WIP: restructure and tensor/kind.go
This commit is contained in:
parent
5f167e3b67
commit
51d5d127dc
|
@ -1,9 +1,9 @@
|
||||||
package torch
|
package gorch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
lib "github.com/sugarme/gotch/torch/libtch"
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Device struct {
|
type Device struct {
|
|
@ -1,129 +1,15 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
//#include <stdlib.h>
|
|
||||||
import "C"
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
t "github.com/sugarme/gotch/torch/libtch"
|
tensor "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Tensor struct {
|
|
||||||
c_tensor *t.C_tensor
|
|
||||||
}
|
|
||||||
|
|
||||||
func FnOfSlice() (retVal Tensor, err error) {
|
|
||||||
|
|
||||||
data := []int{1, 2, 3, 4, 5, 6}
|
|
||||||
nflattened := len(data)
|
|
||||||
dtype := 3 // Kind.Int
|
|
||||||
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
|
||||||
|
|
||||||
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
|
||||||
|
|
||||||
dataPtr := C.malloc(C.size_t(nbytes))
|
|
||||||
|
|
||||||
// Recall: 1 << 30 = 1 * 2 * 30
|
|
||||||
// Ref. See more at https://stackoverflow.com/questions/48756732
|
|
||||||
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
|
||||||
|
|
||||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
|
||||||
|
|
||||||
encodeTensor(buf, reflect.ValueOf(data), []int64{1})
|
|
||||||
|
|
||||||
c_tensor := t.AtTensorOfData(dataPtr, int64(nflattened), 1, uint(eltSizeInBytes), int32(dtype))
|
|
||||||
|
|
||||||
retVal = Tensor{c_tensor}
|
|
||||||
|
|
||||||
return retVal, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func numElements(shape []int) int {
|
|
||||||
n := 1
|
|
||||||
for _, d := range shape {
|
|
||||||
n *= d
|
|
||||||
}
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
_, err := tensor.FnOfSlice()
|
||||||
t := t.NewTensor()
|
|
||||||
|
|
||||||
fmt.Printf("Type of t: %v\n", reflect.TypeOf(t))
|
|
||||||
|
|
||||||
res, err := FnOfSlice()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(res)
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
|
||||||
switch v.Kind() {
|
|
||||||
case reflect.Bool:
|
|
||||||
b := byte(0)
|
|
||||||
if v.Bool() {
|
|
||||||
b = 1
|
|
||||||
}
|
|
||||||
if err := w.WriteByte(b); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
|
||||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
case reflect.Array, reflect.Slice:
|
|
||||||
// If current dimension is a slice, verify that it has the expected size
|
|
||||||
// Go's type system makes that guarantee for arrays.
|
|
||||||
if v.Kind() == reflect.Slice {
|
|
||||||
expected := int(shape[0])
|
|
||||||
if v.Len() != expected {
|
|
||||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
|
||||||
if len(shape) == 1 && v.Len() > 0 {
|
|
||||||
switch v.Index(0).Kind() {
|
|
||||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
|
||||||
return binary.Write(w, nativeEndian, v.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
subShape := shape[1:]
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
err := encodeTensor(w, v.Index(i), subShape)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported type %v", v.Type())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var nativeEndian binary.ByteOrder
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
buf := [2]byte{}
|
|
||||||
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
|
||||||
|
|
||||||
switch buf {
|
|
||||||
case [2]byte{0xCD, 0xAB}:
|
|
||||||
nativeEndian = binary.LittleEndian
|
|
||||||
case [2]byte{0xAB, 0xCD}:
|
|
||||||
nativeEndian = binary.BigEndian
|
|
||||||
default:
|
|
||||||
panic("Could not determine native endianness.")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
69
libtch/tensor.go
Normal file
69
libtch/tensor.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package libtch
|
||||||
|
|
||||||
|
//#include "stdbool.h"
|
||||||
|
//#include "torch_api.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// type c_void unsafe.Pointer
|
||||||
|
// type size_t uint
|
||||||
|
// type c_int int32
|
||||||
|
|
||||||
|
type C_tensor struct {
|
||||||
|
private unsafe.Pointer
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewTensor() *C_tensor {
|
||||||
|
t := C.at_new_tensor()
|
||||||
|
return &C_tensor{private: unsafe.Pointer(t)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
|
||||||
|
|
||||||
|
// just get pointer of the first element of shape
|
||||||
|
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
||||||
|
c_ndims := *(*C.size_t)(unsafe.Pointer(&ndims))
|
||||||
|
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
|
||||||
|
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
||||||
|
|
||||||
|
// c_dims := (*C.long)(unsafe.Pointer(uintptr(dims)))
|
||||||
|
// c_ndims := *(*C.size_t)(unsafe.Pointer(uintptr(ndims)))
|
||||||
|
// c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(uintptr(elt_size_in_bytes)))
|
||||||
|
// c_kind := *(*C.int)(unsafe.Pointer(uintptr(kind)))
|
||||||
|
|
||||||
|
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||||
|
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||||
|
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||||
|
fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
||||||
|
// Keep C pointer value tin Go struct
|
||||||
|
cTensorPtrVal := unsafe.Pointer(t)
|
||||||
|
fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
||||||
|
|
||||||
|
var retVal *C_tensor
|
||||||
|
retVal = &C_tensor{private: cTensorPtrVal}
|
||||||
|
fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
||||||
|
|
||||||
|
// test call C.at_print to print out tensor
|
||||||
|
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
|
||||||
|
AtPrint(retVal)
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func AtPrint(t *C_tensor) {
|
||||||
|
fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
||||||
|
cTensor := (C.tensor)((*t).private)
|
||||||
|
fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
||||||
|
|
||||||
|
C.at_print(cTensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
func AtDataPtr(t *C_tensor) unsafe.Pointer {
|
||||||
|
cTensor := (C.tensor)((*t).private)
|
||||||
|
return C.at_data_ptr(cTensor)
|
||||||
|
}
|
102
tensor/kind.go
Normal file
102
tensor/kind.go
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
package tensor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CInt is equal to C type int. Go type is int32
|
||||||
|
type CInt = int32
|
||||||
|
|
||||||
|
// Kind is 'enum' like type. It represents different kind of elements
|
||||||
|
// that a Tensor can hold.
|
||||||
|
type Kind int
|
||||||
|
|
||||||
|
const (
|
||||||
|
Uint8 Kind = iota // 0
|
||||||
|
Int8 // 1
|
||||||
|
Int16 // 2
|
||||||
|
Int // 3
|
||||||
|
Int64 // 4
|
||||||
|
Half // 5
|
||||||
|
Float // 6
|
||||||
|
Double // 7
|
||||||
|
ComplexHalf // 8
|
||||||
|
ComplexFloat // 9
|
||||||
|
ComplexDouble // 10
|
||||||
|
Bool // 11
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToCInt converts Kind to CInt type value which is `C int`
|
||||||
|
func (k Kind) ToCInt() CInt {
|
||||||
|
return CInt(k)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OfCInt converts a value of type CInt to Kind type value
|
||||||
|
func (k Kind) OfCInt(v CInt) Kind {
|
||||||
|
switch v {
|
||||||
|
case 0:
|
||||||
|
return Uint8
|
||||||
|
case 1:
|
||||||
|
return Int8
|
||||||
|
case 2:
|
||||||
|
return Int16
|
||||||
|
case 3:
|
||||||
|
return Int
|
||||||
|
case 4:
|
||||||
|
return Int64
|
||||||
|
case 5:
|
||||||
|
return Half
|
||||||
|
case 6:
|
||||||
|
return Float
|
||||||
|
case 7:
|
||||||
|
return Double
|
||||||
|
case 8:
|
||||||
|
return ComplexHalf
|
||||||
|
case 9:
|
||||||
|
return ComplexFloat
|
||||||
|
case 10:
|
||||||
|
return ComplexDouble
|
||||||
|
case 11:
|
||||||
|
return Bool
|
||||||
|
default:
|
||||||
|
log.Fatalf("Unexpected kind %v\n", v)
|
||||||
|
}
|
||||||
|
return Kind(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// EltSizeInBytes converts a Kind value to number of bytes
|
||||||
|
// This is a ELement Size In Byte in Libtorch.
|
||||||
|
// Has it been deprecated?
|
||||||
|
func (k Kind) EltSizeInBytes() uint {
|
||||||
|
switch {
|
||||||
|
case k.ToCInt() == int32(Uint8):
|
||||||
|
return 1
|
||||||
|
case k.ToCInt() == int32(Int8):
|
||||||
|
return 1
|
||||||
|
case k.ToCInt() == int32(Int16):
|
||||||
|
return 2
|
||||||
|
case k.ToCInt() == int32(Int):
|
||||||
|
return 4
|
||||||
|
case k.ToCInt() == int32(Int64):
|
||||||
|
return 8
|
||||||
|
case k.ToCInt() == int32(Half):
|
||||||
|
return 2
|
||||||
|
case k.ToCInt() == int32(Float):
|
||||||
|
return 4
|
||||||
|
case k.ToCInt() == int32(Double):
|
||||||
|
return 8
|
||||||
|
case k.ToCInt() == int32(ComplexHalf):
|
||||||
|
return 4
|
||||||
|
case k.ToCInt() == int32(ComplexDouble):
|
||||||
|
return 16
|
||||||
|
case k.ToCInt() == int32(Bool):
|
||||||
|
return 1
|
||||||
|
default:
|
||||||
|
log.Fatalf("Unreachable")
|
||||||
|
}
|
||||||
|
return uint(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: continue with devices...
|
191
tensor/tensor.go
Normal file
191
tensor/tensor.go
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
package tensor
|
||||||
|
|
||||||
|
//#include <stdlib.h>
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
// "runtime"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Tensor struct {
|
||||||
|
ctensor *t.C_tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
var nativeEndian binary.ByteOrder
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
buf := [2]byte{}
|
||||||
|
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
||||||
|
|
||||||
|
switch buf {
|
||||||
|
case [2]byte{0xCD, 0xAB}:
|
||||||
|
nativeEndian = binary.LittleEndian
|
||||||
|
case [2]byte{0xAB, 0xCD}:
|
||||||
|
nativeEndian = binary.BigEndian
|
||||||
|
default:
|
||||||
|
panic("Could not determine native endianness.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FnOfSlice creates tensor from a slice data
|
||||||
|
func FnOfSlice() (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
data := []int{0, 0, 0, 0}
|
||||||
|
shape := []int64{int64(len(data))}
|
||||||
|
nflattened := numElements(shape)
|
||||||
|
dtype := 3 // Kind.Int
|
||||||
|
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
|
||||||
|
|
||||||
|
nbytes := eltSizeInBytes * int(uintptr(nflattened))
|
||||||
|
|
||||||
|
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
|
||||||
|
dataPtr := C.malloc(C.size_t(nbytes))
|
||||||
|
|
||||||
|
// Recall: 1 << 30 = 1 * 2 * 30
|
||||||
|
// Ref. See more at https://stackoverflow.com/questions/48756732
|
||||||
|
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
|
||||||
|
|
||||||
|
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||||
|
|
||||||
|
encodeTensor(buf, reflect.ValueOf(data), shape)
|
||||||
|
|
||||||
|
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
|
||||||
|
|
||||||
|
retVal = Tensor{c_tensor}
|
||||||
|
|
||||||
|
// Read back created tensor values by C libtorch
|
||||||
|
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
|
||||||
|
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
|
||||||
|
// typ := typeOf(dtype, shape)
|
||||||
|
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
|
||||||
|
val := reflect.New(typ)
|
||||||
|
if err := decodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
|
||||||
|
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
tensorData := reflect.Indirect(val).Interface()
|
||||||
|
|
||||||
|
fmt.Println("%v", tensorData)
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func numElements(shape []int64) int64 {
|
||||||
|
n := int64(1)
|
||||||
|
for _, d := range shape {
|
||||||
|
n *= d
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
b := byte(0)
|
||||||
|
if v.Bool() {
|
||||||
|
b = 1
|
||||||
|
}
|
||||||
|
if err := w.WriteByte(b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||||
|
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Array, reflect.Slice:
|
||||||
|
// If current dimension is a slice, verify that it has the expected size
|
||||||
|
// Go's type system makes that guarantee for arrays.
|
||||||
|
if v.Kind() == reflect.Slice {
|
||||||
|
expected := int(shape[0])
|
||||||
|
if v.Len() != expected {
|
||||||
|
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||||
|
if len(shape) == 1 && v.Len() > 0 {
|
||||||
|
switch v.Index(0).Kind() {
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||||
|
return binary.Write(w, nativeEndian, v.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
subShape := shape[1:]
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
err := encodeTensor(w, v.Index(i), subShape)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported type %v", v.Type())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decodeTensor decodes the Tensor from the buffer to ptr using the format
|
||||||
|
// specified in c_api.h. Use stringDecoder for String tensors.
|
||||||
|
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||||
|
switch typ.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
b, err := r.ReadByte()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ptr.Elem().SetBool(b == 1)
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||||
|
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
val := reflect.Indirect(ptr)
|
||||||
|
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
|
||||||
|
|
||||||
|
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||||
|
if len(shape) == 1 && val.Len() > 0 {
|
||||||
|
switch val.Index(0).Kind() {
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||||
|
return binary.Read(r, nativeEndian, val.Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < val.Len(); i++ {
|
||||||
|
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported type %v", typ)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// // typeOf converts from a DType and Shape to the equivalent Go type.
|
||||||
|
// func typeOf(dt DType, shape []int64) reflect.Type {
|
||||||
|
// var ret reflect.Type
|
||||||
|
// for _, t := range types {
|
||||||
|
// if dt == DType(t.dataType) {
|
||||||
|
// ret = t.typ
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// if ret == nil {
|
||||||
|
// // TODO get tensor name
|
||||||
|
// panic(fmt.Sprintf("Unsupported DType %d", int(dt)))
|
||||||
|
// }
|
||||||
|
// for range shape {
|
||||||
|
// ret = reflect.SliceOf(ret)
|
||||||
|
// }
|
||||||
|
// return ret
|
||||||
|
// }
|
|
@ -1,55 +0,0 @@
|
||||||
package torch
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Kind struct {
|
|
||||||
reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
type CInt = int32
|
|
||||||
|
|
||||||
/*
|
|
||||||
* Uint8,
|
|
||||||
* Int8,
|
|
||||||
* Int16,
|
|
||||||
* Int,
|
|
||||||
* Int64,
|
|
||||||
* Half,
|
|
||||||
* Float,
|
|
||||||
* Double,
|
|
||||||
* ComplexHalf,
|
|
||||||
* ComplexFloat,
|
|
||||||
* ComplexDouble,
|
|
||||||
* Bool,
|
|
||||||
* */
|
|
||||||
|
|
||||||
// TODO: recode these types
|
|
||||||
|
|
||||||
var (
|
|
||||||
Bool = Kind{reflect.TypeOf(true)}
|
|
||||||
Int = Kind{reflect.TypeOf(int(1))}
|
|
||||||
Int8 = Kind{reflect.TypeOf(int8(1))}
|
|
||||||
Int16 = Kind{reflect.TypeOf(int16(1))}
|
|
||||||
Int32 = Kind{reflect.TypeOf(int32(1))}
|
|
||||||
Int64 = Kind{reflect.TypeOf(int64(1))}
|
|
||||||
Uint = Kind{reflect.TypeOf(uint(1))}
|
|
||||||
Uint8 = Kind{reflect.TypeOf(uint8(1))}
|
|
||||||
Uint16 = Kind{reflect.TypeOf(uint16(1))}
|
|
||||||
Uint32 = Kind{reflect.TypeOf(uint32(1))}
|
|
||||||
Uint64 = Kind{reflect.TypeOf(uint64(1))}
|
|
||||||
Float32 = Kind{reflect.TypeOf(float32(1))}
|
|
||||||
Float64 = Kind{reflect.TypeOf(float64(1))}
|
|
||||||
Complex64 = Kind{reflect.TypeOf(complex64(1))}
|
|
||||||
Complex128 = Kind{reflect.TypeOf(complex128(1))}
|
|
||||||
String = Kind{reflect.TypeOf("")}
|
|
||||||
|
|
||||||
// aliases
|
|
||||||
Byte = Uint8
|
|
||||||
|
|
||||||
// extras
|
|
||||||
Uintptr = Kind{reflect.TypeOf(uintptr(0))}
|
|
||||||
UnsafePointer = Kind{reflect.TypeOf(unsafe.Pointer(&Uintptr))}
|
|
||||||
)
|
|
|
@ -1,32 +0,0 @@
|
||||||
package libtch
|
|
||||||
|
|
||||||
//#include "stdbool.h"
|
|
||||||
//#include "torch_api.h"
|
|
||||||
import "C"
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type c_void unsafe.Pointer
|
|
||||||
type size_t uint
|
|
||||||
type c_int int32
|
|
||||||
|
|
||||||
type C_tensor struct {
|
|
||||||
private uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTensor() *C_tensor {
|
|
||||||
t := C.at_new_tensor()
|
|
||||||
return &C_tensor{private: *(*uint8)(unsafe.Pointer(&t))}
|
|
||||||
}
|
|
||||||
|
|
||||||
func AtTensorOfData(vs unsafe.Pointer, dims int64, ndims uint, elt_size_in_bytes uint, kind int32) *C_tensor {
|
|
||||||
c_dims := (*C.long)(unsafe.Pointer(&dims))
|
|
||||||
c_ndims := *(*C.ulong)(unsafe.Pointer(&ndims))
|
|
||||||
c_elt_size_in_bytes := *(*C.ulong)(unsafe.Pointer(&elt_size_in_bytes))
|
|
||||||
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
|
||||||
|
|
||||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
|
||||||
return &C_tensor{private: *(*uint8)(unsafe.Pointer(&t))}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user