diff --git a/dtype.go b/dtype.go index f1ba17d..bdf83fc 100644 --- a/dtype.go +++ b/dtype.go @@ -181,6 +181,10 @@ func DTypeFromData(data interface{}) (retVal DType, err error) { return retVal, err } + if typ.Kind() == reflect.Slice { + return ToDType(typ.Elem()) + } + return ToDType(typ) } diff --git a/tensor/util.go b/tensor/util.go index 760c469..f4e9d01 100644 --- a/tensor/util.go +++ b/tensor/util.go @@ -274,6 +274,16 @@ func FlattenDim(shape []int64) int { // FlattenData flattens data to 1D array ([]T) func FlattenData(data interface{}) (fData interface{}, err error) { + // If data is 1D already, just return it. + dataVal := reflect.ValueOf(data) + dataTyp := reflect.TypeOf(data) + if dataVal.Kind() == reflect.Slice { + eleVal := dataTyp.Elem() + if eleVal.Kind() != reflect.Slice { + return data, nil + } + } + flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{}) if err != nil { return nil, err