Updated tensor.OfDataSize to write directly binary data; Updated ReadNpy and ReadNpz
This commit is contained in:
parent
5801be4072
commit
8770fc378b
|
@ -8,9 +8,8 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// filepath := "../../data/convert-model/bert/model.npz"
|
// NOTE. Python script to save model to .npz can be found at https://github.com/sugarme/pytorch-pretrained/bert/bert-base-uncased-to-npz.py
|
||||||
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
|
filepath := "../../data/convert-model/bert/model.npz"
|
||||||
filepath := "/home/sugarme/rustbert/bert/model.npz"
|
|
||||||
|
|
||||||
namedTensors, err := ts.ReadNpz(filepath)
|
namedTensors, err := ts.ReadNpz(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -18,21 +17,9 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
|
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
|
||||||
/*
|
|
||||||
* for _, nt := range namedTensors {
|
|
||||||
* // fmt.Printf("%q\n", nt.Name)
|
|
||||||
* if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" {
|
|
||||||
* fmt.Printf("%0.3f", nt.Tensor)
|
|
||||||
* }
|
|
||||||
* }
|
|
||||||
* */
|
|
||||||
|
|
||||||
// fmt.Printf("%v", namedTensors[70].Tensor)
|
|
||||||
|
|
||||||
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
|
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
|
||||||
err = ts.SaveMultiNew(namedTensors, outputFile)
|
err = ts.SaveMultiNew(namedTensors, outputFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NpyMagicString string = `\x93NUMPY`
|
|
||||||
NpyMagicString string = "\x93NUMPY"
|
NpyMagicString string = "\x93NUMPY"
|
||||||
NpySuffix string = ".npy"
|
NpySuffix string = ".npy"
|
||||||
)
|
)
|
||||||
|
@ -103,7 +102,7 @@ func (h *NpyHeader) ToString() (string, error) {
|
||||||
|
|
||||||
var descr string
|
var descr string
|
||||||
switch h.descr.Kind().String() {
|
switch h.descr.Kind().String() {
|
||||||
// case "float32":
|
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
|
||||||
// descr = "f2"
|
// descr = "f2"
|
||||||
case "float32":
|
case "float32":
|
||||||
descr = "f4"
|
descr = "f4"
|
||||||
|
@ -354,10 +353,6 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
|
|
||||||
fmt.Printf("%4.2f", tensor)
|
|
||||||
}
|
|
||||||
|
|
||||||
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||||
|
|
||||||
// explicitly close before next one
|
// explicitly close before next one
|
||||||
|
|
|
@ -196,7 +196,7 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
||||||
return &Tensor{ctensor}, nil
|
return &Tensor{ctensor}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
||||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
||||||
|
|
||||||
elementNum := ElementCount(shape)
|
elementNum := ElementCount(shape)
|
||||||
|
@ -215,33 +215,7 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
|
||||||
dataPtr, buff := CMalloc(nbytes)
|
dataPtr, buff := CMalloc(nbytes)
|
||||||
defer C.free(unsafe.Pointer(dataPtr))
|
defer C.free(unsafe.Pointer(dataPtr))
|
||||||
|
|
||||||
typ, err := gotch.ToGoType(dtype)
|
if err := binary.Write(buff, nativeEndian, data); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var v reflect.Value
|
|
||||||
switch typ.Name() {
|
|
||||||
case "float", "float32":
|
|
||||||
v = reflect.ValueOf(float32(0.1))
|
|
||||||
case "float64":
|
|
||||||
v = reflect.ValueOf(float64(0.1))
|
|
||||||
case "int", "int32":
|
|
||||||
v = reflect.ValueOf(int(1))
|
|
||||||
case "int64":
|
|
||||||
v = reflect.ValueOf(int64(1))
|
|
||||||
case "int8":
|
|
||||||
v = reflect.ValueOf(int8(1))
|
|
||||||
case "uint8":
|
|
||||||
v = reflect.ValueOf(uint8(1))
|
|
||||||
case "bool":
|
|
||||||
v = reflect.ValueOf(false)
|
|
||||||
default:
|
|
||||||
err := fmt.Errorf("unsupported dtype: %v\n", dtype)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = EncodeTensor(buff, v, shape); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -255,8 +229,6 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
buff.Reset()
|
|
||||||
|
|
||||||
return &Tensor{ctensor}, nil
|
return &Tensor{ctensor}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -836,6 +808,7 @@ type NamedTensor struct {
|
||||||
// SaveMulti saves some named tensors to a file
|
// SaveMulti saves some named tensors to a file
|
||||||
//
|
//
|
||||||
// The file format is the same as the one used by the PyTorch C++ API.
|
// The file format is the same as the one used by the PyTorch C++ API.
|
||||||
|
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
|
||||||
func SaveMulti(namedTensors []NamedTensor, path string) error {
|
func SaveMulti(namedTensors []NamedTensor, path string) error {
|
||||||
var ctensors []lib.Ctensor
|
var ctensors []lib.Ctensor
|
||||||
var names []string
|
var names []string
|
||||||
|
@ -854,6 +827,8 @@ func SaveMulti(namedTensors []NamedTensor, path string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// MustSaveMulti saves some named tensors to a file. It will panic if error
|
// MustSaveMulti saves some named tensors to a file. It will panic if error
|
||||||
|
//
|
||||||
|
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
|
||||||
func MustSaveMulti(namedTensors []NamedTensor, path string) {
|
func MustSaveMulti(namedTensors []NamedTensor, path string) {
|
||||||
err := SaveMulti(namedTensors, path)
|
err := SaveMulti(namedTensors, path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user