gotch/ts/npy.go
Goncalves Henriques, Andre (UG - Computer Science) 9257404edd Move the name of the module
2024-04-21 15:15:00 +01:00

375 lines
7.3 KiB
Go

package ts
import (
"archive/zip"
"bufio"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
"git.andr3h3nriqu3s.com/andr3/gotch"
)
const (
NpyMagicString string = "\x93NUMPY"
NpySuffix string = ".npy"
)
func readHeader(r io.Reader) (string, error) {
magicStr := make([]byte, len(NpyMagicString))
_, err := io.ReadFull(r, magicStr)
if err != nil {
return "", err
}
if string(magicStr) != NpyMagicString {
err = fmt.Errorf("magic string mismatched.\n")
return "", err
}
version := make([]byte, 2)
_, err = io.ReadFull(r, version)
if err != nil {
return "", err
}
var headerLenLength int
switch version[0] {
case 1:
headerLenLength = 2
case 2:
headerLenLength = 4
default:
err = fmt.Errorf("Unsupported version: %v\n", version[0])
}
headerLen := make([]byte, headerLenLength)
_, err = io.ReadFull(r, headerLen)
if err != nil {
return "", err
}
var hLen int = 0
for i := len(headerLen) - 1; i >= 0; i-- {
hLen = hLen*256 + int(headerLen[i])
}
header := make([]byte, hLen)
_, err = io.ReadFull(r, header)
if err != nil {
return "", err
}
return string(header), nil
}
type NpyHeader struct {
descr gotch.DType
fortranOrder bool
shape []int64
}
// NewHeader creates Header from input data
//
// NOTE. This is mainly for unit test purpose
func NewNpyHeader(dtype gotch.DType, fo bool, shape []int64) *NpyHeader {
return &NpyHeader{
descr: dtype,
fortranOrder: fo,
shape: shape,
}
}
func (h *NpyHeader) ToString() (string, error) {
var fortranOrder string = "False"
if h.fortranOrder {
fortranOrder = "True"
}
var shapeStr []string
for _, v := range h.shape {
shapeStr = append(shapeStr, fmt.Sprintf("%v", v))
}
shape := strings.Join(shapeStr, ",")
var descr string
switch h.descr {
case gotch.Half:
descr = "f2"
case gotch.Float:
descr = "f4"
case gotch.Double:
descr = "f8"
case gotch.Int:
descr = "i4"
case gotch.Int64:
descr = "i8"
case gotch.Int16:
descr = "i2"
case gotch.Int8:
descr = "i1"
case gotch.Uint8:
descr = "u1"
default:
err := fmt.Errorf("Unsupported kind: %v\n", h.descr)
return "", err
}
if len(h.shape) == 1 {
shape += ","
}
headStr := fmt.Sprintf("{'descr': '<%v', 'fortran_order': %v, 'shape': (%v), }", descr, fortranOrder, shape)
return headStr, nil
}
// ParseNpyHeader parses the given npy header string.
//
// A typical example would be:
// {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
func ParseNpyHeader(header string) (*NpyHeader, error) {
// trim all prefix or suffix patterns
trimHeader := trimMatches([]rune{'{', '}', ','}, header)
var parts []string
startIdx := 0
var cntParenthesis int64 = 0
for i, r := range trimHeader {
switch r {
case '(':
cntParenthesis += 1
case ')':
cntParenthesis -= 1
case ',':
if cntParenthesis == 0 {
parts = append(parts, trimHeader[startIdx:i])
startIdx = i + 1
}
default:
// do nothing
}
}
parts = append(parts, header[startIdx:])
var partMap map[string]string = make(map[string]string)
for _, part := range parts {
strings.TrimSpace(part)
p := strings.TrimSpace(part)
if len(p) > 0 {
kv := strings.Split(p, ":")
if len(kv) == 2 {
key := trimMatches([]rune{'\''}, kv[0])
value := trimMatches([]rune{'\''}, kv[1])
partMap[key] = value
}
}
}
var fortranOrder bool
fo, ok := partMap["fortran_order"]
if !ok {
fortranOrder = false
}
switch fo {
case "False":
fortranOrder = false
case "True":
fortranOrder = true
default:
err := fmt.Errorf("unknown fortran_order: %v\n", fo)
return nil, err
}
d, ok := partMap["descr"]
if !ok {
err := fmt.Errorf("no descr in header.\n")
return nil, err
}
if len(d) == 0 {
err := fmt.Errorf("empty descr.\n")
return nil, err
}
if strings.HasPrefix(d, ">") {
err := fmt.Errorf("little-endian descr: %v\n", d)
return nil, err
}
descrStr := trimMatches([]rune{'=', '<'}, d)
var descr gotch.DType
switch descrStr {
case "f2":
descr = gotch.Float // use Go float32 as there's no float16
case "f4":
descr = gotch.Float
case "f8":
descr = gotch.Double
case "i4":
descr = gotch.Int
case "i8":
descr = gotch.Int64
case "i2":
descr = gotch.Int16
case "i1":
descr = gotch.Int8
case "u1":
descr = gotch.Uint8
default:
err := fmt.Errorf("unrecognized descr: %v\n", descr)
return nil, err
}
s, ok := partMap["shape"]
if !ok {
err := fmt.Errorf("no shape in header.\n")
return nil, err
}
shapeStr := trimMatches([]rune{'(', ')', ','}, s)
var shape []int64
if len(shapeStr) == 0 {
shape = make([]int64, 0)
} else {
size := strings.Split(shapeStr, ",")
for _, v := range size {
dim, err := strconv.Atoi(strings.TrimSpace(v))
if err != nil {
return nil, err
}
shape = append(shape, int64(dim))
}
}
return &NpyHeader{
descr,
fortranOrder,
shape,
}, nil
}
// trimMatches trims all prefix or suffix specified in the input string slice from string data
func trimMatches(matches []rune, s string) string {
// First: trim leading and trailing space
trimStr := strings.TrimSpace(s)
for _, m := range matches {
if strings.HasPrefix(trimStr, string([]rune{m})) {
trimStr = strings.TrimPrefix(trimStr, string([]rune{m}))
}
if strings.HasSuffix(trimStr, string([]rune{m})) {
trimStr = strings.TrimSuffix(trimStr, string([]rune{m}))
}
}
return trimStr
}
// ReadNpy reads a .npy file and returns the stored tensor.
func ReadNpy(filepath string) (*Tensor, error) {
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
r := bufio.NewReader(f)
h, err := readHeader(r)
if err != nil {
return nil, err
}
header, err := ParseNpyHeader(h)
if err != nil {
return nil, err
}
if header.fortranOrder {
err := fmt.Errorf("fortran order not supported.\n")
return nil, err
}
// Read all the rest
var data []byte
data, err = ioutil.ReadAll(r)
if err != nil {
return nil, err
}
// NOTE(TT.). case tensor 1 element with shape = []
if len(data) > 0 && len(header.shape) == 0 {
header.shape = []int64{1}
}
return OfDataSize(data, header.shape, header.descr)
}
// ReadNpz reads a compressed numpy file (.npz) and returns named tensors
func ReadNpz(filePath string) ([]NamedTensor, error) {
var namedTensors []NamedTensor
r, err := zip.OpenReader(filePath)
if err != nil {
return nil, err
}
defer r.Close()
for _, f := range r.File {
basename := f.Name
// remove file extension to get tensor name
name := strings.TrimSuffix(basename, filepath.Ext(basename))
rc, err := f.Open()
if err != nil {
return nil, err
}
headerStr, err := readHeader(rc)
if err != nil {
return nil, err
}
header, err := ParseNpyHeader(headerStr)
if err != nil {
return nil, err
}
if header.fortranOrder {
err := fmt.Errorf("fortran order not supported.\n")
return nil, err
}
var data []byte
data, err = ioutil.ReadAll(rc)
if err != nil {
return nil, err
}
// NOTE(TT.). case tensor 1 element with shape = []
if len(data) > 0 && len(header.shape) == 0 {
header.shape = []int64{1}
}
tensor, err := OfDataSize(data, header.shape, header.descr)
if err != nil {
return nil, err
}
namedTensors = append(namedTensors, NamedTensor{name, tensor})
// explicitly close before next one
rc.Close()
data = make([]byte, 0)
}
return namedTensors, nil
}