From 5801be4072850857183c294969855e59bff73ed1 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 18 Nov 2020 01:22:36 +1100 Subject: [PATCH] update to ts.SaveMultiNew --- example/convert-model/main.go | 8 +++++--- tensor/npy.go | 4 ++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/example/convert-model/main.go b/example/convert-model/main.go index e93bce5..bdff712 100644 --- a/example/convert-model/main.go +++ b/example/convert-model/main.go @@ -8,7 +8,9 @@ import ( ) func main() { - filepath := "../../data/convert-model/bert/model.npz" + // filepath := "../../data/convert-model/bert/model.npz" + // filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz" + filepath := "/home/sugarme/rustbert/bert/model.npz" namedTensors, err := ts.ReadNpz(filepath) if err != nil { @@ -27,8 +29,8 @@ func main() { // fmt.Printf("%v", namedTensors[70].Tensor) - outputFile := "../../data/convert-model/bert/model.gt" - err = ts.SaveMulti(namedTensors, outputFile) + outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt" + err = ts.SaveMultiNew(namedTensors, outputFile) if err != nil { log.Fatal(err) } diff --git a/tensor/npy.go b/tensor/npy.go index 0007dfb..5d433ea 100644 --- a/tensor/npy.go +++ b/tensor/npy.go @@ -354,6 +354,10 @@ func ReadNpz(filePath string) ([]NamedTensor, error) { return nil, err } + if name == "bert.encoder.layer.0.attention.output.dense.weight" { + fmt.Printf("%4.2f", tensor) + } + namedTensors = append(namedTensors, NamedTensor{name, tensor}) // explicitly close before next one