diff --git a/dtype.go b/dtype.go index f213930..c580a36 100644 --- a/dtype.go +++ b/dtype.go @@ -175,10 +175,25 @@ func DTypeFromData(data interface{}) (retVal DType, err error) { return DType{}, err } - retVal = DType{reflect.TypeOf(dataType)} + return ToDType(dataType) - return retVal, nil +} +// ElementGoType infers and returns Go type of element in given data +func ElementGoType(data interface{}) (retVal reflect.Type, err error) { + dataKind := reflect.ValueOf(data).Kind() + var dataType reflect.Type + switch dataKind { + case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + dataType = reflect.TypeOf(data) + case reflect.Slice: + dataType = reflect.TypeOf(data).Elem() + default: + err = fmt.Errorf("Unsupported type for data type %v\n", dataType) + return DType{}, err + } + + return dataType, nil } // DataDType infers and returns data type of tensor data @@ -259,3 +274,45 @@ func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { return matched, msg } + +var supportedTypes = map[reflect.Kind]bool{ + reflect.Uint8: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Bool: true, +} + +var scalarTypes = map[reflect.Kind]bool{ + reflect.Bool: true, + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + reflect.Uintptr: true, + reflect.Float32: true, + reflect.Float64: true, + reflect.Complex64: true, + reflect.Complex128: true, +} + +// IsSupportedScalar checks whether given SCALAR type is supported +// TODO: check input is a scalar. +func IsSupportedScalar(k reflect.Kind) bool { + // if _, ok := scalarTypes[k]; !ok { + // log.Fatalf("Input type: %v is not a Go scalar type.", k) + // } + + _, retVal := supportedTypes[k] + + return retVal +} diff --git a/example/tensor/main.go b/example/tensor/main.go index bb95276..aa7cfe9 100644 --- a/example/tensor/main.go +++ b/example/tensor/main.go @@ -1,11 +1,10 @@ package main import ( - "fmt" + // "fmt" "log" - "reflect" - gotch "github.com/sugarme/gotch" + // gotch "github.com/sugarme/gotch" wrapper "github.com/sugarme/gotch/wrapper" ) @@ -14,16 +13,21 @@ func main() { // TODO: Check Go type of data and tensor DType // For. if data is []int and DType is Bool // It is still running but get wrong result. - data := []int32{1, 0, 1} - dtype := gotch.Int + data := []int32{1, 1, 1, 2, 2, 2} + shape := []int64{2, 3} - fmt.Println(gotch.DType{reflect.TypeOf(data)}) + // dtype := gotch.Int + // ts := wrapper.NewTensor() + // sliceTensor, err := ts.FOfSlice(data, dtype) + // if err != nil { + // log.Fatal(err) + // } - ts := wrapper.NewTensor() - sliceTensor, err := ts.FOfSlice(data, dtype) + ts, err := wrapper.NewTensorFromData(data, shape) if err != nil { log.Fatal(err) } - sliceTensor.Print() + ts.Print() + } diff --git a/go.mod b/go.mod index 3f67ae4..a8eb3e0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/sugarme/gotch go 1.14 + +require github.com/aunum/log v0.0.0-20200321163253-24c356e939b0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..96024c1 --- /dev/null +++ b/go.sum @@ -0,0 +1,19 @@ +github.com/aunum/log v0.0.0-20200321163253-24c356e939b0 h1:pLO0OS2sfb+XfZKCPjJeqDYWPngbK786h40oAKDqgpU= +github.com/aunum/log v0.0.0-20200321163253-24c356e939b0/go.mod h1:ze/JIQHfGKwpM8U2b39e8OH0KHt1ovEcjwPV3yfU+/c= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= +github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo= +gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 4724d03..cb23938 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -65,3 +65,48 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, func (ts Tensor) Print() { lib.AtPrint(ts.ctensor) } + +// NewTensorFromData creates tensor from given data and shape +func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err error) { + // 1. Check whether data and shape match + elementNum, err := DataDim(data) + if err != nil { + return nil, err + } + + nflattend := FlattenDim(shape) + + if elementNum != nflattend { + err = fmt.Errorf("Number of data elements and flatten shape dimension mismatched.\n") + return nil, err + } + + // 2. Write raw data to C memory and get C pointer + dataPtr, err := DataAsPtr(data) + if err != nil { + return nil, err + } + + // 3. Create tensor with pointer and shape + dtype, err := gotch.DTypeFromData(data) + if err != nil { + return nil, err + } + + eltSizeInBytes, err := gotch.DTypeSize(dtype) + if err != nil { + return nil, err + } + + cint, err := gotch.DType2CInt(dtype) + if err != nil { + return nil, err + } + + ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint)) + + retVal = &Tensor{ctensor} + + return retVal, nil + +} diff --git a/wrapper/util.go b/wrapper/util.go index 32acdb9..49c9fcf 100644 --- a/wrapper/util.go +++ b/wrapper/util.go @@ -9,7 +9,8 @@ import ( "fmt" "reflect" "unsafe" - // gotch "github.com/sugarme/gotch" + + gotch "github.com/sugarme/gotch" ) // nativeEndian is a ByteOrder for local platform. @@ -155,3 +156,73 @@ func ElementCount(shape []int64) int64 { } return n } + +// DataDim returns number of elements in data +func DataDim(data interface{}) (retVal int, err error) { + v := reflect.ValueOf(data) + + switch gotch.IsSupportedScalar(v.Kind()) { + case true: + retVal = 1 + default: + switch v.Kind() { + case reflect.Slice, reflect.Array: + retVal = v.Len() + default: + err = fmt.Errorf("Cannot count data element due to unsupported data type: %v\n.", v.Kind()) + return 0, err + } + + } + + return retVal, nil +} + +// DataAsPtr write to C memory and returns a C pointer. +// +// NOTE: +// Supported data types are scalar, slice/array of scalar type equivalent to +// DType. +func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) { + + // 1. Count number of elements in data + elementNum, err := DataDim(data) + if err != nil { + return nil, err + } + + // 2. Element size in bytes + dtype, err := gotch.DTypeFromData(data) + fmt.Println(dtype) + if err != nil { + return nil, err + } + + eltSizeInBytes, err := gotch.DTypeSize(dtype) + if err != nil { + return nil, err + } + + nbytes := int(eltSizeInBytes) * int(elementNum) + + // 3. Get C pointer and prepare C memory buffer for writing + dataPtr, buff := CMalloc(nbytes) + + // 4. Write data to C memory + err = binary.Write(buff, nativeEndian, data) + if err != nil { + return nil, err + } + + return dataPtr, nil +} + +// FlattenDim counts number of elements with given shape +func FlattenDim(shape []int64) int { + n := int64(1) + for _, d := range shape { + n *= d + } + + return int(n) +}