fixed read npy failed with single element tensor with zero shape
This commit is contained in:
parent
30068c6b41
commit
3959bc3f93
10
ts/npy.go
10
ts/npy.go
|
@ -306,6 +306,11 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(TT.). case tensor 1 element with shape = []
|
||||||
|
if len(data) > 0 && len(header.shape) == 0 {
|
||||||
|
header.shape = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
return OfDataSize(data, header.shape, header.descr)
|
return OfDataSize(data, header.shape, header.descr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,6 +353,11 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE(TT.). case tensor 1 element with shape = []
|
||||||
|
if len(data) > 0 && len(header.shape) == 0 {
|
||||||
|
header.shape = []int64{1}
|
||||||
|
}
|
||||||
|
|
||||||
tensor, err := OfDataSize(data, header.shape, header.descr)
|
tensor, err := OfDataSize(data, header.shape, header.descr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
Loading…
Reference in New Issue
Block a user