diff --git a/dtype.go b/dtype.go new file mode 100644 index 0000000..ad3a39a --- /dev/null +++ b/dtype.go @@ -0,0 +1,250 @@ +package gotch + +import ( + "fmt" + // "log" + "reflect" +) + +// CInt is equal to C type int. Go type is int32 +type CInt = int32 + +// DType represents different kind of element that a tensor can hold. +// It has an embedded `reflect.Type` for type reflection. +type DType struct { + reflect.Type +} + +/* + * // Custom-made Float16 as not exist in Go + * // Ref: https://github.com/golang/go/issues/32022 + * type GoFloat16 = int16 // not implemented yet + * type GoComplexHalf = interface{} // not implemented yet! + * */ + +// TODO: double check these Torch DType to Go type +var ( + Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0 + Int8 DType = DType{reflect.TypeOf(int8(1))} // 1 + Int16 DType = DType{reflect.TypeOf(int16(1))} // 2 + Int DType = DType{reflect.TypeOf(int32(1))} // 3 + Int64 DType = DType{reflect.TypeOf(int64(1))} // 4 + // Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5 + Float DType = DType{reflect.TypeOf(float32(1))} // 6 + Double DType = DType{reflect.TypeOf(float64(1))} // 7 + // ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8 + // ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9 + // ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10 + Bool DType = DType{reflect.TypeOf(true)} // 11 +) + +/* + * // ToCInt converts DType to CInt type value which is `C int` + * func (dt DType) ToCInt() CInt { + * switch dt.Kind() { + * case reflect.Uint8: + * return 0 + * case reflect.Int8: + * return 1 + * case reflect.Int16: + * return 2 + * case reflect.Int32: + * return 3 + * case reflect.Int64: + * return 4 + * case reflect.Float32: + * return 6 + * case reflect.Float64: + * return 7 + * case reflect.Bool: + * return 11 + * default: + * log.Fatalf("Unsupported type.") + * } + * + * // unreachable + * return CInt(-1) + * } + * + * // OfCInt converts a value of type CInt to DType type value + * func (dt DType) OfCInt(v CInt) DType { + * switch v { + * case 0: + * return Uint8 + * case 1: + * return Int8 + * case 2: + * return Int16 + * case 3: + * return Int + * case 4: + * return Int64 + * case 6: + * return Float + * case 7: + * return Double + * case 8: + * case 11: + * return Bool + * default: + * log.Fatalf("Unexpected DType %v\n", v) + * } + * return DType{reflect.TypeOf(false)} + * } + * + * // EltSizeInBytes converts a DType value to number of bytes + * // This is a ELement Size In Bytes in Libtorch. + * // Has it been deprecated? + * func (dt DType) EltSizeInBytes() uint { + * switch dt.Kind() { + * case reflect.Uint8: + * return 1 + * case reflect.Int8: + * return 1 + * case reflect.Int16: + * return 2 + * case reflect.Int: + * return 4 + * case reflect.Int64: + * return 8 + * case reflect.Float32: + * return 4 + * case reflect.Float64: + * return 8 + * case reflect.Bool: + * return 1 + * default: + * log.Fatalf("Unsupported Type %v\n", dt.Type) + * } + * return uint(0) + * } + * */ + +// ToGoType converts DType to Go type +func (dt DType) ToGoType() reflect.Type { + return dt.Type +} + +var dtypeCInt = map[DType]CInt{ + Uint8: 0, + Int8: 1, + Int16: 2, + Int: 3, + Int64: 4, + Float: 6, + Double: 7, + Bool: 11, +} + +func DType2CInt(dt DType) CInt { + return dtypeCInt[dt] +} + +func CInt2DType(v CInt) (dtype DType, err error) { + var found = false + for key, val := range dtypeCInt { + if val == v { + dtype = key + break + } + } + + if !found { + err = fmt.Errorf("Unsuported DType for CInt %v\n", v) + return DType{}, err + } + + return dtype, nil + +} + +// dtypeSize is a map of DType and its size in Bytes +var dtypeSize = map[DType]uint{ + Uint8: 1, + Int8: 1, + Int16: 2, + Int: 4, + Int64: 8, + Float: 4, + Double: 8, + Bool: 1, +} + +// DTypeSize returns DType size in Bytes +func DTypeSize(dt DType) uint { + return dtypeSize[dt] +} + +type DTypeDevice struct { + DType DType + Device Device +} + +var ( + FloatCPU DTypeDevice = DTypeDevice{Float, CPU} + DoubleCPU DTypeDevice = DTypeDevice{Double, CPU} + Int64CPU DTypeDevice = DTypeDevice{Int64, CPU} + + FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)} + DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)} + Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)} +) + +// Type Inferring: +// =============== + +// DataDType infers and returns data type of tensor data +func DataDType(v interface{}, shape []int64) (retVal DType, err error) { + // assuming that all elements in data have the same type + switch { + case len(shape) == 0: + retVal, err = ElementDType(v) + case len(shape) > 0: + return ElementDType(v.([]interface{})[0]) + default: + err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v)) + return DType{}, err + } + return DType{}, nil +} + +// ElementDType infers and returns its own tensor data type +func ElementDType(v interface{}) (retVal DType, err error) { + switch v.(type) { + case uint8: + retVal = Uint8 + case int8: + retVal = Int8 + case int16: + retVal = Int16 + case int32: + retVal = Int + case int64: + retVal = Int64 + case float32: + retVal = Float + case float64: + retVal = Double + case bool: + retVal = Bool + default: + err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v)) + } + + return retVal, nil +} + +// TypeOf infers and returns element Go type from given tensor DType and shape +func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) { + typ := dt.ToGoType() + + switch { + case len(shape) == 0: + return typ, nil + case len(shape) > 0: + return reflect.SliceOf(typ), nil + default: + err = fmt.Errorf("Unsupported data type.") + return nil, err + } +} diff --git a/example/tensor/main.go b/example/tensor/main.go index 9afa088..2ed5338 100644 --- a/example/tensor/main.go +++ b/example/tensor/main.go @@ -1,15 +1,25 @@ package main import ( - "fmt" + "log" - tensor "github.com/sugarme/gotch/tensor" + gotch "github.com/sugarme/gotch" + wrapper "github.com/sugarme/gotch/wrapper" ) func main() { - _, err := tensor.FnOfSlice() + + // TODO: Check Go type of data and tensor DType + // For. if data is []int and DType is Bool + // It is still running but get wrong result. + data := []bool{true, true, false} + dtype := gotch.Bool + + ts := wrapper.NewTensor() + sliceTensor, err := ts.FOfSlice(data, dtype) if err != nil { - fmt.Println(err) + log.Fatal(err) } + sliceTensor.Print() } diff --git a/kind.go b/kind.go deleted file mode 100644 index 3ca327b..0000000 --- a/kind.go +++ /dev/null @@ -1,135 +0,0 @@ -package gotch - -import ( - "log" - "reflect" -) - -// CInt is equal to C type int. Go type is int32 -type CInt = int32 - -// Kind represents different kind of element that a tensor can hold. -// It has an embedded `reflect.Type` for type reflection. -type Kind struct { - reflect.Type -} - -// TODO: double check these Torch DType to Go type -var ( - Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0 - Int8 = Kind{reflect.TypeOf(int8(1))} // 1 - Int16 = Kind{reflect.TypeOf(int16(1))} // 2 - Int = Kind{reflect.TypeOf(int(1))} // 3 - Int64 = Kind{reflect.TypeOf(int64(1))} // 4 - Half = Kind{reflect.TypeOf(float32(1))} // 5 - Float = Kind{reflect.TypeOf(float64(1))} // 6 - Double = Kind{reflect.TypeOf(float64(1))} // 7 - ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8 - ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9 - ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10 - Bool = kind{reflect.TypeOf(true)} // 11 -) - -// ToCInt converts Kind to CInt type value which is `C int` -func (k Kind) ToCInt() CInt { - switch { - case k.Kind() == uint8: - return 0 - case k.Kind() == int8: - return 1 - case k.Kind() == int16: - return 2 - case k.Kind() == int: - return 3 - case k.Kind() == int64: - return 4 - case k.Kind() == float32: - return 5 - default: - log.Fatalf("Unsupported type.") - } - - // unreachable - return CInt(-1) -} - -// OfCInt converts a value of type CInt to Kind type value -func (k Kind) OfCInt(v CInt) Kind { - switch v { - case 0: - return Uint8 - case 1: - return Int8 - case 2: - return Int16 - case 3: - return Int - case 4: - return Int64 - case 5: - return Half - case 6: - return Float - case 7: - return Double - case 8: - return ComplexHalf - case 9: - return ComplexFloat - case 10: - return ComplexDouble - case 11: - return Bool - default: - log.Fatalf("Unexpected kind %v\n", v) - } - return Kind{reflect.TypeOf(false)} -} - -// EltSizeInBytes converts a Kind value to number of bytes -// This is a ELement Size In Byte in Libtorch. -// Has it been deprecated? -func (k Kind) EltSizeInBytes() uint { - switch { - case k.ToCInt() == int32(Uint8): - return 1 - case k.ToCInt() == int32(Int8): - return 1 - case k.ToCInt() == int32(Int16): - return 2 - case k.ToCInt() == int32(Int): - return 4 - case k.ToCInt() == int32(Int64): - return 8 - case k.ToCInt() == int32(Half): - return 2 - case k.ToCInt() == int32(Float): - return 4 - case k.ToCInt() == int32(Double): - return 8 - case k.ToCInt() == int32(ComplexHalf): - return 4 - case k.ToCInt() == int32(ComplexDouble): - return 16 - case k.ToCInt() == int32(Bool): - return 1 - default: - log.Fatalf("Unreachable") - } - return uint(0) -} - -type KindDevice struct { - Kind Kind - Device Device -} - -var ( - FloatCPU KindDevice = KindDevice{Float, CPU} - DoubleCPU KindDevice = KindDevice{Double, CPU} - Int64CPU KindDevice = KindDevice{Int64, CPU} - - FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)} - DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)} - Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)} -) diff --git a/libtch/tensor.go b/libtch/tensor.go index 1abf6f7..44de43d 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -5,8 +5,8 @@ package libtch import "C" import ( - "fmt" - "reflect" + // "fmt" + // "reflect" "unsafe" ) @@ -33,27 +33,27 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt // t is of type `unsafe.Pointer` in Go and `*void` in C t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind) - fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind()) - fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t) + // fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind()) + // fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t) // Keep C pointer value tin Go struct cTensorPtrVal := unsafe.Pointer(t) - fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal) + // fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal) var retVal *C_tensor retVal = &C_tensor{private: cTensorPtrVal} - fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private) + // fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private) // test call C.at_print to print out tensor // C.at_print(*(*C.tensor)(unsafe.Pointer(&t))) - AtPrint(retVal) + // AtPrint(retVal) return retVal } func AtPrint(t *C_tensor) { - fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private) + // fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private) cTensor := (C.tensor)((*t).private) - fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor) + // fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor) C.at_print(cTensor) } diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 6392ad0..f21102c 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -1,20 +1,18 @@ package wrapper -//#include +// #include import "C" import ( - "bytes" - "encoding/binary" - "fmt" + // "fmt" "reflect" - "unsafe" + gotch "github.com/sugarme/gotch" lib "github.com/sugarme/gotch/libtch" ) type Tensor struct { - ctensor *t.C_tensor + ctensor *lib.C_tensor } // NewTensor creates a new tensor @@ -24,45 +22,43 @@ func NewTensor() Tensor { } // FOfSlice creates tensor from a slice data -func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) { +func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) { - data := []int{0, 0, 0, 0} - shape := []int64{int64(len(data))} - nflattened := numElements(shape) - dtype := 3 // Kind.Int - eltSizeInBytes := 4 // Element Size in Byte for Int dtype + dataLen := reflect.ValueOf(data).Len() + shape := []int64{int64(dataLen)} + elementNum := ElementCount(shape) + // eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype + eltSizeInBytes := gotch.DTypeSize(dtype) - nbytes := eltSizeInBytes * int(uintptr(nflattened)) + nbytes := int(eltSizeInBytes) * int(elementNum) - // NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go - // data should be allocated to memory BY `C` side - dataPtr := C.malloc(C.size_t(nbytes)) + dataPtr, buff := CMalloc(nbytes) - // Recall: 1 << 30 = 1 * 2 * 30 - // Ref. See more at https://stackoverflow.com/questions/48756732 - dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes] - - buf := bytes.NewBuffer(dataSlice[:0:nbytes]) - - EncodeTensor(buf, reflect.ValueOf(data), shape) - - c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype)) - - retVal = Tensor{c_tensor} - - // Read back created tensor values by C libtorch - readDataPtr := lib.AtDataPtr(retVal.c_tensor) - readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes] - // typ := typeOf(dtype, shape) - typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32` - val := reflect.New(typ) - if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil { - panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err)) + if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil { + return nil, err } - tensorData := reflect.Indirect(val).Interface() + ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype))) - fmt.Println("%v", tensorData) + retVal = &Tensor{ctensor} + + // Read back created tensor values by C libtorch + // readDataPtr := lib.AtDataPtr(retVal.ctensor) + // readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes] + // // typ := typeOf(dtype, shape) + // typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32` + // val := reflect.New(typ) + // if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil { + // panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err)) + // } + // + // tensorData := reflect.Indirect(val).Interface() + // + // fmt.Println("%v", tensorData) return retVal, nil } + +func (ts Tensor) Print() { + lib.AtPrint(ts.ctensor) +} diff --git a/wrapper/util.go b/wrapper/util.go index 19ca852..32acdb9 100644 --- a/wrapper/util.go +++ b/wrapper/util.go @@ -1,16 +1,20 @@ package wrapper +// #include +import "C" + import ( "bytes" "encoding/binary" - "errors" "fmt" "reflect" "unsafe" - - gotch "github.com/sugarme/gotch" + // gotch "github.com/sugarme/gotch" ) +// nativeEndian is a ByteOrder for local platform. +// Ref. https://stackoverflow.com/a/53286786 +// Ref. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/tensor.go#L488-L505 var nativeEndian binary.ByteOrder func init() { @@ -27,6 +31,36 @@ func init() { } } +// CMalloc allocates a given number of bytes to C side memory. +// It returns +// - dataPtr: a C pointer type of `*void` (`unsafe.Pointer` in Go). +// - buf : a Go pointer points to a given bytes of buffer (empty) in C memory +// allocated by C waiting for writing data to. +// +// NOTE: +// 1. Go pointer is a pointer to Go memory. C pointer is a pointer to C memory. +// 2. General rule is Go code can use C pointers. Go code may pass Go pointer to C +// provided that the Go memory to which it points does NOT contain any Go +// pointers. BUT C code must not store any Go pointers in Go memory, even +// temporarily. +// 3. Some Go values contain Go pointers IMPLICITLY: strings, slices, maps, +// channels and function values. Thus, pointers to these values should not be +// passed to C side. Instead, data should be allocated to C memory and return a +// C pointer to it using `C.malloc`. +// Ref: https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md +func CMalloc(nbytes int) (dataPtr unsafe.Pointer, buf *bytes.Buffer) { + + dataPtr = C.malloc(C.size_t(nbytes)) + + // Recall: 1 << 30 = 1 * 2 * 30 + // Ref. See more at https://stackoverflow.com/questions/48756732 + dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes] + buf = bytes.NewBuffer(dataSlice[:0:nbytes]) + + return dataPtr, buf +} + +// EncodeTensor loads tensor data to C memory and returns a C pointer. func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { switch v.Kind() { case reflect.Bool: @@ -37,7 +71,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { if err := w.WriteByte(b); err != nil { return err } - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: if err := binary.Write(w, nativeEndian, v.Interface()); err != nil { return err } @@ -55,7 +89,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { // Optimisation: if only one dimension is left we can use binary.Write() directly for this slice if len(shape) == 1 && v.Len() > 0 { switch v.Index(0).Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: return binary.Write(w, nativeEndian, v.Interface()) } } @@ -74,8 +108,8 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { return nil } -// DecodeTensor decodes the Tensor from the buffer to ptr using the format -// specified in c_api.h. Use stringDecoder for String tensors. +// DecodeTensor decodes tensor value from a C memory buffer given +// C pointer, data type and shape and returns data value of type interface func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { switch typ.Kind() { case reflect.Bool: @@ -84,7 +118,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. return err } ptr.Elem().SetBool(b == 1) - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil { return err } @@ -96,7 +130,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. // Optimization: if only one dimension is left we can use binary.Read() directly for this slice if len(shape) == 1 && val.Len() > 0 { switch val.Index(0).Kind() { - case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64: return binary.Read(r, nativeEndian, val.Interface()) } } @@ -113,47 +147,11 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect. return nil } -func numElements(shape []int64) int64 { +// ElementCount counts number of element in the tensor given a shape +func ElementCount(shape []int64) int64 { n := int64(1) for _, d := range shape { n *= d } return n } - -// GetKind returns data type `Kind` (a element of tensor can hold) -// v - a value of a data element -func GetKind(v interface{}) (retVal gotch.Kind, err error) { - - switch { - case reflect.TypeOf(v) == int: - retVal = gotch.Int - case reflect.TypeOf(v) == uint8: - retVal = gotch.Uint8 - - default: - err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v)) - return retVal, err - } - - return retVal, nil -} - -// // TypeOf converts from a DType and Shape to the equivalent Go type. -// func TypeOf(dt DType, shape []int64) reflect.Type { -// var ret reflect.Type -// for _, t := range types { -// if dt == DType(t.dataType) { -// ret = t.typ -// break -// } -// } -// if ret == nil { -// // TODO get tensor name -// panic(fmt.Sprintf("Unsupported DType %d", int(dt))) -// } -// for range shape { -// ret = reflect.SliceOf(ret) -// } -// return ret -// }