update to ts.SaveMultiNew

This commit is contained in:
sugarme 2020-11-18 01:22:36 +11:00
parent b4228528bb
commit 5801be4072
2 changed files with 9 additions and 3 deletions

View File

@ -8,7 +8,9 @@ import (
) )
func main() { func main() {
filepath := "../../data/convert-model/bert/model.npz" // filepath := "../../data/convert-model/bert/model.npz"
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
filepath := "/home/sugarme/rustbert/bert/model.npz"
namedTensors, err := ts.ReadNpz(filepath) namedTensors, err := ts.ReadNpz(filepath)
if err != nil { if err != nil {
@ -27,8 +29,8 @@ func main() {
// fmt.Printf("%v", namedTensors[70].Tensor) // fmt.Printf("%v", namedTensors[70].Tensor)
outputFile := "../../data/convert-model/bert/model.gt" outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
err = ts.SaveMulti(namedTensors, outputFile) err = ts.SaveMultiNew(namedTensors, outputFile)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -354,6 +354,10 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
return nil, err return nil, err
} }
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
fmt.Printf("%4.2f", tensor)
}
namedTensors = append(namedTensors, NamedTensor{name, tensor}) namedTensors = append(namedTensors, NamedTensor{name, tensor})
// explicitly close before next one // explicitly close before next one