fix(DTypeFromData and FlattenData): fixed type check didn't reach data type is a slice
This commit is contained in:
parent
959a1c8a99
commit
9f5eccb4e5
4
dtype.go
4
dtype.go
|
@ -181,6 +181,10 @@ func DTypeFromData(data interface{}) (retVal DType, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if typ.Kind() == reflect.Slice {
|
||||||
|
return ToDType(typ.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
return ToDType(typ)
|
return ToDType(typ)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -274,6 +274,16 @@ func FlattenDim(shape []int64) int {
|
||||||
// FlattenData flattens data to 1D array ([]T)
|
// FlattenData flattens data to 1D array ([]T)
|
||||||
func FlattenData(data interface{}) (fData interface{}, err error) {
|
func FlattenData(data interface{}) (fData interface{}, err error) {
|
||||||
|
|
||||||
|
// If data is 1D already, just return it.
|
||||||
|
dataVal := reflect.ValueOf(data)
|
||||||
|
dataTyp := reflect.TypeOf(data)
|
||||||
|
if dataVal.Kind() == reflect.Slice {
|
||||||
|
eleVal := dataTyp.Elem()
|
||||||
|
if eleVal.Kind() != reflect.Slice {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{})
|
flat, err := flattenData(reflect.ValueOf(data).Interface(), 0, []interface{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Reference in New Issue
Block a user