update to ts.SaveMultiNew
This commit is contained in:
parent
b4228528bb
commit
5801be4072
|
@ -8,7 +8,9 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
filepath := "../../data/convert-model/bert/model.npz"
|
||||
// filepath := "../../data/convert-model/bert/model.npz"
|
||||
// filepath := "/home/sugarme/projects/pytorch-pretrained/bert/model.npz"
|
||||
filepath := "/home/sugarme/rustbert/bert/model.npz"
|
||||
|
||||
namedTensors, err := ts.ReadNpz(filepath)
|
||||
if err != nil {
|
||||
|
@ -27,8 +29,8 @@ func main() {
|
|||
|
||||
// fmt.Printf("%v", namedTensors[70].Tensor)
|
||||
|
||||
outputFile := "../../data/convert-model/bert/model.gt"
|
||||
err = ts.SaveMulti(namedTensors, outputFile)
|
||||
outputFile := "/home/sugarme/projects/transformer/data/bert/model.gt"
|
||||
err = ts.SaveMultiNew(namedTensors, outputFile)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -354,6 +354,10 @@ func ReadNpz(filePath string) ([]NamedTensor, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if name == "bert.encoder.layer.0.attention.output.dense.weight" {
|
||||
fmt.Printf("%4.2f", tensor)
|
||||
}
|
||||
|
||||
namedTensors = append(namedTensors, NamedTensor{name, tensor})
|
||||
|
||||
// explicitly close before next one
|
||||
|
|
Loading…
Reference in New Issue
Block a user