update to ts.SaveMultiNew
This commit is contained in:
parent
b4228528bb
commit
5801be4072
|
@ -8,7 +8,9 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
filepath := "../../data/convert-model/bert/model.npz"
|
// filepath := "../../data/convert-model/bert/model.npz"
|
||||||
|
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
|
||||||
|
filepath := "/home/sugarme/rustbert/bert/model.npz"
|
||||||
|
|
||||||
namedTensors, err := ts.ReadNpz(filepath)
|
namedTensors, err := ts.ReadNpz(filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -27,8 +29,8 @@ func main() {
|
||||||
|
|
||||||
// fmt.Printf("%v", namedTensors[70].Tensor)
|
// fmt.Printf("%v", namedTensors[70].Tensor)
|
||||||
|
|
||||||
outputFile := "../../data/convert-model/bert/model.gt"
|
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
|
||||||
err = ts.SaveMulti(namedTensors, outputFile)
|
err = ts.SaveMultiNew(namedTensors, outputFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -354,6 +354,10 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
|
||||||
|
fmt.Printf("%4.2f", tensor)
|
||||||
|
}
|
||||||
|
|
||||||
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||||
|
|
||||||
// explicitly close before next one
|
// explicitly close before next one
|
||||||
|
|
Loading…
Reference in New Issue
Block a user