From 67c80a47863b2d5291f5645792cac21eacd38cf4 Mon Sep 17 00:00:00 2001 From: sugarme Date: Mon, 1 Jun 2020 15:45:25 +1000 Subject: [PATCH] feat(wrapper): updated FlattenData to deal with nested slice --- dtype.go | 92 +++++++++++++++-------- example/tensor/main.go | 16 +++- wrapper/tensor.go | 2 +- wrapper/util.go | 165 +++++++++++++++++++++++++++++++++++++---- 4 files changed, 229 insertions(+), 46 deletions(-) diff --git a/dtype.go b/dtype.go index c580a36..7e877f7 100644 --- a/dtype.go +++ b/dtype.go @@ -163,31 +163,67 @@ var ( // DTypeFromData infers returns equavalent DType from given data func DTypeFromData(data interface{}) (retVal DType, err error) { - dataKind := reflect.ValueOf(data).Kind() - var dataType reflect.Type - switch dataKind { - case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: - dataType = reflect.TypeOf(data) - case reflect.Slice: - dataType = reflect.TypeOf(data).Elem() - default: - err = fmt.Errorf("Unsupported type for data type %v\n", dataType) - return DType{}, err + + // NOTE: call `Interface()` to get data type back to interface{} type + typ, _, err := dataCheck(reflect.ValueOf(data).Interface(), 0) + if err != nil { + return retVal, err } - return ToDType(dataType) + return ToDType(typ) +} +// NOTE: 0 is reflect.Kind() of Invalid +// See: https://golang.org/pkg/reflect/#Kind +func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { + v := reflect.ValueOf(data) + var goType reflect.Type = reflect.TypeOf(data) + var total int = count + var round = 0 + + switch v.Kind() { + case reflect.Slice, reflect.Array: + if round == 0 { + round = v.Len() + } + for i := 0; i < v.Len(); i++ { + round-- + goType, total, err = dataCheck(v.Index(i).Interface(), total) + + if err != nil { + return reflect.TypeOf(reflect.Zero), 0, err + } + } + + return goType, total, nil + + case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + total++ + if goType.String() != "invalid" { + goType = v.Type() + } + default: + err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind()) + return reflect.TypeOf(reflect.Zero), 0, err + } + + return goType, total, nil } // ElementGoType infers and returns Go type of element in given data func ElementGoType(data interface{}) (retVal reflect.Type, err error) { - dataKind := reflect.ValueOf(data).Kind() - var dataType reflect.Type + dataValue := reflect.ValueOf(data) + return elementType(dataValue) +} + +func elementType(data reflect.Value) (dataType reflect.Type, err error) { + dataKind := data.Kind() switch dataKind { case reflect.Uint8, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: - dataType = reflect.TypeOf(data) - case reflect.Slice: - dataType = reflect.TypeOf(data).Elem() + dataType = data.Type() + case reflect.Slice, reflect.Array: + data = data.Elem() + dataType, err = elementType(data) // recursively type inferring default: err = fmt.Errorf("Unsupported type for data type %v\n", dataType) return DType{}, err @@ -257,22 +293,20 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) { // TypeCheck checks whether data Go type matching DType func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { - - dataKind := reflect.ValueOf(data).Kind() - dataType := reflect.TypeOf(data) - - switch dataKind { - case reflect.Slice: - dataEleType := reflect.TypeOf(data).Elem() - matched = dataEleType == dtype.Type - msg = fmt.Sprintf("data type: %v, DType: %v", dataEleType, dtype.Kind()) - default: - matched = dataType == dtype.Type - msg = fmt.Sprintf("data type: %v, DType: %v", dataType, dtype.Kind()) + dataValue := reflect.ValueOf(data) + var dataType reflect.Type + var err error + dataType, err = elementType(dataValue) + if err != nil { + msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) + msg += err.Error() + return false, msg } - return matched, msg + matched = dataType == dtype.Type + msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) + return matched, msg } var supportedTypes = map[reflect.Kind]bool{ diff --git a/example/tensor/main.go b/example/tensor/main.go index aa7cfe9..678a4a0 100644 --- a/example/tensor/main.go +++ b/example/tensor/main.go @@ -13,8 +13,12 @@ func main() { // TODO: Check Go type of data and tensor DType // For. if data is []int and DType is Bool // It is still running but get wrong result. - data := []int32{1, 1, 1, 2, 2, 2} - shape := []int64{2, 3} + data := [][]int64{ + {1, 1, 1, 2, 2, 2, 1}, + {1, 1, 1, 2, 2, 2, 1}, + } + // shape := []int64{2, 7} + shape := []int64{2, 7} // dtype := gotch.Int // ts := wrapper.NewTensor() @@ -30,4 +34,12 @@ func main() { ts.Print() + // typ, count, err := wrapper.DataCheck(data) + // if err != nil { + // log.Fatal(err) + // } + // + // fmt.Printf("typ: %v\n", typ) + // fmt.Printf("Count: %v\n", count) + } diff --git a/wrapper/tensor.go b/wrapper/tensor.go index cb23938..8b7d4f0 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -77,7 +77,7 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err nflattend := FlattenDim(shape) if elementNum != nflattend { - err = fmt.Errorf("Number of data elements and flatten shape dimension mismatched.\n") + err = fmt.Errorf("Number of data elements (%v) and flatten shape (%v) dimension mismatched.\n", elementNum, nflattend) return nil, err } diff --git a/wrapper/util.go b/wrapper/util.go index 49c9fcf..0b48d1f 100644 --- a/wrapper/util.go +++ b/wrapper/util.go @@ -158,24 +158,55 @@ func ElementCount(shape []int64) int64 { } // DataDim returns number of elements in data +// NOTE: only support scalar and (nested) slice/array of scalar type func DataDim(data interface{}) (retVal int, err error) { - v := reflect.ValueOf(data) - switch gotch.IsSupportedScalar(v.Kind()) { - case true: - retVal = 1 - default: - switch v.Kind() { - case reflect.Slice, reflect.Array: - retVal = v.Len() - default: - err = fmt.Errorf("Cannot count data element due to unsupported data type: %v\n.", v.Kind()) - return 0, err + _, count, err := dataCheck(reflect.ValueOf(data).Interface(), 0) + return count, err +} + +// DataCheck checks the input data for element Go type and number of elements. +// It will return errors if element type is not supported. +func DataCheck(data interface{}) (k reflect.Type, n int, err error) { + + return dataCheck(reflect.ValueOf(data).Interface(), 0) +} + +// NOTE: 0 is reflect.Kind() of Invalid +// See: https://golang.org/pkg/reflect/#Kind +func dataCheck(data interface{}, count int) (k reflect.Type, n int, err error) { + v := reflect.ValueOf(data) + var goType reflect.Type = reflect.TypeOf(data) + var total int = count + var round = 0 + + switch v.Kind() { + case reflect.Slice, reflect.Array: + if round == 0 { + round = v.Len() + } + for i := 0; i < v.Len(); i++ { + round-- + goType, total, err = dataCheck(v.Index(i).Interface(), total) + + if err != nil { + return reflect.TypeOf(reflect.Zero), 0, err + } } + return goType, total, nil + + case reflect.Uint8, reflect.Int8, reflect.Int32, reflect.Int64, reflect.Float32, reflect.Float64, reflect.Bool: + total++ + if goType.String() != "invalid" { + goType = v.Type() + } + default: + err = fmt.Errorf("Input Data: unsupported data structure or type: %v\n", v.Kind()) + return reflect.TypeOf(reflect.Zero), 0, err } - return retVal, nil + return goType, total, nil } // DataAsPtr write to C memory and returns a C pointer. @@ -193,7 +224,6 @@ func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) { // 2. Element size in bytes dtype, err := gotch.DTypeFromData(data) - fmt.Println(dtype) if err != nil { return nil, err } @@ -209,7 +239,17 @@ func DataAsPtr(data interface{}) (dataPtr unsafe.Pointer, err error) { dataPtr, buff := CMalloc(nbytes) // 4. Write data to C memory - err = binary.Write(buff, nativeEndian, data) + // NOTE: data should be **fixed size** values so that binary.Write can work + // A fixed-size value is either a fixed-size arithmetic type (bool, int8, uint8, + // int16, float32, complex64, ...) or an array or struct containing only fixed-size values. + // See more: https://golang.org/pkg/encoding/binary/ + // Therefore, we will need to flatten data to `[]T` + fData, err := FlattenData(data) + if err != nil { + return nil, err + } + + err = binary.Write(buff, nativeEndian, fData) if err != nil { return nil, err } @@ -226,3 +266,100 @@ func FlattenDim(shape []int64) int { return int(n) } + +// FlattenData flattens data to 1D array ([]T) +func FlattenData(data interface{}) (fData interface{}, err error) { + + flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{}) + if err != nil { + return nil, err + } + + ele := flat[0] + + // Boring task. Convert interface to specific type. + // Any good way to do??? + switch reflect.ValueOf(ele).Kind() { + case reflect.Uint8: + var retVal []uint8 + for _, v := range flat { + retVal = append(retVal, v.(uint8)) + } + return retVal, nil + case reflect.Int8: + var retVal []int8 + for _, v := range flat { + retVal = append(retVal, v.(int8)) + } + return retVal, nil + case reflect.Int16: + var retVal []int16 + for _, v := range flat { + retVal = append(retVal, v.(int16)) + } + return retVal, nil + case reflect.Int32: + var retVal []int32 + for _, v := range flat { + retVal = append(retVal, v.(int32)) + } + return retVal, nil + case reflect.Int64: + var retVal []int64 + for _, v := range flat { + retVal = append(retVal, v.(int64)) + } + return retVal, nil + case reflect.Float32: + var retVal []float32 + for _, v := range flat { + retVal = append(retVal, v.(float32)) + } + return retVal, nil + case reflect.Float64: + var retVal []float64 + for _, v := range flat { + retVal = append(retVal, v.(float64)) + } + return retVal, nil + case reflect.Bool: + var retVal []bool + for _, v := range flat { + retVal = append(retVal, v.(bool)) + } + return retVal, nil + + default: + err = fmt.Errorf("Unsupport type for input data: %v\n", reflect.ValueOf(ele).Kind()) + return nil, err + } + + return nil, err + +} + +func flattenData(data interface{}, round int, flat []interface{}) (f []interface{}, err error) { + v := reflect.ValueOf(data) + var flatData []interface{} = flat + + switch v.Kind() { + case reflect.Slice, reflect.Array: + if round == 0 { + round = v.Len() + } + for i := 0; i < v.Len(); i++ { + round-- + flatData, err = flattenData(v.Index(i).Interface(), round, flatData) + if err != nil { + return nil, err + } + } + + return flatData, nil + + case reflect.Int32, reflect.Int64: + flatData = append(flatData, data) + } + + return flatData, nil +}