feat(wrapper): updated FlattenData to deal with nested slice
This commit is contained in:
parent
b87d3c8281
commit
67c80a4786
92
dtype.go
92
dtype.go
|
@ -163,31 +163,67 @@ var (
|
||||||
|
|
||||||
// DTypeFromData infers returns equavalent DType from given data
|
// DTypeFromData infers returns equavalent DType from given data
|
||||||
func DTypeFromData(data interface{}) (retVal DType, err error) {
|
func DTypeFromData(data interface{}) (retVal DType, err error) {
|
||||||
dataKind := reflect.ValueOf(data).Kind()
|
|
||||||
var dataType reflect.Type
|
// NOTE: call `Interface()` to get data type back to interface{} type
|
||||||
switch dataKind {
|
typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
if err != nil {
|
||||||
dataType = reflect.TypeOf(data)
|
return retVal, err
|
||||||
case reflect.Slice:
|
|
||||||
dataType = reflect.TypeOf(data).Elem()
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
|
||||||
return DType{}, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ToDType(dataType)
|
return ToDType(typ)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: 0 is reflect.Kind() of Invalid
|
||||||
|
// See: https://golang.org/pkg/reflect/#Kind
|
||||||
|
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||||
|
v := reflect.ValueOf(data)
|
||||||
|
var goType reflect.Type = reflect.TypeOf(data)
|
||||||
|
var total int = count
|
||||||
|
var round = 0
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
if round == 0 {
|
||||||
|
round = v.Len()
|
||||||
|
}
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
round--
|
||||||
|
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return reflect.TypeOf(reflect.Zero), 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return goType, total, nil
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
|
total++
|
||||||
|
if goType.String() != "invalid" {
|
||||||
|
goType = v.Type()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
||||||
|
return reflect.TypeOf(reflect.Zero), 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return goType, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ElementGoType infers and returns Go type of element in given data
|
// ElementGoType infers and returns Go type of element in given data
|
||||||
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
|
func ElementGoType(data interface{}) (retVal reflect.Type, err error) {
|
||||||
dataKind := reflect.ValueOf(data).Kind()
|
dataValue := reflect.ValueOf(data)
|
||||||
var dataType reflect.Type
|
return elementType(dataValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func elementType(data reflect.Value) (dataType reflect.Type, err error) {
|
||||||
|
dataKind := data.Kind()
|
||||||
switch dataKind {
|
switch dataKind {
|
||||||
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
dataType = reflect.TypeOf(data)
|
dataType = data.Type()
|
||||||
case reflect.Slice:
|
case reflect.Slice, reflect.Array:
|
||||||
dataType = reflect.TypeOf(data).Elem()
|
data = data.Elem()
|
||||||
|
dataType, err = elementType(data) // recursively type inferring
|
||||||
default:
|
default:
|
||||||
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
err = fmt.Errorf("Unsupported type for data type %v\n", dataType)
|
||||||
return DType{}, err
|
return DType{}, err
|
||||||
|
@ -257,22 +293,20 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||||
|
|
||||||
// TypeCheck checks whether data Go type matching DType
|
// TypeCheck checks whether data Go type matching DType
|
||||||
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||||
|
dataValue := reflect.ValueOf(data)
|
||||||
dataKind := reflect.ValueOf(data).Kind()
|
var dataType reflect.Type
|
||||||
dataType := reflect.TypeOf(data)
|
var err error
|
||||||
|
dataType, err = elementType(dataValue)
|
||||||
switch dataKind {
|
if err != nil {
|
||||||
case reflect.Slice:
|
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||||
dataEleType := reflect.TypeOf(data).Elem()
|
msg += err.Error()
|
||||||
matched = dataEleType == dtype.Type
|
return false, msg
|
||||||
msg = fmt.Sprintf("data type: %v, DType: %v", dataEleType, dtype.Kind())
|
|
||||||
default:
|
|
||||||
matched = dataType == dtype.Type
|
|
||||||
msg = fmt.Sprintf("data type: %v, DType: %v", dataType, dtype.Kind())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return matched, msg
|
matched = dataType == dtype.Type
|
||||||
|
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||||
|
|
||||||
|
return matched, msg
|
||||||
}
|
}
|
||||||
|
|
||||||
var supportedTypes = map[reflect.Kind]bool{
|
var supportedTypes = map[reflect.Kind]bool{
|
||||||
|
|
|
@ -13,8 +13,12 @@ func main() {
|
||||||
// TODO: Check Go type of data and tensor DType
|
// TODO: Check Go type of data and tensor DType
|
||||||
// For. if data is []int and DType is Bool
|
// For. if data is []int and DType is Bool
|
||||||
// It is still running but get wrong result.
|
// It is still running but get wrong result.
|
||||||
data := []int32{1, 1, 1, 2, 2, 2}
|
data := [][]int64{
|
||||||
shape := []int64{2, 3}
|
{1, 1, 1, 2, 2, 2, 1},
|
||||||
|
{1, 1, 1, 2, 2, 2, 1},
|
||||||
|
}
|
||||||
|
// shape := []int64{2, 7}
|
||||||
|
shape := []int64{2, 7}
|
||||||
|
|
||||||
// dtype := gotch.Int
|
// dtype := gotch.Int
|
||||||
// ts := wrapper.NewTensor()
|
// ts := wrapper.NewTensor()
|
||||||
|
@ -30,4 +34,12 @@ func main() {
|
||||||
|
|
||||||
ts.Print()
|
ts.Print()
|
||||||
|
|
||||||
|
// typ, count, err := wrapper.DataCheck(data)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fmt.Printf("typ: %v\n", typ)
|
||||||
|
// fmt.Printf("Count: %v\n", count)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -77,7 +77,7 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err
|
||||||
nflattend := FlattenDim(shape)
|
nflattend := FlattenDim(shape)
|
||||||
|
|
||||||
if elementNum != nflattend {
|
if elementNum != nflattend {
|
||||||
err = fmt.Errorf("Number of data elements and flatten shape dimension mismatched.\n")
|
err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
165
wrapper/util.go
165
wrapper/util.go
|
@ -158,24 +158,55 @@ func ElementCount(shape []int64) int64 {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataDim returns number of elements in data
|
// DataDim returns number of elements in data
|
||||||
|
// NOTE: only support scalar and (nested) slice/array of scalar type
|
||||||
func DataDim(data interface{}) (retVal int, err error) {
|
func DataDim(data interface{}) (retVal int, err error) {
|
||||||
v := reflect.ValueOf(data)
|
|
||||||
|
|
||||||
switch gotch.IsSupportedScalar(v.Kind()) {
|
_, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
case true:
|
return count, err
|
||||||
retVal = 1
|
}
|
||||||
default:
|
|
||||||
switch v.Kind() {
|
// DataCheck checks the input data for element Go type and number of elements.
|
||||||
case reflect.Slice, reflect.Array:
|
// It will return errors if element type is not supported.
|
||||||
retVal = v.Len()
|
func DataCheck(data interface{}) (k reflect.Type, n int, err error) {
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Cannot count data element due to unsupported data type: %v\n.", v.Kind())
|
return dataCheck(reflect.ValueOf(data).Interface(), 0)
|
||||||
return 0, err
|
}
|
||||||
|
|
||||||
|
// NOTE: 0 is reflect.Kind() of Invalid
|
||||||
|
// See: https://golang.org/pkg/reflect/#Kind
|
||||||
|
func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) {
|
||||||
|
v := reflect.ValueOf(data)
|
||||||
|
var goType reflect.Type = reflect.TypeOf(data)
|
||||||
|
var total int = count
|
||||||
|
var round = 0
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
if round == 0 {
|
||||||
|
round = v.Len()
|
||||||
|
}
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
round--
|
||||||
|
goType, total, err = dataCheck(v.Index(i).Interface(), total)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return reflect.TypeOf(reflect.Zero), 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return goType, total, nil
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool:
|
||||||
|
total++
|
||||||
|
if goType.String() != "invalid" {
|
||||||
|
goType = v.Type()
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind())
|
||||||
|
return reflect.TypeOf(reflect.Zero), 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return retVal, nil
|
return goType, total, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DataAsPtr write to C memory and returns a C pointer.
|
// DataAsPtr write to C memory and returns a C pointer.
|
||||||
|
@ -193,7 +224,6 @@ func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
||||||
|
|
||||||
// 2. Element size in bytes
|
// 2. Element size in bytes
|
||||||
dtype, err := gotch.DTypeFromData(data)
|
dtype, err := gotch.DTypeFromData(data)
|
||||||
fmt.Println(dtype)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -209,7 +239,17 @@ func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) {
|
||||||
dataPtr, buff := CMalloc(nbytes)
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
|
|
||||||
// 4. Write data to C memory
|
// 4. Write data to C memory
|
||||||
err = binary.Write(buff, nativeEndian, data)
|
// NOTE: data should be **fixed size** values so that binary.Write can work
|
||||||
|
// A fixed-size value is either a fixed-size arithmetic type (bool, int8, uint8,
|
||||||
|
// int16, float32, complex64, ...) or an array or struct containing only fixed-size values.
|
||||||
|
// See more: https://golang.org/pkg/encoding/binary/
|
||||||
|
// Therefore, we will need to flatten data to `[]T`
|
||||||
|
fData, err := FlattenData(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = binary.Write(buff, nativeEndian, fData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -226,3 +266,100 @@ func FlattenDim(shape []int64) int {
|
||||||
|
|
||||||
return int(n)
|
return int(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FlattenData flattens data to 1D array ([]T)
|
||||||
|
func FlattenData(data interface{}) (fData interface{}, err error) {
|
||||||
|
|
||||||
|
flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ele := flat[0]
|
||||||
|
|
||||||
|
// Boring task. Convert interface to specific type.
|
||||||
|
// Any good way to do???
|
||||||
|
switch reflect.ValueOf(ele).Kind() {
|
||||||
|
case reflect.Uint8:
|
||||||
|
var retVal []uint8
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(uint8))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Int8:
|
||||||
|
var retVal []int8
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(int8))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Int16:
|
||||||
|
var retVal []int16
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(int16))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Int32:
|
||||||
|
var retVal []int32
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(int32))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Int64:
|
||||||
|
var retVal []int64
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(int64))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Float32:
|
||||||
|
var retVal []float32
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(float32))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Float64:
|
||||||
|
var retVal []float64
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(float64))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
case reflect.Bool:
|
||||||
|
var retVal []bool
|
||||||
|
for _, v := range flat {
|
||||||
|
retVal = append(retVal, v.(bool))
|
||||||
|
}
|
||||||
|
return retVal, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unsupport type for input data: %v\n", reflect.ValueOf(ele).Kind())
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func flattenData(data interface{}, round int, flat []interface{}) (f []interface{}, err error) {
|
||||||
|
v := reflect.ValueOf(data)
|
||||||
|
var flatData []interface{} = flat
|
||||||
|
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Slice, reflect.Array:
|
||||||
|
if round == 0 {
|
||||||
|
round = v.Len()
|
||||||
|
}
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
round--
|
||||||
|
flatData, err = flattenData(v.Index(i).Interface(), round, flatData)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return flatData, nil
|
||||||
|
|
||||||
|
case reflect.Int32, reflect.Int64:
|
||||||
|
flatData = append(flatData, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return flatData, nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user