diff --git a/ts/npy.go b/ts/npy.go index e621f84..9f8f2e8 100644 --- a/ts/npy.go +++ b/ts/npy.go @@ -306,6 +306,11 @@ func ReadNpy(filepath string) (*Tensor, error) { return nil, err } + // NOTE(TT.). case tensor 1 element with shape = [] + if len(data) > 0 && len(header.shape) == 0 { + header.shape = []int64{1} + } + return OfDataSize(data, header.shape, header.descr) } @@ -348,6 +353,11 @@ func ReadNpz(filePath string) ([]NamedTensor, error) { return nil, err } + // NOTE(TT.). case tensor 1 element with shape = [] + if len(data) > 0 && len(header.shape) == 0 { + header.shape = []int64{1} + } + tensor, err := OfDataSize(data, header.shape, header.descr) if err != nil { return nil, err