feat(wrapper): cleanup and add more type inferring
This commit is contained in:
parent
98c182cef8
commit
2430589319
169
dtype.go
169
dtype.go
|
@ -38,91 +38,46 @@ var (
|
|||
Bool DType = DType{reflect.TypeOf(true)} // 11
|
||||
)
|
||||
|
||||
/*
|
||||
* // ToCInt converts DType to CInt type value which is `C int`
|
||||
* func (dt DType) ToCInt() CInt {
|
||||
* switch dt.Kind() {
|
||||
* case reflect.Uint8:
|
||||
* return 0
|
||||
* case reflect.Int8:
|
||||
* return 1
|
||||
* case reflect.Int16:
|
||||
* return 2
|
||||
* case reflect.Int32:
|
||||
* return 3
|
||||
* case reflect.Int64:
|
||||
* return 4
|
||||
* case reflect.Float32:
|
||||
* return 6
|
||||
* case reflect.Float64:
|
||||
* return 7
|
||||
* case reflect.Bool:
|
||||
* return 11
|
||||
* default:
|
||||
* log.Fatalf("Unsupported type.")
|
||||
* }
|
||||
*
|
||||
* // unreachable
|
||||
* return CInt(-1)
|
||||
* }
|
||||
*
|
||||
* // OfCInt converts a value of type CInt to DType type value
|
||||
* func (dt DType) OfCInt(v CInt) DType {
|
||||
* switch v {
|
||||
* case 0:
|
||||
* return Uint8
|
||||
* case 1:
|
||||
* return Int8
|
||||
* case 2:
|
||||
* return Int16
|
||||
* case 3:
|
||||
* return Int
|
||||
* case 4:
|
||||
* return Int64
|
||||
* case 6:
|
||||
* return Float
|
||||
* case 7:
|
||||
* return Double
|
||||
* case 8:
|
||||
* case 11:
|
||||
* return Bool
|
||||
* default:
|
||||
* log.Fatalf("Unexpected DType %v\n", v)
|
||||
* }
|
||||
* return DType{reflect.TypeOf(false)}
|
||||
* }
|
||||
*
|
||||
* // EltSizeInBytes converts a DType value to number of bytes
|
||||
* // This is a ELement Size In Bytes in Libtorch.
|
||||
* // Has it been deprecated?
|
||||
* func (dt DType) EltSizeInBytes() uint {
|
||||
* switch dt.Kind() {
|
||||
* case reflect.Uint8:
|
||||
* return 1
|
||||
* case reflect.Int8:
|
||||
* return 1
|
||||
* case reflect.Int16:
|
||||
* return 2
|
||||
* case reflect.Int:
|
||||
* return 4
|
||||
* case reflect.Int64:
|
||||
* return 8
|
||||
* case reflect.Float32:
|
||||
* return 4
|
||||
* case reflect.Float64:
|
||||
* return 8
|
||||
* case reflect.Bool:
|
||||
* return 1
|
||||
* default:
|
||||
* log.Fatalf("Unsupported Type %v\n", dt.Type)
|
||||
* }
|
||||
* return uint(0)
|
||||
* }
|
||||
* */
|
||||
var dtypeGoType = map[DType]reflect.Type{
|
||||
Uint8: reflect.TypeOf(uint8(1)),
|
||||
Int8: reflect.TypeOf(int8(1)),
|
||||
Int16: reflect.TypeOf(int16(1)),
|
||||
Int: reflect.TypeOf(int32(1)),
|
||||
Int64: reflect.TypeOf(int64(1)),
|
||||
Float: reflect.TypeOf(float32(1)),
|
||||
Double: reflect.TypeOf(float64(1)),
|
||||
Bool: reflect.TypeOf(true),
|
||||
}
|
||||
|
||||
// ToGoType converts DType to Go type
|
||||
func (dt DType) ToGoType() reflect.Type {
|
||||
return dt.Type
|
||||
// ToDType infers and returns supported equivalent DType from given Go type
|
||||
func ToDType(typ reflect.Type) (retVal DType, err error) {
|
||||
var found = false
|
||||
for key, val := range dtypeGoType {
|
||||
if val == typ {
|
||||
retVal = key
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
err = fmt.Errorf("Unsupported Go type: %v", typ)
|
||||
return DType{}, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// ToGoType infers and returns supported equivalent Go type from given DType
|
||||
func ToGoType(dtype DType) (retVal reflect.Type, err error) {
|
||||
if _, ok := dtypeGoType[dtype]; !ok {
|
||||
err = fmt.Errorf("Unsupported DType %v", dtype)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retVal = dtypeGoType[dtype]
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
var dtypeCInt = map[DType]CInt{
|
||||
|
@ -136,8 +91,14 @@ var dtypeCInt = map[DType]CInt{
|
|||
Bool: 11,
|
||||
}
|
||||
|
||||
func DType2CInt(dt DType) CInt {
|
||||
return dtypeCInt[dt]
|
||||
func DType2CInt(dt DType) (retVal CInt, err error) {
|
||||
if _, ok := dtypeCInt[dt]; !ok {
|
||||
err = fmt.Errorf("Unsupported CInt conversion from DType: %v\n", dt)
|
||||
}
|
||||
|
||||
retVal = dtypeCInt[dt]
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||
|
@ -171,8 +132,15 @@ var dtypeSize = map[DType]uint{
|
|||
}
|
||||
|
||||
// DTypeSize returns DType size in Bytes
|
||||
func DTypeSize(dt DType) uint {
|
||||
return dtypeSize[dt]
|
||||
func DTypeSize(dt DType) (retVal uint, err error) {
|
||||
if _, ok := dtypeSize[dt]; !ok {
|
||||
err = fmt.Errorf("Unsupported conversion DType size in Byte for DType: %v\n", dt)
|
||||
return 99, err
|
||||
}
|
||||
|
||||
retVal = dtypeSize[dt]
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
type DTypeDevice struct {
|
||||
|
@ -193,6 +161,26 @@ var (
|
|||
// Type Inferring:
|
||||
// ===============
|
||||
|
||||
// DTypeFromData infers returns equavalent DType from given data
|
||||
func DTypeFromData(data interface{}) (retVal DType, err error) {
|
||||
dataKind := reflect.ValueOf(data).Kind()
|
||||
var dataType reflect.Type
|
||||
switch dataKind {
|
||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||
dataType = reflect.TypeOf(data)
|
||||
case reflect.Slice:
|
||||
dataType = reflect.TypeOf(data).Elem()
|
||||
default:
|
||||
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
||||
return DType{}, err
|
||||
}
|
||||
|
||||
retVal = DType{reflect.TypeOf(dataType)}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
// DataDType infers and returns data type of tensor data
|
||||
func DataDType(v interface{}, shape []int64) (retVal DType, err error) {
|
||||
// assuming that all elements in data have the same type
|
||||
|
@ -236,7 +224,10 @@ func ElementDType(v interface{}) (retVal DType, err error) {
|
|||
|
||||
// TypeOf infers and returns element Go type from given tensor DType and shape
|
||||
func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||
typ := dt.ToGoType()
|
||||
var typ reflect.Type
|
||||
if typ, err = ToGoType(dt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(shape) == 0:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
|
@ -12,9 +14,11 @@ func main() {
|
|||
// TODO: Check Go type of data and tensor DType
|
||||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
data := []float32{1.1, 1.2, 1.1}
|
||||
data := []int32{1, 0, 1}
|
||||
dtype := gotch.Int
|
||||
|
||||
fmt.Println(gotch.DType{reflect.TypeOf(data)})
|
||||
|
||||
ts := wrapper.NewTensor()
|
||||
sliceTensor, err := ts.FOfSlice(data, dtype)
|
||||
if err != nil {
|
||||
|
|
|
@ -4,7 +4,7 @@ package wrapper
|
|||
import "C"
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
|
@ -24,11 +24,19 @@ func NewTensor() Tensor {
|
|||
// FOfSlice creates tensor from a slice data
|
||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||
|
||||
if ok, msg := gotch.TypeCheck(data, dtype); !ok {
|
||||
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
shape := []int64{int64(dataLen)}
|
||||
elementNum := ElementCount(shape)
|
||||
|
||||
eltSizeInBytes := gotch.DTypeSize(dtype)
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
||||
|
@ -38,7 +46,12 @@ func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor,
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(gotch.DType2CInt(dtype)))
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
|
||||
retVal = &Tensor{ctensor}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user