WIP(npy): added ReadNpy and ReadNpz
This commit is contained in:
parent
94197e0710
commit
b4228528bb
36
example/convert-model/main.go
Normal file
36
example/convert-model/main.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
filepath := "../../data/convert-model/bert/model.npz"
|
||||
|
||||
namedTensors, err := ts.ReadNpz(filepath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
|
||||
/*
|
||||
* for _, nt := range namedTensors {
|
||||
* // fmt.Printf("%q\n", nt.Name)
|
||||
* if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" {
|
||||
* fmt.Printf("%0.3f", nt.Tensor)
|
||||
* }
|
||||
* }
|
||||
* */
|
||||
|
||||
// fmt.Printf("%v", namedTensors[70].Tensor)
|
||||
|
||||
outputFile := "../../data/convert-model/bert/model.gt"
|
||||
err = ts.SaveMulti(namedTensors, outputFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
164
tensor/npy.go
164
tensor/npy.go
|
@ -1,11 +1,13 @@
|
|||
package tensor
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
|
@ -13,7 +15,8 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
NpyMagicString string = `\x93NUMPY`
|
||||
// NpyMagicString string = `\x93NUMPY`
|
||||
NpyMagicString string = "\x93NUMPY"
|
||||
NpySuffix string = ".npy"
|
||||
)
|
||||
|
||||
|
@ -54,7 +57,7 @@ func readHeader(r io.Reader) (string, error) {
|
|||
}
|
||||
|
||||
var hLen int = 0
|
||||
for i := len(headerLen); i > 0; i-- {
|
||||
for i := len(headerLen) - 1; i >= 0; i-- {
|
||||
hLen = hLen*256 + int(headerLen[i])
|
||||
}
|
||||
|
||||
|
@ -67,13 +70,24 @@ func readHeader(r io.Reader) (string, error) {
|
|||
return string(header), nil
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
type NpyHeader struct {
|
||||
descr gotch.DType
|
||||
fortranOrder bool
|
||||
shape []int64
|
||||
}
|
||||
|
||||
func (h *Header) toString() (string, error) {
|
||||
// NewHeader creates Header from input data
|
||||
//
|
||||
// NOTE. This is mainly for unit test purpose
|
||||
func NewNpyHeader(dtype gotch.DType, fo bool, shape []int64) *NpyHeader {
|
||||
return &NpyHeader{
|
||||
descr: dtype,
|
||||
fortranOrder: fo,
|
||||
shape: shape,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *NpyHeader) ToString() (string, error) {
|
||||
var fortranOrder string = "False"
|
||||
|
||||
if h.fortranOrder {
|
||||
|
@ -82,7 +96,7 @@ func (h *Header) toString() (string, error) {
|
|||
|
||||
var shapeStr []string
|
||||
for _, v := range h.shape {
|
||||
shapeStr = append(shapeStr, string(v))
|
||||
shapeStr = append(shapeStr, fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
shape := strings.Join(shapeStr, ",")
|
||||
|
@ -110,8 +124,8 @@ func (h *Header) toString() (string, error) {
|
|||
return "", err
|
||||
}
|
||||
|
||||
if len(shape) > 0 {
|
||||
shape = shape + ","
|
||||
if len(h.shape) == 1 {
|
||||
shape += ","
|
||||
}
|
||||
|
||||
headStr := fmt.Sprintf("{'descr': '<%v', 'fortran_order': %v, 'shape': (%v), }", descr, fortranOrder, shape)
|
||||
|
@ -119,22 +133,14 @@ func (h *Header) toString() (string, error) {
|
|||
return headStr, nil
|
||||
}
|
||||
|
||||
// Parser for the npy header, a typical example would be:
|
||||
// ParseNpyHeader parses the given npy header string.
|
||||
//
|
||||
// A typical example would be:
|
||||
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
|
||||
func (h *Header) parse(header string) (*Header, error) {
|
||||
|
||||
// trim matches
|
||||
var chars []rune
|
||||
for _, r := range header {
|
||||
if r == '{' || r == '}' || r == ',' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
|
||||
chars = append(chars, r)
|
||||
}
|
||||
|
||||
trimHeader := string(chars)
|
||||
func ParseNpyHeader(header string) (*NpyHeader, error) {
|
||||
|
||||
// trim all prefix or suffix patterns
|
||||
trimHeader := trimMatches([]rune{'{', '}', ','}, header)
|
||||
var parts []string
|
||||
startIdx := 0
|
||||
var cntParenthesis int64 = 0
|
||||
|
@ -155,32 +161,15 @@ func (h *Header) parse(header string) (*Header, error) {
|
|||
}
|
||||
|
||||
parts = append(parts, header[startIdx:])
|
||||
|
||||
var partMap map[string]string = make(map[string]string)
|
||||
|
||||
for _, part := range parts {
|
||||
strings.TrimSpace(part)
|
||||
p := strings.TrimSpace(part)
|
||||
if len(p) > 0 {
|
||||
kv := strings.Split(p, ":")
|
||||
if len(kv) == 2 {
|
||||
var key, value string
|
||||
rKey := []rune(kv[0])
|
||||
for _, r := range rKey {
|
||||
if r == '\'' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
key = key + string([]rune{r})
|
||||
}
|
||||
|
||||
rValue := []rune(kv[1])
|
||||
for _, r := range rValue {
|
||||
if r == '\'' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
value = value + string([]rune{r})
|
||||
}
|
||||
|
||||
key := trimMatches([]rune{'\''}, kv[0])
|
||||
value := trimMatches([]rune{'\''}, kv[1])
|
||||
partMap[key] = value
|
||||
}
|
||||
}
|
||||
|
@ -217,13 +206,7 @@ func (h *Header) parse(header string) (*Header, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var descrStr string
|
||||
for _, r := range d {
|
||||
if r == '=' || r == '<' {
|
||||
continue
|
||||
}
|
||||
descrStr += string([]rune{r})
|
||||
}
|
||||
descrStr := trimMatches([]rune{'=', '<'}, d)
|
||||
|
||||
var descr gotch.DType
|
||||
switch descrStr {
|
||||
|
@ -254,15 +237,7 @@ func (h *Header) parse(header string) (*Header, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
var shapeStr string
|
||||
for _, r := range s {
|
||||
if r == '(' || r == ')' || r == ',' {
|
||||
continue
|
||||
}
|
||||
|
||||
shapeStr += string([]rune{r})
|
||||
}
|
||||
|
||||
shapeStr := trimMatches([]rune{'(', ')', ','}, s)
|
||||
var shape []int64
|
||||
if len(shapeStr) == 0 {
|
||||
shape = make([]int64, 0)
|
||||
|
@ -277,13 +252,29 @@ func (h *Header) parse(header string) (*Header, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return &Header{
|
||||
return &NpyHeader{
|
||||
descr,
|
||||
fortranOrder,
|
||||
shape,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// trimMatches trims all prefix or suffix specified in the input string slice from string data
|
||||
func trimMatches(matches []rune, s string) string {
|
||||
// First: trim leading and trailing space
|
||||
trimStr := strings.TrimSpace(s)
|
||||
for _, m := range matches {
|
||||
if strings.HasPrefix(trimStr, string([]rune{m})) {
|
||||
trimStr = strings.TrimPrefix(trimStr, string([]rune{m}))
|
||||
}
|
||||
if strings.HasSuffix(trimStr, string([]rune{m})) {
|
||||
trimStr = strings.TrimSuffix(trimStr, string([]rune{m}))
|
||||
}
|
||||
}
|
||||
|
||||
return trimStr
|
||||
}
|
||||
|
||||
// ReadNpy reads a .npy file and returns the stored tensor.
|
||||
func ReadNpy(filepath string) (*Tensor, error) {
|
||||
|
||||
|
@ -291,6 +282,7 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
r := bufio.NewReader(f)
|
||||
|
||||
|
@ -299,9 +291,7 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
hd := new(Header)
|
||||
|
||||
header, err := hd.parse(h)
|
||||
header, err := ParseNpyHeader(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -319,3 +309,57 @@ func ReadNpy(filepath string) (*Tensor, error) {
|
|||
|
||||
return OfDataSize(data, header.shape, header.descr)
|
||||
}
|
||||
|
||||
// ReadNpz reads a compressed numpy file (.npz) and returns named tensors
|
||||
func ReadNpz(filePath string) ([]NamedTensor, error) {
|
||||
var namedTensors []NamedTensor
|
||||
r, err := zip.OpenReader(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
|
||||
for _, f := range r.File {
|
||||
basename := f.Name
|
||||
// remove file extension to get tensor name
|
||||
name := strings.TrimSuffix(basename, filepath.Ext(basename))
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
headerStr, err := readHeader(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := ParseNpyHeader(headerStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if header.fortranOrder {
|
||||
err := fmt.Errorf("fortran order not supported.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var data []byte
|
||||
data, err = ioutil.ReadAll(rc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tensor, err := OfDataSize(data, header.shape, header.descr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||
|
||||
// explicitly close before next one
|
||||
rc.Close()
|
||||
data = make([]byte, 0)
|
||||
}
|
||||
|
||||
return namedTensors, nil
|
||||
}
|
||||
|
|
62
tensor/npy_test.go
Normal file
62
tensor/npy_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package tensor_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestNpyHeaderParse(t *testing.T) {
|
||||
|
||||
h1 := "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }"
|
||||
want1 := ts.NewNpyHeader(gotch.Double, false, []int64{128})
|
||||
|
||||
testParse(t, want1, h1)
|
||||
|
||||
h2 := "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }"
|
||||
want2 := ts.NewNpyHeader(gotch.Float, true, []int64{256, 1, 128})
|
||||
|
||||
testParse(t, want2, h2)
|
||||
|
||||
h3, err := ts.ParseNpyHeader(h1)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
testToString(t, h1, h3)
|
||||
|
||||
h4, err := ts.ParseNpyHeader(h2)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
testToString(t, h2, h4)
|
||||
|
||||
h5 := ts.NewNpyHeader(gotch.Int64, false, []int64{})
|
||||
want5 := "{'descr': '<i8', 'fortran_order': False, 'shape': (), }"
|
||||
testToString(t, want5, h5)
|
||||
}
|
||||
|
||||
func testParse(t *testing.T, want *ts.NpyHeader, headerStr string) {
|
||||
got, err := ts.ParseNpyHeader(headerStr)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want: %+v\n", want)
|
||||
t.Errorf("got: %+v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func testToString(t *testing.T, want string, h *ts.NpyHeader) {
|
||||
got, err := h.ToString()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want: %+v\n", want)
|
||||
t.Errorf("got: %+v\n", got)
|
||||
}
|
||||
}
|
|
@ -197,9 +197,67 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
}
|
||||
|
||||
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
||||
func OfDataSize(data []byte, size []int64, dtype gotch.DType) (*Tensor, error) {
|
||||
// TODO: implement
|
||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
||||
|
||||
elementNum := ElementCount(shape)
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
||||
if nbytes != len(data) {
|
||||
err := fmt.Errorf("data and shape mismatched for dtype (%v): byte data (%v) - shape (%v).\n", dtype, len(data), shape)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataPtr, buff := CMalloc(nbytes)
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
|
||||
typ, err := gotch.ToGoType(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var v reflect.Value
|
||||
switch typ.Name() {
|
||||
case "float", "float32":
|
||||
v = reflect.ValueOf(float32(0.1))
|
||||
case "float64":
|
||||
v = reflect.ValueOf(float64(0.1))
|
||||
case "int", "int32":
|
||||
v = reflect.ValueOf(int(1))
|
||||
case "int64":
|
||||
v = reflect.ValueOf(int64(1))
|
||||
case "int8":
|
||||
v = reflect.ValueOf(int8(1))
|
||||
case "uint8":
|
||||
v = reflect.ValueOf(uint8(1))
|
||||
case "bool":
|
||||
v = reflect.ValueOf(false)
|
||||
default:
|
||||
err := fmt.Errorf("unsupported dtype: %v\n", dtype)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = EncodeTensor(buff, v, shape); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buff.Reset()
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
}
|
||||
|
||||
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
||||
|
|
Loading…
Reference in New Issue
Block a user