WIP: npy.go
This commit is contained in:
parent
9a4646f331
commit
94197e0710
259
tensor/npy.go
259
tensor/npy.go
|
@ -4,8 +4,12 @@ import (
|
|||
"bufio"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -13,17 +17,9 @@ const (
|
|||
NpySuffix string = ".npy"
|
||||
)
|
||||
|
||||
func readHeader(filepath string) (string, error) {
|
||||
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
func readHeader(r io.Reader) (string, error) {
|
||||
magicStr := make([]byte, len(NpyMagicString))
|
||||
r := bufio.NewReader(f)
|
||||
|
||||
_, err = io.ReadFull(r, magicStr)
|
||||
_, err := io.ReadFull(r, magicStr)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
@ -72,7 +68,7 @@ func readHeader(filepath string) (string, error) {
|
|||
}
|
||||
|
||||
type Header struct {
|
||||
descr reflect.Kind
|
||||
descr gotch.DType
|
||||
fortranOrder bool
|
||||
shape []int64
|
||||
}
|
||||
|
@ -83,4 +79,243 @@ func (h *Header) toString() (string, error) {
|
|||
if h.fortranOrder {
|
||||
fortranOrder = "True"
|
||||
}
|
||||
|
||||
var shapeStr []string
|
||||
for _, v := range h.shape {
|
||||
shapeStr = append(shapeStr, string(v))
|
||||
}
|
||||
|
||||
shape := strings.Join(shapeStr, ",")
|
||||
|
||||
var descr string
|
||||
switch h.descr.Kind().String() {
|
||||
// case "float32":
|
||||
// descr = "f2"
|
||||
case "float32":
|
||||
descr = "f4"
|
||||
case "float64":
|
||||
descr = "f8"
|
||||
case "int":
|
||||
descr = "i4"
|
||||
case "int64":
|
||||
descr = "i8"
|
||||
case "int16":
|
||||
descr = "i2"
|
||||
case "int8":
|
||||
descr = "i1"
|
||||
case "uint8":
|
||||
descr = "u1"
|
||||
default:
|
||||
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
|
||||
return "", err
|
||||
}
|
||||
|
||||
if len(shape) > 0 {
|
||||
shape = shape + ","
|
||||
}
|
||||
|
||||
headStr := fmt.Sprintf("{'descr': '<%v', 'fortran_order': %v, 'shape': (%v), }", descr, fortranOrder, shape)
|
||||
|
||||
return headStr, nil
|
||||
}
|
||||
|
||||
// Parser for the npy header, a typical example would be:
|
||||
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
|
||||
func (h *Header) parse(header string) (*Header, error) {
|
||||
|
||||
// trim matches
|
||||
var chars []rune
|
||||
for _, r := range header {
|
||||
if r == '{' || r == '}' || r == ',' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
|
||||
chars = append(chars, r)
|
||||
}
|
||||
|
||||
trimHeader := string(chars)
|
||||
|
||||
var parts []string
|
||||
startIdx := 0
|
||||
var cntParenthesis int64 = 0
|
||||
for i, r := range trimHeader {
|
||||
switch r {
|
||||
case '(':
|
||||
cntParenthesis += 1
|
||||
case ')':
|
||||
cntParenthesis -= 1
|
||||
case ',':
|
||||
if cntParenthesis == 0 {
|
||||
parts = append(parts, trimHeader[startIdx:i])
|
||||
startIdx = i + 1
|
||||
}
|
||||
default:
|
||||
// do nothing
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, header[startIdx:])
|
||||
|
||||
var partMap map[string]string = make(map[string]string)
|
||||
|
||||
for _, part := range parts {
|
||||
strings.TrimSpace(part)
|
||||
p := strings.TrimSpace(part)
|
||||
if len(p) > 0 {
|
||||
kv := strings.Split(p, ":")
|
||||
if len(kv) == 2 {
|
||||
var key, value string
|
||||
rKey := []rune(kv[0])
|
||||
for _, r := range rKey {
|
||||
if r == '\'' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
key = key + string([]rune{r})
|
||||
}
|
||||
|
||||
rValue := []rune(kv[1])
|
||||
for _, r := range rValue {
|
||||
if r == '\'' || r == ' ' {
|
||||
continue
|
||||
}
|
||||
value = value + string([]rune{r})
|
||||
}
|
||||
|
||||
partMap[key] = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var fortranOrder bool
|
||||
fo, ok := partMap["fortran_order"]
|
||||
if !ok {
|
||||
fortranOrder = false
|
||||
}
|
||||
switch fo {
|
||||
case "False":
|
||||
fortranOrder = false
|
||||
case "True":
|
||||
fortranOrder = true
|
||||
default:
|
||||
err := fmt.Errorf("unknown fortran_order: %v\n", fo)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
d, ok := partMap["descr"]
|
||||
if !ok {
|
||||
err := fmt.Errorf("no descr in header.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(d) == 0 {
|
||||
err := fmt.Errorf("empty descr.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if strings.HasPrefix(d, ">") {
|
||||
err := fmt.Errorf("little-endian descr: %v\n", d)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var descrStr string
|
||||
for _, r := range d {
|
||||
if r == '=' || r == '<' {
|
||||
continue
|
||||
}
|
||||
descrStr += string([]rune{r})
|
||||
}
|
||||
|
||||
var descr gotch.DType
|
||||
switch descrStr {
|
||||
case "f2":
|
||||
descr = gotch.Float // use Go float32 as there's no float16
|
||||
case "f4":
|
||||
descr = gotch.Float
|
||||
case "f8":
|
||||
descr = gotch.Double
|
||||
case "i4":
|
||||
descr = gotch.Int
|
||||
case "i8":
|
||||
descr = gotch.Int64
|
||||
case "i2":
|
||||
descr = gotch.Int16
|
||||
case "i1":
|
||||
descr = gotch.Int8
|
||||
case "u1":
|
||||
descr = gotch.Uint8
|
||||
default:
|
||||
err := fmt.Errorf("unrecognized descr: %v\n", descr)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s, ok := partMap["shape"]
|
||||
if !ok {
|
||||
err := fmt.Errorf("no shape in header.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var shapeStr string
|
||||
for _, r := range s {
|
||||
if r == '(' || r == ')' || r == ',' {
|
||||
continue
|
||||
}
|
||||
|
||||
shapeStr += string([]rune{r})
|
||||
}
|
||||
|
||||
var shape []int64
|
||||
if len(shapeStr) == 0 {
|
||||
shape = make([]int64, 0)
|
||||
} else {
|
||||
size := strings.Split(shapeStr, ",")
|
||||
for _, v := range size {
|
||||
dim, err := strconv.Atoi(strings.TrimSpace(v))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shape = append(shape, int64(dim))
|
||||
}
|
||||
}
|
||||
|
||||
return &Header{
|
||||
descr,
|
||||
fortranOrder,
|
||||
shape,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ReadNpy reads a .npy file and returns the stored tensor.
|
||||
func ReadNpy(filepath string) (*Tensor, error) {
|
||||
|
||||
f, err := os.Open(filepath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
r := bufio.NewReader(f)
|
||||
|
||||
h, err := readHeader(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hd := new(Header)
|
||||
|
||||
header, err := hd.parse(h)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if header.fortranOrder {
|
||||
err := fmt.Errorf("fortran order not supported.\n")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read all the rest
|
||||
var data []byte
|
||||
data, err = ioutil.ReadAll(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return OfDataSize(data, header.shape, header.descr)
|
||||
}
|
||||
|
|
|
@ -196,6 +196,23 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
return &Tensor{ctensor}, nil
|
||||
}
|
||||
|
||||
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
||||
func OfDataSize(data []byte, size []int64, dtype gotch.DType) (*Tensor, error) {
|
||||
// TODO: implement
|
||||
|
||||
}
|
||||
|
||||
// MustOfDataSize create Tensor from input byte data and specified shape and dtype
|
||||
// or panic if error
|
||||
func MustOfDataSize(data []byte, size []int64, dtype gotch.DType) *Tensor {
|
||||
ts, err := OfDataSize(data, size, dtype)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
// MustOfSlice create a tensor from slice of data. It will be panic if error.
|
||||
func MustOfSlice(data interface{}) *Tensor {
|
||||
ts, err := OfSlice(data)
|
||||
|
|
Loading…
Reference in New Issue
Block a user