feat(wrapper): cleanup and add more type inferring

This commit is contained in:
sugarme 2020-05-30 12:36:49 +10:00
parent 98c182cef8
commit 2430589319
3 changed files with 101 additions and 93 deletions

169
dtype.go
View File

@ -38,91 +38,46 @@ var (
Bool DType = DType{reflect.TypeOf(true)} // 11
)
/*
* // ToCInt converts DType to CInt type value which is `C int`
* func (dt DType) ToCInt() CInt {
* switch dt.Kind() {
* case reflect.Uint8:
* return 0
* case reflect.Int8:
* return 1
* case reflect.Int16:
* return 2
* case reflect.Int32:
* return 3
* case reflect.Int64:
* return 4
* case reflect.Float32:
* return 6
* case reflect.Float64:
* return 7
* case reflect.Bool:
* return 11
* default:
* log.Fatalf("Unsupported type.")
* }
*
* // unreachable
* return CInt(-1)
* }
*
* // OfCInt converts a value of type CInt to DType type value
* func (dt DType) OfCInt(v CInt) DType {
* switch v {
* case 0:
* return Uint8
* case 1:
* return Int8
* case 2:
* return Int16
* case 3:
* return Int
* case 4:
* return Int64
* case 6:
* return Float
* case 7:
* return Double
* case 8:
* case 11:
* return Bool
* default:
* log.Fatalf("Unexpected DType %v\n", v)
* }
* return DType{reflect.TypeOf(false)}
* }
*
* // EltSizeInBytes converts a DType value to number of bytes
* // This is a ELement Size In Bytes in Libtorch.
* // Has it been deprecated?
* func (dt DType) EltSizeInBytes() uint {
* switch dt.Kind() {
* case reflect.Uint8:
* return 1
* case reflect.Int8:
* return 1
* case reflect.Int16:
* return 2
* case reflect.Int:
* return 4
* case reflect.Int64:
* return 8
* case reflect.Float32:
* return 4
* case reflect.Float64:
* return 8
* case reflect.Bool:
* return 1
* default:
* log.Fatalf("Unsupported Type %v\n", dt.Type)
* }
* return uint(0)
* }
* */
var dtypeGoType = map[DType]reflect.Type{
Uint8: reflect.TypeOf(uint8(1)),
Int8: reflect.TypeOf(int8(1)),
Int16: reflect.TypeOf(int16(1)),
Int: reflect.TypeOf(int32(1)),
Int64: reflect.TypeOf(int64(1)),
Float: reflect.TypeOf(float32(1)),
Double: reflect.TypeOf(float64(1)),
Bool: reflect.TypeOf(true),
}
// ToGoType converts DType to Go type
func (dt DType) ToGoType() reflect.Type {
return dt.Type
// ToDType infers and returns supported equivalent DType from given Go type
func ToDType(typ reflect.Type) (retVal DType, err error) {
var found = false
for key, val := range dtypeGoType {
if val == typ {
retVal = key
found = true
break
}
}
if !found {
err = fmt.Errorf("Unsupported Go type: %v", typ)
return DType{}, err
}
return retVal, nil
}
// ToGoType infers and returns supported equivalent Go type from given DType
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
if _, ok := dtypeGoType[dtype]; !ok {
err = fmt.Errorf("Unsupported DType %v", dtype)
return nil, err
}
retVal = dtypeGoType[dtype]
return retVal, nil
}
var dtypeCInt = map[DType]CInt{
@ -136,8 +91,14 @@ var dtypeCInt = map[DType]CInt{
Bool: 11,
}
func DType2CInt(dt DType) CInt {
return dtypeCInt[dt]
func DType2CInt(dt DType) (retVal CInt, err error) {
if _, ok := dtypeCInt[dt]; !ok {
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
}
retVal = dtypeCInt[dt]
return retVal, nil
}
func CInt2DType(v CInt) (dtype DType, err error) {
@ -171,8 +132,15 @@ var dtypeSize = map[DType]uint{
}
// DTypeSize returns DType size in Bytes
func DTypeSize(dt DType) uint {
return dtypeSize[dt]
func DTypeSize(dt DType) (retVal uint, err error) {
if _, ok := dtypeSize[dt]; !ok {
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
return 99, err
}
retVal = dtypeSize[dt]
return retVal, nil
}
type DTypeDevice struct {
@ -193,6 +161,26 @@ var (
// Type Inferring:
// ===============
// DTypeFromData infers returns equavalent DType from given data
func DTypeFromData(data interface{}) (retVal DType, err error) {
dataKind := reflect.ValueOf(data).Kind()
var dataType reflect.Type
switch dataKind {
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
dataType = reflect.TypeOf(data)
case reflect.Slice:
dataType = reflect.TypeOf(data).Elem()
default:
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
return DType{}, err
}
retVal = DType{reflect.TypeOf(dataType)}
return retVal, nil
}
// DataDType infers and returns data type of tensor data
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
// assuming that all elements in data have the same type
@ -236,7 +224,10 @@ func ElementDType(v interface{}) (retVal DType, err error) {
// TypeOf infers and returns element Go type from given tensor DType and shape
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
typ := dt.ToGoType()
var typ reflect.Type
if typ, err = ToGoType(dt); err != nil {
return nil, err
}
switch {
case len(shape) == 0:

View File

@ -1,7 +1,9 @@
package main
import (
"fmt"
"log"
"reflect"
gotch "github.com/sugarme/gotch"
wrapper "github.com/sugarme/gotch/wrapper"
@ -12,9 +14,11 @@ func main() {
// TODO: Check Go type of data and tensor DType
// For. if data is []int and DType is Bool
// It is still running but get wrong result.
data := []float32{1.1, 1.2, 1.1}
data := []int32{1, 0, 1}
dtype := gotch.Int
fmt.Println(gotch.DType{reflect.TypeOf(data)})
ts := wrapper.NewTensor()
sliceTensor, err := ts.FOfSlice(data, dtype)
if err != nil {

View File

@ -4,7 +4,7 @@ package wrapper
import "C"
import (
// "fmt"
"fmt"
"reflect"
gotch "github.com/sugarme/gotch"
@ -24,11 +24,19 @@ func NewTensor() Tensor {
// FOfSlice creates tensor from a slice data
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
if ok, msg := gotch.TypeCheck(data, dtype); !ok {
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg)
return nil, err
}
dataLen := reflect.ValueOf(data).Len()
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)
eltSizeInBytes := gotch.DTypeSize(dtype)
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
@ -38,7 +46,12 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype)))
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
retVal = &Tensor{ctensor}