From 8770fc378b9e515907a2ecf107422393228adec4 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 18 Nov 2020 13:07:08 +1100 Subject: [PATCH] Updated tensor.OfDataSize to write directly binary data; Updated ReadNpy and ReadNpz --- example/convert-model/main.go | 17 ++--------------- tensor/npy.go | 7 +------ tensor/tensor.go | 35 +++++------------------------------ 3 files changed, 8 insertions(+), 51 deletions(-) diff --git a/example/convert-model/main.go b/example/convert-model/main.go index bdff712..a800a16 100644 --- a/example/convert-model/main.go +++ b/example/convert-model/main.go @@ -8,9 +8,8 @@ import ( ) func main() { - // filepath := "../../data/convert-model/bert/model.npz" - // filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz" - filepath := "/home/sugarme/rustbert/bert/model.npz" + // NOTE. Python script to save model to .npz can be found at https://github.com/sugarme/pytorch-pretrained/bert/bert-base-uncased-to-npz.py + filepath := "../../data/convert-model/bert/model.npz" namedTensors, err := ts.ReadNpz(filepath) if err != nil { @@ -18,21 +17,9 @@ func main() { } fmt.Printf("Num of named tensor: %v\n", len(namedTensors)) - /* - * for _, nt := range namedTensors { - * // fmt.Printf("%q\n", nt.Name) - * if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" { - * fmt.Printf("%0.3f", nt.Tensor) - * } - * } - * */ - - // fmt.Printf("%v", namedTensors[70].Tensor) - outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt" err = ts.SaveMultiNew(namedTensors, outputFile) if err != nil { log.Fatal(err) } - } diff --git a/tensor/npy.go b/tensor/npy.go index 5d433ea..f7a11c9 100644 --- a/tensor/npy.go +++ b/tensor/npy.go @@ -15,7 +15,6 @@ import ( ) const ( - // NpyMagicString string = `\x93NUMPY` NpyMagicString string = "\x93NUMPY" NpySuffix string = ".npy" ) @@ -103,7 +102,7 @@ func (h *NpyHeader) ToString() (string, error) { var descr string switch h.descr.Kind().String() { - // case "float32": + // case "float16": // NOTE. No float16 in Go primary types. TODO. implement // descr = "f2" case "float32": descr = "f4" @@ -354,10 +353,6 @@ func ReadNpz(filePath string) ([]NamedTensor, error) { return nil, err } - if name == "bert.encoder.layer.0.attention.output.dense.weight" { - fmt.Printf("%4.2f", tensor) - } - namedTensors = append(namedTensors, NamedTensor{name, tensor}) // explicitly close before next one diff --git a/tensor/tensor.go b/tensor/tensor.go index ba2536f..704c4ef 100644 --- a/tensor/tensor.go +++ b/tensor/tensor.go @@ -196,7 +196,7 @@ func OfSlice(data interface{}) (*Tensor, error) { return &Tensor{ctensor}, nil } -// OfDataSize creates Tensor from input byte data and specidied shape and dtype. +// OfDataSize creates Tensor from input byte data, shape and dtype. func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) { elementNum := ElementCount(shape) @@ -215,33 +215,7 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) dataPtr, buff := CMalloc(nbytes) defer C.free(unsafe.Pointer(dataPtr)) - typ, err := gotch.ToGoType(dtype) - if err != nil { - return nil, err - } - - var v reflect.Value - switch typ.Name() { - case "float", "float32": - v = reflect.ValueOf(float32(0.1)) - case "float64": - v = reflect.ValueOf(float64(0.1)) - case "int", "int32": - v = reflect.ValueOf(int(1)) - case "int64": - v = reflect.ValueOf(int64(1)) - case "int8": - v = reflect.ValueOf(int8(1)) - case "uint8": - v = reflect.ValueOf(uint8(1)) - case "bool": - v = reflect.ValueOf(false) - default: - err := fmt.Errorf("unsupported dtype: %v\n", dtype) - return nil, err - } - - if err = EncodeTensor(buff, v, shape); err != nil { + if err := binary.Write(buff, nativeEndian, data); err != nil { return nil, err } @@ -255,8 +229,6 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) return nil, err } - buff.Reset() - return &Tensor{ctensor}, nil } @@ -836,6 +808,7 @@ type NamedTensor struct { // SaveMulti saves some named tensors to a file // // The file format is the same as the one used by the PyTorch C++ API. +// NOTE. This method is depreciated and will be replaced with `SaveMultiNew` func SaveMulti(namedTensors []NamedTensor, path string) error { var ctensors []lib.Ctensor var names []string @@ -854,6 +827,8 @@ func SaveMulti(namedTensors []NamedTensor, path string) error { } // MustSaveMulti saves some named tensors to a file. It will panic if error +// +// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew` func MustSaveMulti(namedTensors []NamedTensor, path string) { err := SaveMulti(namedTensors, path) if err != nil {