Updated tensor.OfDataSize to write directly binary data; Updated ReadNpy and ReadNpz

This commit is contained in:
sugarme 2020-11-18 13:07:08 +11:00
parent 5801be4072
commit 8770fc378b
3 changed files with 8 additions and 51 deletions

View File

@ -8,9 +8,8 @@ import (
)
func main() {
// filepath := "../../data/convert-model/bert/model.npz"
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
filepath := "/home/sugarme/rustbert/bert/model.npz"
// NOTE. Python script to save model to .npz can be found at https://github.com/sugarme/pytorch-pretrained/bert/bert-base-uncased-to-npz.py
filepath := "../../data/convert-model/bert/model.npz"
namedTensors, err := ts.ReadNpz(filepath)
if err != nil {
@ -18,21 +17,9 @@ func main() {
}
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
/*
* for _, nt := range namedTensors {
* // fmt.Printf("%q\n", nt.Name)
* if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" {
* fmt.Printf("%0.3f", nt.Tensor)
* }
* }
* */
// fmt.Printf("%v", namedTensors[70].Tensor)
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
err = ts.SaveMultiNew(namedTensors, outputFile)
if err != nil {
log.Fatal(err)
}
}

View File

@ -15,7 +15,6 @@ import (
)
const (
// NpyMagicString string = `\x93NUMPY`
NpyMagicString string = "\x93NUMPY"
NpySuffix string = ".npy"
)
@ -103,7 +102,7 @@ func (h *NpyHeader) ToString() (string, error) {
var descr string
switch h.descr.Kind().String() {
// case "float32":
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
// descr = "f2"
case "float32":
descr = "f4"
@ -354,10 +353,6 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
return nil, err
}
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
fmt.Printf("%4.2f", tensor)
}
namedTensors = append(namedTensors, NamedTensor{name, tensor})
// explicitly close before next one

View File

@ -196,7 +196,7 @@ func OfSlice(data interface{}) (*Tensor, error) {
return &Tensor{ctensor}, nil
}
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
// OfDataSize creates Tensor from input byte data, shape and dtype.
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
elementNum := ElementCount(shape)
@ -215,33 +215,7 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
typ, err := gotch.ToGoType(dtype)
if err != nil {
return nil, err
}
var v reflect.Value
switch typ.Name() {
case "float", "float32":
v = reflect.ValueOf(float32(0.1))
case "float64":
v = reflect.ValueOf(float64(0.1))
case "int", "int32":
v = reflect.ValueOf(int(1))
case "int64":
v = reflect.ValueOf(int64(1))
case "int8":
v = reflect.ValueOf(int8(1))
case "uint8":
v = reflect.ValueOf(uint8(1))
case "bool":
v = reflect.ValueOf(false)
default:
err := fmt.Errorf("unsupported dtype: %v\n", dtype)
return nil, err
}
if err = EncodeTensor(buff, v, shape); err != nil {
if err := binary.Write(buff, nativeEndian, data); err != nil {
return nil, err
}
@ -255,8 +229,6 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
return nil, err
}
buff.Reset()
return &Tensor{ctensor}, nil
}
@ -836,6 +808,7 @@ type NamedTensor struct {
// SaveMulti saves some named tensors to a file
//
// The file format is the same as the one used by the PyTorch C++ API.
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
func SaveMulti(namedTensors []NamedTensor, path string) error {
var ctensors []lib.Ctensor
var names []string
@ -854,6 +827,8 @@ func SaveMulti(namedTensors []NamedTensor, path string) error {
}
// MustSaveMulti saves some named tensors to a file. It will panic if error
//
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
func MustSaveMulti(namedTensors []NamedTensor, path string) {
err := SaveMulti(namedTensors, path)
if err != nil {