feat(dtype), WIP(wrapper), example/tensor

This commit is contained in:
sugarme 2020-05-30 09:04:47 +10:00
parent 773f423fff
commit 6b0d6105ae
6 changed files with 352 additions and 233 deletions

250
dtype.go Normal file
View File

@ -0,0 +1,250 @@
package gotch
import (
"fmt"
// "log"
"reflect"
)
// CInt is equal to C type int. Go type is int32
type CInt = int32
// DType represents different kind of element that a tensor can hold.
// It has an embedded `reflect.Type` for type reflection.
type DType struct {
reflect.Type
}
/*
* // Custom-made Float16 as not exist in Go
* // Ref: https://github.com/golang/go/issues/32022
* type GoFloat16 = int16 // not implemented yet
* type GoComplexHalf = interface{} // not implemented yet!
* */
// TODO: double check these Torch DType to Go type
var (
Uint8 DType = DType{reflect.TypeOf(uint8(1))} // 0
Int8 DType = DType{reflect.TypeOf(int8(1))} // 1
Int16 DType = DType{reflect.TypeOf(int16(1))} // 2
Int DType = DType{reflect.TypeOf(int32(1))} // 3
Int64 DType = DType{reflect.TypeOf(int64(1))} // 4
// Half DType = DType{reflect.TypeOf(GoFloat16(1))} // 5
Float DType = DType{reflect.TypeOf(float32(1))} // 6
Double DType = DType{reflect.TypeOf(float64(1))} // 7
// ComplexHalf DType = DType{reflect.TypeOf(GoComplexHalf(1))} // 8
// ComplexFloat DType = DType{reflect.TypeOf(complex64(1))} // 9
// ComplexDouble DType = DType{reflect.TypeOf(complex128(1))} // 10
Bool DType = DType{reflect.TypeOf(true)} // 11
)
/*
* // ToCInt converts DType to CInt type value which is `C int`
* func (dt DType) ToCInt() CInt {
* switch dt.Kind() {
* case reflect.Uint8:
* return 0
* case reflect.Int8:
* return 1
* case reflect.Int16:
* return 2
* case reflect.Int32:
* return 3
* case reflect.Int64:
* return 4
* case reflect.Float32:
* return 6
* case reflect.Float64:
* return 7
* case reflect.Bool:
* return 11
* default:
* log.Fatalf("Unsupported type.")
* }
*
* // unreachable
* return CInt(-1)
* }
*
* // OfCInt converts a value of type CInt to DType type value
* func (dt DType) OfCInt(v CInt) DType {
* switch v {
* case 0:
* return Uint8
* case 1:
* return Int8
* case 2:
* return Int16
* case 3:
* return Int
* case 4:
* return Int64
* case 6:
* return Float
* case 7:
* return Double
* case 8:
* case 11:
* return Bool
* default:
* log.Fatalf("Unexpected DType %v\n", v)
* }
* return DType{reflect.TypeOf(false)}
* }
*
* // EltSizeInBytes converts a DType value to number of bytes
* // This is a ELement Size In Bytes in Libtorch.
* // Has it been deprecated?
* func (dt DType) EltSizeInBytes() uint {
* switch dt.Kind() {
* case reflect.Uint8:
* return 1
* case reflect.Int8:
* return 1
* case reflect.Int16:
* return 2
* case reflect.Int:
* return 4
* case reflect.Int64:
* return 8
* case reflect.Float32:
* return 4
* case reflect.Float64:
* return 8
* case reflect.Bool:
* return 1
* default:
* log.Fatalf("Unsupported Type %v\n", dt.Type)
* }
* return uint(0)
* }
* */
// ToGoType converts DType to Go type
func (dt DType) ToGoType() reflect.Type {
return dt.Type
}
var dtypeCInt = map[DType]CInt{
Uint8: 0,
Int8: 1,
Int16: 2,
Int: 3,
Int64: 4,
Float: 6,
Double: 7,
Bool: 11,
}
func DType2CInt(dt DType) CInt {
return dtypeCInt[dt]
}
func CInt2DType(v CInt) (dtype DType, err error) {
var found = false
for key, val := range dtypeCInt {
if val == v {
dtype = key
break
}
}
if !found {
err = fmt.Errorf("Unsuported DType for CInt %v\n", v)
return DType{}, err
}
return dtype, nil
}
// dtypeSize is a map of DType and its size in Bytes
var dtypeSize = map[DType]uint{
Uint8: 1,
Int8: 1,
Int16: 2,
Int: 4,
Int64: 8,
Float: 4,
Double: 8,
Bool: 1,
}
// DTypeSize returns DType size in Bytes
func DTypeSize(dt DType) uint {
return dtypeSize[dt]
}
type DTypeDevice struct {
DType DType
Device Device
}
var (
FloatCPU DTypeDevice = DTypeDevice{Float, CPU}
DoubleCPU DTypeDevice = DTypeDevice{Double, CPU}
Int64CPU DTypeDevice = DTypeDevice{Int64, CPU}
FloatCUDA DTypeDevice = DTypeDevice{Float, CudaBuilder(0)}
DoubleCUDA DTypeDevice = DTypeDevice{Double, CudaBuilder(0)}
Int64CUDA DTypeDevice = DTypeDevice{Int64, CudaBuilder(0)}
)
// Type Inferring:
// ===============
// DataDType infers and returns data type of tensor data
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
// assuming that all elements in data have the same type
switch {
case len(shape) == 0:
retVal, err = ElementDType(v)
case len(shape) > 0:
return ElementDType(v.([]interface{})[0])
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
return DType{}, err
}
return DType{}, nil
}
// ElementDType infers and returns its own tensor data type
func ElementDType(v interface{}) (retVal DType, err error) {
switch v.(type) {
case uint8:
retVal = Uint8
case int8:
retVal = Int8
case int16:
retVal = Int16
case int32:
retVal = Int
case int64:
retVal = Int64
case float32:
retVal = Float
case float64:
retVal = Double
case bool:
retVal = Bool
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
}
return retVal, nil
}
// TypeOf infers and returns element Go type from given tensor DType and shape
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
typ := dt.ToGoType()
switch {
case len(shape) == 0:
return typ, nil
case len(shape) > 0:
return reflect.SliceOf(typ), nil
default:
err = fmt.Errorf("Unsupported data type.")
return nil, err
}
}

View File

@ -1,15 +1,25 @@
package main
import (
"fmt"
"log"
tensor "github.com/sugarme/gotch/tensor"
gotch "github.com/sugarme/gotch"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
_, err := tensor.FnOfSlice()
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
data := []bool{true, true, false}
dtype := gotch.Bool
ts := wrapper.NewTensor()
sliceTensor, err := ts.FOfSlice(data, dtype)
if err != nil {
fmt.Println(err)
log.Fatal(err)
}
sliceTensor.Print()
}

135
kind.go
View File

@ -1,135 +0,0 @@
package gotch
import (
"log"
"reflect"
)
// CInt is equal to C type int. Go type is int32
type CInt = int32
// Kind represents different kind of element that a tensor can hold.
// It has an embedded `reflect.Type` for type reflection.
type Kind struct {
reflect.Type
}
// TODO: double check these Torch DType to Go type
var (
Uint8 = Kind{reflect.TypeOf(uint8(1))} // 0
Int8 = Kind{reflect.TypeOf(int8(1))} // 1
Int16 = Kind{reflect.TypeOf(int16(1))} // 2
Int = Kind{reflect.TypeOf(int(1))} // 3
Int64 = Kind{reflect.TypeOf(int64(1))} // 4
Half = Kind{reflect.TypeOf(float32(1))} // 5
Float = Kind{reflect.TypeOf(float64(1))} // 6
Double = Kind{reflect.TypeOf(float64(1))} // 7
ComplexHalf = kind{reflect.TypeOf(complex(1))} // 8
ComplexFloat = Kind{reflect.TypeOf(complex64(1))} // 9
ComplexDouble = kind{reflect.TypeOf(complex128(1))} // 10
Bool = kind{reflect.TypeOf(true)} // 11
)
// ToCInt converts Kind to CInt type value which is `C int`
func (k Kind) ToCInt() CInt {
switch {
case k.Kind() == uint8:
return 0
case k.Kind() == int8:
return 1
case k.Kind() == int16:
return 2
case k.Kind() == int:
return 3
case k.Kind() == int64:
return 4
case k.Kind() == float32:
return 5
default:
log.Fatalf("Unsupported type.")
}
// unreachable
return CInt(-1)
}
// OfCInt converts a value of type CInt to Kind type value
func (k Kind) OfCInt(v CInt) Kind {
switch v {
case 0:
return Uint8
case 1:
return Int8
case 2:
return Int16
case 3:
return Int
case 4:
return Int64
case 5:
return Half
case 6:
return Float
case 7:
return Double
case 8:
return ComplexHalf
case 9:
return ComplexFloat
case 10:
return ComplexDouble
case 11:
return Bool
default:
log.Fatalf("Unexpected kind %v\n", v)
}
return Kind{reflect.TypeOf(false)}
}
// EltSizeInBytes converts a Kind value to number of bytes
// This is a ELement Size In Byte in Libtorch.
// Has it been deprecated?
func (k Kind) EltSizeInBytes() uint {
switch {
case k.ToCInt() == int32(Uint8):
return 1
case k.ToCInt() == int32(Int8):
return 1
case k.ToCInt() == int32(Int16):
return 2
case k.ToCInt() == int32(Int):
return 4
case k.ToCInt() == int32(Int64):
return 8
case k.ToCInt() == int32(Half):
return 2
case k.ToCInt() == int32(Float):
return 4
case k.ToCInt() == int32(Double):
return 8
case k.ToCInt() == int32(ComplexHalf):
return 4
case k.ToCInt() == int32(ComplexDouble):
return 16
case k.ToCInt() == int32(Bool):
return 1
default:
log.Fatalf("Unreachable")
}
return uint(0)
}
type KindDevice struct {
Kind Kind
Device Device
}
var (
FloatCPU KindDevice = KindDevice{Float, CPU}
DoubleCPU KindDevice = KindDevice{Double, CPU}
Int64CPU KindDevice = KindDevice{Int64, CPU}
FloatCUDA KindDevice = KindDevice{Float, CudaBuilder(0)}
DoubleCUDA KindDevice = KindDevice{Double, CudaBuilder(0)}
Int64CUDA KindDevice = KindDevice{Int64, CudaBuilder(0)}
)

View File

@ -5,8 +5,8 @@ package libtch
import "C"
import (
"fmt"
"reflect"
// "fmt"
// "reflect"
"unsafe"
)
@ -33,27 +33,27 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
// t is of type `unsafe.Pointer` in Go and `*void` in C
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
// fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
// fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
// Keep C pointer value tin Go struct
cTensorPtrVal := unsafe.Pointer(t)
fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
// fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
var retVal *C_tensor
retVal = &C_tensor{private: cTensorPtrVal}
fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
// fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
// test call C.at_print to print out tensor
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
AtPrint(retVal)
// AtPrint(retVal)
return retVal
}
func AtPrint(t *C_tensor) {
fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
// fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
cTensor := (C.tensor)((*t).private)
fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
// fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
C.at_print(cTensor)
}

View File

@ -1,20 +1,18 @@
package wrapper
//#include <stdlib.h>
// #include <stdlib.h>
import "C"
import (
"bytes"
"encoding/binary"
"fmt"
// "fmt"
"reflect"
"unsafe"
gotch "github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
type Tensor struct {
ctensor *t.C_tensor
ctensor *lib.C_tensor
}
// NewTensor creates a new tensor
@ -24,45 +22,43 @@ func NewTensor() Tensor {
}
// FOfSlice creates tensor from a slice data
func(ts Tensor) FOfSlice(data []inteface{}) (retVal Tensor, err error) {
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
data := []int{0, 0, 0, 0}
shape := []int64{int64(len(data))}
nflattened := numElements(shape)
dtype := 3 // Kind.Int
eltSizeInBytes := 4 // Element Size in Byte for Int dtype
dataLen := reflect.ValueOf(data).Len()
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
// eltSizeInBytes := dtype.EltSizeInBytes() // Element Size in Byte for Int dtype
eltSizeInBytes := gotch.DTypeSize(dtype)
nbytes := eltSizeInBytes * int(uintptr(nflattened))
nbytes := int(eltSizeInBytes) * int(elementNum)
// NOTE: dataPrt is type of `*void` in C or type of `unsafe.Pointer` in Go
// data should be allocated to memory BY `C` side
dataPtr := C.malloc(C.size_t(nbytes))
dataPtr, buff := CMalloc(nbytes)
// Recall: 1 << 30 = 1 * 2 * 30
// Ref. See more at https://stackoverflow.com/questions/48756732
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
EncodeTensor(buf, reflect.ValueOf(data), shape)
c_tensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(dtype))
retVal = Tensor{c_tensor}
// Read back created tensor values by C libtorch
readDataPtr := lib.AtDataPtr(retVal.c_tensor)
readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
// typ := typeOf(dtype, shape)
typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
val := reflect.New(typ)
if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return nil, err
}
tensorData := reflect.Indirect(val).Interface()
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype)))
fmt.Println("%v", tensorData)
retVal = &Tensor{ctensor}
// Read back created tensor values by C libtorch
// readDataPtr := lib.AtDataPtr(retVal.ctensor)
// readDataSlice := (*[1 << 30]byte)(readDataPtr)[:nbytes:nbytes]
// // typ := typeOf(dtype, shape)
// typ := reflect.TypeOf(int32(0)) // C. type `int` ~ Go type `int32`
// val := reflect.New(typ)
// if err := DecodeTensor(bytes.NewReader(readDataSlice), shape, typ, val); err != nil {
// panic(fmt.Sprintf("unable to decode Tensor of type %v and shape %v - %v", dtype, shape, err))
// }
//
// tensorData := reflect.Indirect(val).Interface()
//
// fmt.Println("%v", tensorData)
return retVal, nil
}
func (ts Tensor) Print() {
lib.AtPrint(ts.ctensor)
}

View File

@ -1,16 +1,20 @@
package wrapper
// #include <stdlib.h>
import "C"
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"reflect"
"unsafe"
gotch "github.com/sugarme/gotch"
// gotch "github.com/sugarme/gotch"
)
// nativeEndian is a ByteOrder for local platform.
// Ref. https://stackoverflow.com/a/53286786
// Ref. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/tensor.go#L488-L505
var nativeEndian binary.ByteOrder
func init() {
@ -27,6 +31,36 @@ func init() {
}
}
// CMalloc allocates a given number of bytes to C side memory.
// It returns
// - dataPtr: a C pointer type of `*void` (`unsafe.Pointer` in Go).
// - buf : a Go pointer points to a given bytes of buffer (empty) in C memory
// allocated by C waiting for writing data to.
//
// NOTE:
// 1. Go pointer is a pointer to Go memory. C pointer is a pointer to C memory.
// 2. General rule is Go code can use C pointers. Go code may pass Go pointer to C
// provided that the Go memory to which it points does NOT contain any Go
// pointers. BUT C code must not store any Go pointers in Go memory, even
// temporarily.
// 3. Some Go values contain Go pointers IMPLICITLY: strings, slices, maps,
// channels and function values. Thus, pointers to these values should not be
// passed to C side. Instead, data should be allocated to C memory and return a
// C pointer to it using `C.malloc`.
// Ref: https://github.com/golang/proposal/blob/master/design/12416-cgo-pointers.md
func CMalloc(nbytes int) (dataPtr unsafe.Pointer, buf *bytes.Buffer) {
dataPtr = C.malloc(C.size_t(nbytes))
// Recall: 1 << 30 = 1 * 2 * 30
// Ref. See more at https://stackoverflow.com/questions/48756732
dataSlice := (*[1 << 30]byte)(dataPtr)[:nbytes:nbytes]
buf = bytes.NewBuffer(dataSlice[:0:nbytes])
return dataPtr, buf
}
// EncodeTensor loads tensor data to C memory and returns a C pointer.
func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
switch v.Kind() {
case reflect.Bool:
@ -37,7 +71,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if err := w.WriteByte(b); err != nil {
return err
}
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
@ -55,7 +89,7 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
if len(shape) == 1 && v.Len() > 0 {
switch v.Index(0).Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Write(w, nativeEndian, v.Interface())
}
}
@ -74,8 +108,8 @@ func EncodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
return nil
}
// DecodeTensor decodes the Tensor from the buffer to ptr using the format
// specified in c_api.h. Use stringDecoder for String tensors.
// DecodeTensor decodes tensor value from a C memory buffer given
// C pointer, data type and shape and returns data value of type interface
func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Bool:
@ -84,7 +118,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return err
}
ptr.Elem().SetBool(b == 1)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
@ -96,7 +130,7 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
if len(shape) == 1 && val.Len() > 0 {
switch val.Index(0).Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64:
return binary.Read(r, nativeEndian, val.Interface())
}
}
@ -113,47 +147,11 @@ func DecodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.
return nil
}
func numElements(shape []int64) int64 {
// ElementCount counts number of element in the tensor given a shape
func ElementCount(shape []int64) int64 {
n := int64(1)
for _, d := range shape {
n *= d
}
return n
}
// GetKind returns data type `Kind` (a element of tensor can hold)
// v - a value of a data element
func GetKind(v interface{}) (retVal gotch.Kind, err error) {
switch {
case reflect.TypeOf(v) == int:
retVal = gotch.Int
case reflect.TypeOf(v) == uint8:
retVal = gotch.Uint8
default:
err = fmt.Errorf("Unsupported data type for %v\n", reflect.TypeOf(v))
return retVal, err
}
return retVal, nil
}
// // TypeOf converts from a DType and Shape to the equivalent Go type.
// func TypeOf(dt DType, shape []int64) reflect.Type {
// var ret reflect.Type
// for _, t := range types {
// if dt == DType(t.dataType) {
// ret = t.typ
// break
// }
// }
// if ret == nil {
// // TODO get tensor name
// panic(fmt.Sprintf("Unsupported DType %d", int(dt)))
// }
// for range shape {
// ret = reflect.SliceOf(ret)
// }
// return ret
// }