Updated tensor.OfDataSize to write directly binary data; Updated ReadNpy and ReadNpz
This commit is contained in:
parent
5801be4072
commit
8770fc378b
|
@ -8,9 +8,8 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
// filepath := "../../data/convert-model/bert/model.npz"
|
||||
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
|
||||
filepath := "/home/sugarme/rustbert/bert/model.npz"
|
||||
// NOTE. Python script to save model to .npz can be found at https://github.com/sugarme/pytorch-pretrained/bert/bert-base-uncased-to-npz.py
|
||||
filepath := "../../data/convert-model/bert/model.npz"
|
||||
|
||||
namedTensors, err := ts.ReadNpz(filepath)
|
||||
if err != nil {
|
||||
|
@ -18,21 +17,9 @@ func main() {
|
|||
}
|
||||
|
||||
fmt.Printf("Num of named tensor: %v\n", len(namedTensors))
|
||||
/*
|
||||
* for _, nt := range namedTensors {
|
||||
* // fmt.Printf("%q\n", nt.Name)
|
||||
* if nt.Name == "bert.encoder.layer.1.attention.output.LayerNorm.weight" {
|
||||
* fmt.Printf("%0.3f", nt.Tensor)
|
||||
* }
|
||||
* }
|
||||
* */
|
||||
|
||||
// fmt.Printf("%v", namedTensors[70].Tensor)
|
||||
|
||||
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
|
||||
err = ts.SaveMultiNew(namedTensors, outputFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -15,7 +15,6 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
// NpyMagicString string = `\x93NUMPY`
|
||||
NpyMagicString string = "\x93NUMPY"
|
||||
NpySuffix string = ".npy"
|
||||
)
|
||||
|
@ -103,7 +102,7 @@ func (h *NpyHeader) ToString() (string, error) {
|
|||
|
||||
var descr string
|
||||
switch h.descr.Kind().String() {
|
||||
// case "float32":
|
||||
// case "float16": // NOTE. No float16 in Go primary types. TODO. implement
|
||||
// descr = "f2"
|
||||
case "float32":
|
||||
descr = "f4"
|
||||
|
@ -354,10 +353,6 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
|
||||
fmt.Printf("%4.2f", tensor)
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||
|
||||
// explicitly close before next one
|
||||
|
|
|
@ -196,7 +196,7 @@ func OfSlice(data interface{}) (*Tensor, error) {
|
|||
return &Tensor{ctensor}, nil
|
||||
}
|
||||
|
||||
// OfDataSize creates Tensor from input byte data and specidied shape and dtype.
|
||||
// OfDataSize creates Tensor from input byte data, shape and dtype.
|
||||
func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error) {
|
||||
|
||||
elementNum := ElementCount(shape)
|
||||
|
@ -215,33 +215,7 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
|
|||
dataPtr, buff := CMalloc(nbytes)
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
|
||||
typ, err := gotch.ToGoType(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var v reflect.Value
|
||||
switch typ.Name() {
|
||||
case "float", "float32":
|
||||
v = reflect.ValueOf(float32(0.1))
|
||||
case "float64":
|
||||
v = reflect.ValueOf(float64(0.1))
|
||||
case "int", "int32":
|
||||
v = reflect.ValueOf(int(1))
|
||||
case "int64":
|
||||
v = reflect.ValueOf(int64(1))
|
||||
case "int8":
|
||||
v = reflect.ValueOf(int8(1))
|
||||
case "uint8":
|
||||
v = reflect.ValueOf(uint8(1))
|
||||
case "bool":
|
||||
v = reflect.ValueOf(false)
|
||||
default:
|
||||
err := fmt.Errorf("unsupported dtype: %v\n", dtype)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = EncodeTensor(buff, v, shape); err != nil {
|
||||
if err := binary.Write(buff, nativeEndian, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -255,8 +229,6 @@ func OfDataSize(data []byte, shape []int64, dtype gotch.DType) (*Tensor, error)
|
|||
return nil, err
|
||||
}
|
||||
|
||||
buff.Reset()
|
||||
|
||||
return &Tensor{ctensor}, nil
|
||||
}
|
||||
|
||||
|
@ -836,6 +808,7 @@ type NamedTensor struct {
|
|||
// SaveMulti saves some named tensors to a file
|
||||
//
|
||||
// The file format is the same as the one used by the PyTorch C++ API.
|
||||
// NOTE. This method is depreciated and will be replaced with `SaveMultiNew`
|
||||
func SaveMulti(namedTensors []NamedTensor, path string) error {
|
||||
var ctensors []lib.Ctensor
|
||||
var names []string
|
||||
|
@ -854,6 +827,8 @@ func SaveMulti(namedTensors []NamedTensor, path string) error {
|
|||
}
|
||||
|
||||
// MustSaveMulti saves some named tensors to a file. It will panic if error
|
||||
//
|
||||
// NOTE. This method is depreciated and will be replaced with `MustSaveMultiNew`
|
||||
func MustSaveMulti(namedTensors []NamedTensor, path string) {
|
||||
err := SaveMulti(namedTensors, path)
|
||||
if err != nil {
|
||||
|
|
Loading…
Reference in New Issue
Block a user