WIP(npy): added ReadNpy and ReadNpz
This commit is contained in:
parent
94197e0710
commit
b4228528bb
36
example/convert-model/main.go
Normal file
36
example/convert-model/main.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
filepath := "../../data/convert-model/bert/model.npz"
|
||||||
|
|
||||||
|
namedTensors, err := ts.ReadNpz(filepath)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
|
||||||
|
/*
|
||||||
|
* for _, nt := range namedTensors {
|
||||||
|
* // fmt.Printf("%q\n", nt.Name)
|
||||||
|
* if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" {
|
||||||
|
* fmt.Printf("%0.3f", nt.Tensor)
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
* */
|
||||||
|
|
||||||
|
// fmt.Printf("%v", namedTensors[70].Tensor)
|
||||||
|
|
||||||
|
outputFile := "../../data/convert-model/bert/model.gt"
|
||||||
|
err = ts.SaveMulti(namedTensors, outputFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
164
tensor/npy.go
164
tensor/npy.go
|
@ -1,11 +1,13 @@
|
||||||
package tensor
|
package tensor
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"archive/zip"
|
||||||
"bufio"
|
"bufio"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
@ -13,7 +15,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NpyMagicString string = `\x93NUMPY`
|
// NpyMagicString string = `\x93NUMPY`
|
||||||
|
NpyMagicString string = "\x93NUMPY"
|
||||||
NpySuffix string = ".npy"
|
NpySuffix string = ".npy"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -54,7 +57,7 @@ func readHeader(r io.Reader) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var hLen int = 0
|
var hLen int = 0
|
||||||
for i := len(headerLen); i > 0; i-- {
|
for i := len(headerLen) - 1; i >= 0; i-- {
|
||||||
hLen = hLen*256 + int(headerLen[i])
|
hLen = hLen*256 + int(headerLen[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,13 +70,24 @@ func readHeader(r io.Reader) (string, error) {
|
||||||
return string(header), nil
|
return string(header), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Header struct {
|
type NpyHeader struct {
|
||||||
descr gotch.DType
|
descr gotch.DType
|
||||||
fortranOrder bool
|
fortranOrder bool
|
||||||
shape []int64
|
shape []int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Header) toString() (string, error) {
|
// NewHeader creates Header from input data
|
||||||
|
//
|
||||||
|
// NOTE. This is mainly for unit test purpose
|
||||||
|
func NewNpyHeader(dtype gotch.DType, fo bool, shape []int64) *NpyHeader {
|
||||||
|
return &NpyHeader{
|
||||||
|
descr: dtype,
|
||||||
|
fortranOrder: fo,
|
||||||
|
shape: shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *NpyHeader) ToString() (string, error) {
|
||||||
var fortranOrder string = "False"
|
var fortranOrder string = "False"
|
||||||
|
|
||||||
if h.fortranOrder {
|
if h.fortranOrder {
|
||||||
|
@ -82,7 +96,7 @@ func (h *Header) toString() (string, error) {
|
||||||
|
|
||||||
var shapeStr []string
|
var shapeStr []string
|
||||||
for _, v := range h.shape {
|
for _, v := range h.shape {
|
||||||
shapeStr = append(shapeStr, string(v))
|
shapeStr = append(shapeStr, fmt.Sprintf("%v", v))
|
||||||
}
|
}
|
||||||
|
|
||||||
shape := strings.Join(shapeStr, ",")
|
shape := strings.Join(shapeStr, ",")
|
||||||
|
@ -110,8 +124,8 @@ func (h *Header) toString() (string, error) {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(shape) > 0 {
|
if len(h.shape) == 1 {
|
||||||
shape = shape + ","
|
shape += ","
|
||||||
}
|
}
|
||||||
|
|
||||||
headStr := fmt.Sprintf("{'descr': '<%v', 'fortran_order': %v, 'shape': (%v), }", descr, fortranOrder, shape)
|
headStr := fmt.Sprintf("{'descr': '<%v', 'fortran_order': %v, 'shape': (%v), }", descr, fortranOrder, shape)
|
||||||
|
@ -119,22 +133,14 @@ func (h *Header) toString() (string, error) {
|
||||||
return headStr, nil
|
return headStr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parser for the npy header, a typical example would be:
|
// ParseNpyHeader parses the given npy header string.
|
||||||
|
//
|
||||||
|
// A typical example would be:
|
||||||
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
|
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
|
||||||
func (h *Header) parse(header string) (*Header, error) {
|
func ParseNpyHeader(header string) (*NpyHeader, error) {
|
||||||
|
|
||||||
// trim matches
|
|
||||||
var chars []rune
|
|
||||||
for _, r := range header {
|
|
||||||
if r == '{' || r == '}' || r == ',' || r == ' ' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
chars = append(chars, r)
|
|
||||||
}
|
|
||||||
|
|
||||||
trimHeader := string(chars)
|
|
||||||
|
|
||||||
|
// trim all prefix or suffix patterns
|
||||||
|
trimHeader := trimMatches([]rune{'{', '}', ','}, header)
|
||||||
var parts []string
|
var parts []string
|
||||||
startIdx := 0
|
startIdx := 0
|
||||||
var cntParenthesis int64 = 0
|
var cntParenthesis int64 = 0
|
||||||
|
@ -155,32 +161,15 @@ func (h *Header) parse(header string) (*Header, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
parts = append(parts, header[startIdx:])
|
parts = append(parts, header[startIdx:])
|
||||||
|
|
||||||
var partMap map[string]string = make(map[string]string)
|
var partMap map[string]string = make(map[string]string)
|
||||||
|
|
||||||
for _, part := range parts {
|
for _, part := range parts {
|
||||||
strings.TrimSpace(part)
|
strings.TrimSpace(part)
|
||||||
p := strings.TrimSpace(part)
|
p := strings.TrimSpace(part)
|
||||||
if len(p) > 0 {
|
if len(p) > 0 {
|
||||||
kv := strings.Split(p, ":")
|
kv := strings.Split(p, ":")
|
||||||
if len(kv) == 2 {
|
if len(kv) == 2 {
|
||||||
var key, value string
|
key := trimMatches([]rune{'\''}, kv[0])
|
||||||
rKey := []rune(kv[0])
|
value := trimMatches([]rune{'\''}, kv[1])
|
||||||
for _, r := range rKey {
|
|
||||||
if r == '\'' || r == ' ' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
key = key + string([]rune{r})
|
|
||||||
}
|
|
||||||
|
|
||||||
rValue := []rune(kv[1])
|
|
||||||
for _, r := range rValue {
|
|
||||||
if r == '\'' || r == ' ' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
value = value + string([]rune{r})
|
|
||||||
}
|
|
||||||
|
|
||||||
partMap[key] = value
|
partMap[key] = value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -217,13 +206,7 @@ func (h *Header) parse(header string) (*Header, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var descrStr string
|
descrStr := trimMatches([]rune{'=', '<'}, d)
|
||||||
for _, r := range d {
|
|
||||||
if r == '=' || r == '<' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
descrStr += string([]rune{r})
|
|
||||||
}
|
|
||||||
|
|
||||||
var descr gotch.DType
|
var descr gotch.DType
|
||||||
switch descrStr {
|
switch descrStr {
|
||||||
|
@ -254,15 +237,7 @@ func (h *Header) parse(header string) (*Header, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var shapeStr string
|
shapeStr := trimMatches([]rune{'(', ')', ','}, s)
|
||||||
for _, r := range s {
|
|
||||||
if r == '(' || r == ')' || r == ',' {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
shapeStr += string([]rune{r})
|
|
||||||
}
|
|
||||||
|
|
||||||
var shape []int64
|
var shape []int64
|
||||||
if len(shapeStr) == 0 {
|
if len(shapeStr) == 0 {
|
||||||
shape = make([]int64, 0)
|
shape = make([]int64, 0)
|
||||||
|
@ -277,13 +252,29 @@ func (h *Header) parse(header string) (*Header, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Header{
|
return &NpyHeader{
|
||||||
descr,
|
descr,
|
||||||
fortranOrder,
|
fortranOrder,
|
||||||
shape,
|
shape,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// trimMatches trims all prefix or suffix specified in the input string slice from string data
|
||||||
|
func trimMatches(matches []rune, s string) string {
|
||||||
|
// First: trim leading and trailing space
|
||||||
|
trimStr := strings.TrimSpace(s)
|
||||||
|
for _, m := range matches {
|
||||||
|
if strings.HasPrefix(trimStr, string([]rune{m})) {
|
||||||
|
trimStr = strings.TrimPrefix(trimStr, string([]rune{m}))
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(trimStr, string([]rune{m})) {
|
||||||
|
trimStr = strings.TrimSuffix(trimStr, string([]rune{m}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return trimStr
|
||||||
|
}
|
||||||
|
|
||||||
// ReadNpy reads a .npy file and returns the stored tensor.
|
// ReadNpy reads a .npy file and returns the stored tensor.
|
||||||
func ReadNpy(filepath string) (*Tensor, error) {
|
func ReadNpy(filepath string) (*Tensor, error) {
|
||||||
|
|
||||||
|
@ -291,6 +282,7 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
r := bufio.NewReader(f)
|
r := bufio.NewReader(f)
|
||||||
|
|
||||||
|
@ -299,9 +291,7 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
hd := new(Header)
|
header, err := ParseNpyHeader(h)
|
||||||
|
|
||||||
header, err := hd.parse(h)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -319,3 +309,57 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
||||||
|
|
||||||
return OfDataSize(data, header.shape, header.descr)
|
return OfDataSize(data, header.shape, header.descr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadNpz reads a compressed numpy file (.npz) and returns named tensors
|
||||||
|
func ReadNpz(filePath string) ([]NamedTensor, error) {
|
||||||
|
var namedTensors []NamedTensor
|
||||||
|
r, err := zip.OpenReader(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Close()
|
||||||
|
|
||||||
|
for _, f := range r.File {
|
||||||
|
basename := f.Name
|
||||||
|
// remove file extension to get tensor name
|
||||||
|
name := strings.TrimSuffix(basename, filepath.Ext(basename))
|
||||||
|
rc, err := f.Open()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
headerStr, err := readHeader(rc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
header, err := ParseNpyHeader(headerStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if header.fortranOrder {
|
||||||
|
err := fmt.Errorf("fortran order not supported.\n")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []byte
|
||||||
|
data, err = ioutil.ReadAll(rc)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor, err := OfDataSize(data, header.shape, header.descr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||||
|
|
||||||
|
// explicitly close before next one
|
||||||
|
rc.Close()
|
||||||
|
data = make([]byte, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return namedTensors, nil
|
||||||
|
}
|
||||||
|
|
62
tensor/npy_test.go
Normal file
62
tensor/npy_test.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
package tensor_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNpyHeaderParse(t *testing.T) {
|
||||||
|
|
||||||
|
h1 := "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }"
|
||||||
|
want1 := ts.NewNpyHeader(gotch.Double, false, []int64{128})
|
||||||
|
|
||||||
|
testParse(t, want1, h1)
|
||||||
|
|
||||||
|
h2 := "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }"
|
||||||
|
want2 := ts.NewNpyHeader(gotch.Float, true, []int64{256, 1, 128})
|
||||||
|
|
||||||
|
testParse(t, want2, h2)
|
||||||
|
|
||||||
|
h3, err := ts.ParseNpyHeader(h1)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
testToString(t, h1, h3)
|
||||||
|
|
||||||
|
h4, err := ts.ParseNpyHeader(h2)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
testToString(t, h2, h4)
|
||||||
|
|
||||||
|
h5 := ts.NewNpyHeader(gotch.Int64, false, []int64{})
|
||||||
|
want5 := "{'descr': '<i8', 'fortran_order': False, 'shape': (), }"
|
||||||
|
testToString(t, want5, h5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func testParse(t *testing.T, want *ts.NpyHeader, headerStr string) {
|
||||||
|
got, err := ts.ParseNpyHeader(headerStr)
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("want: %+v\n", want)
|
||||||
|
t.Errorf("got: %+v\n", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testToString(t *testing.T, want string, h *ts.NpyHeader) {
|
||||||
|
got, err := h.ToString()
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("want: %+v\n", want)
|
||||||
|
t.Errorf("got: %+v\n", got)
|
||||||
|
}
|
||||||
|
}
|
|
@ -197,9 +197,67 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
||||||
func OfDataSize(data []byte, size []int64, dtype gotch.DType) (*Tensor, error) {
|
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
||||||
// TODO: implement
|
|
||||||
|
|
||||||
|
elementNum := ElementCount(shape)
|
||||||
|
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||||
|
|
||||||
|
if nbytes != len(data) {
|
||||||
|
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
|
defer C.free(unsafe.Pointer(dataPtr))
|
||||||
|
|
||||||
|
typ, err := gotch.ToGoType(dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var v reflect.Value
|
||||||
|
switch typ.Name() {
|
||||||
|
case "float", "float32":
|
||||||
|
v = reflect.ValueOf(float32(0.1))
|
||||||
|
case "float64":
|
||||||
|
v = reflect.ValueOf(float64(0.1))
|
||||||
|
case "int", "int32":
|
||||||
|
v = reflect.ValueOf(int(1))
|
||||||
|
case "int64":
|
||||||
|
v = reflect.ValueOf(int64(1))
|
||||||
|
case "int8":
|
||||||
|
v = reflect.ValueOf(int8(1))
|
||||||
|
case "uint8":
|
||||||
|
v = reflect.ValueOf(uint8(1))
|
||||||
|
case "bool":
|
||||||
|
v = reflect.ValueOf(false)
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("unsupported dtype: %v\n", dtype)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = EncodeTensor(buff, v, shape); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cint, err := gotch.DType2CInt(dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
buff.Reset()
|
||||||
|
|
||||||
|
return &Tensor{ctensor}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
||||||
|
|
Loading…
Reference in New Issue
Block a user