From 9f5eccb4e5fad4562f58e7af277a7ad852d72002 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 22 Jul 2020 15:26:18 +1000 Subject: [PATCH] fix(DTypeFromData and FlattenData): fixed type check didn't reach data type is a slice --- dtype.go | 4 ++++ tensor/util.go | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/dtype.go b/dtype.go index f1ba17d..bdf83fc 100644 --- a/dtype.go +++ b/dtype.go @@ -181,6 +181,10 @@ func DTypeFromData(data interface{}) (retVal DType, err error) { return retVal, err } + if typ.Kind() == reflect.Slice { + return ToDType(typ.Elem()) + } + return ToDType(typ) } diff --git a/tensor/util.go b/tensor/util.go index 760c469..f4e9d01 100644 --- a/tensor/util.go +++ b/tensor/util.go @@ -274,6 +274,16 @@ func FlattenDim(shape []int64) int { // FlattenData flattens data to 1D array ([]T) func FlattenData(data interface{}) (fData interface{}, err error) { + // If data is 1D already, just return it. + dataVal := reflect.ValueOf(data) + dataTyp := reflect.TypeOf(data) + if dataVal.Kind() == reflect.Slice { + eleVal := dataTyp.Elem() + if eleVal.Kind() != reflect.Slice { + return data, nil + } + } + flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{}) if err != nil { return nil, err