feat(wrapper): added more dtype helpers and NewTensorFromData

This commit is contained in:
sugarme 2020-05-30 15:39:56 +10:00
parent 2430589319
commit b87d3c8281
6 changed files with 210 additions and 12 deletions

View File

@ -175,10 +175,25 @@ func DTypeFromData(data interface{}) (retVal DType, err error) {
return DType{}, err
}
retVal = DType{reflect.TypeOf(dataType)}
return ToDType(dataType)
return retVal, nil
}
// ElementGoType infers and returns Go type of element in given data
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
dataKind := reflect.ValueOf(data).Kind()
var dataType reflect.Type
switch dataKind {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
dataType = reflect.TypeOf(data)
case reflect.Slice:
dataType = reflect.TypeOf(data).Elem()
default:
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
return DType{}, err
}
return dataType, nil
}
// DataDType infers and returns data type of tensor data
@ -259,3 +274,45 @@ func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
return matched, msg
}
var supportedTypes = map[reflect.Kind]bool{
reflect.Uint8: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Bool: true,
}
var scalarTypes = map[reflect.Kind]bool{
reflect.Bool: true,
reflect.Int: true,
reflect.Int8: true,
reflect.Int16: true,
reflect.Int32: true,
reflect.Int64: true,
reflect.Uint: true,
reflect.Uint8: true,
reflect.Uint16: true,
reflect.Uint32: true,
reflect.Uint64: true,
reflect.Uintptr: true,
reflect.Float32: true,
reflect.Float64: true,
reflect.Complex64: true,
reflect.Complex128: true,
}
// IsSupportedScalar checks whether given SCALAR type is supported
// TODO: check input is a scalar.
func IsSupportedScalar(k reflect.Kind) bool {
// if _, ok := scalarTypes[k]; !ok {
// log.Fatalf("Input type: %v is not a Go scalar type.", k)
// }
_, retVal := supportedTypes[k]
return retVal
}

View File

@ -1,11 +1,10 @@
package main
import (
"fmt"
// "fmt"
"log"
"reflect"
gotch "github.com/sugarme/gotch"
// gotch "github.com/sugarme/gotch"
wrapper "github.com/sugarme/gotch/wrapper"
)
@ -14,16 +13,21 @@ func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
data := []int32{1, 0, 1}
dtype := gotch.Int
data := []int32{1, 1, 1, 2, 2, 2}
shape := []int64{2, 3}
fmt.Println(gotch.DType{reflect.TypeOf(data)})
// dtype := gotch.Int
// ts := wrapper.NewTensor()
// sliceTensor, err := ts.FOfSlice(data, dtype)
// if err != nil {
// log.Fatal(err)
// }
ts := wrapper.NewTensor()
sliceTensor, err := ts.FOfSlice(data, dtype)
ts, err := wrapper.NewTensorFromData(data, shape)
if err != nil {
log.Fatal(err)
}
sliceTensor.Print()
ts.Print()
}

2
go.mod
View File

@ -1,3 +1,5 @@
module github.com/sugarme/gotch
go 1.14
require github.com/aunum/log v0.0.0-20200321163253-24c356e939b0

19
go.sum Normal file
View File

@ -0,0 +1,19 @@
github.com/aunum/log v0.0.0-20200321163253-24c356e939b0 h1:pLO0OS2sfb+XfZKCPjJeqDYWPngbK786h40oAKDqgpU=
github.com/aunum/log v0.0.0-20200321163253-24c356e939b0/go.mod h1:ze/JIQHfGKwpM8U2b39e8OH0KHt1ovEcjwPV3yfU+/c=
github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s=
github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA=
github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM=
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037 h1:YyJpGZS1sBuBCzLAR1VEpK193GlqGZbnPFnPV/5Rsb4=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.7 h1:VUgggvou5XRW9mHwD/yXxIYSMtY0zoKQf/v226p2nyo=
gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

View File

@ -65,3 +65,48 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
func (ts Tensor) Print() {
lib.AtPrint(ts.ctensor)
}
// NewTensorFromData creates tensor from given data and shape
func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err error) {
// 1. Check whether data and shape match
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
nflattend := FlattenDim(shape)
if elementNum != nflattend {
err = fmt.Errorf("Number of data elements and flatten shape dimension mismatched.\n")
return nil, err
}
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
if err != nil {
return nil, err
}
// 3. Create tensor with pointer and shape
dtype, err := gotch.DTypeFromData(data)
if err != nil {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
retVal = &Tensor{ctensor}
return retVal, nil
}

View File

@ -9,7 +9,8 @@ import (
"fmt"
"reflect"
"unsafe"
// gotch "github.com/sugarme/gotch"
gotch "github.com/sugarme/gotch"
)
// nativeEndian is a ByteOrder for local platform.
@ -155,3 +156,73 @@ func ElementCount(shape []int64) int64 {
}
return n
}
// DataDim returns number of elements in data
func DataDim(data interface{}) (retVal int, err error) {
v := reflect.ValueOf(data)
switch gotch.IsSupportedScalar(v.Kind()) {
case true:
retVal = 1
default:
switch v.Kind() {
case reflect.Slice, reflect.Array:
retVal = v.Len()
default:
err = fmt.Errorf("Cannot count data element due to unsupported data type: %v\n.", v.Kind())
return 0, err
}
}
return retVal, nil
}
// DataAsPtr write to C memory and returns a C pointer.
//
// NOTE:
// Supported data types are scalar, slice/array of scalar type equivalent to
// DType.
func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
// 1. Count number of elements in data
elementNum, err := DataDim(data)
if err != nil {
return nil, err
}
// 2. Element size in bytes
dtype, err := gotch.DTypeFromData(data)
fmt.Println(dtype)
if err != nil {
return nil, err
}
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
// 3. Get C pointer and prepare C memory buffer for writing
dataPtr, buff := CMalloc(nbytes)
// 4. Write data to C memory
err = binary.Write(buff, nativeEndian, data)
if err != nil {
return nil, err
}
return dataPtr, nil
}
// FlattenDim counts number of elements with given shape
func FlattenDim(shape []int64) int {
n := int64(1)
for _, d := range shape {
n *= d
}
return int(n)
}