fix(nn/varstore): Load/PartialLoad methods added mismatched shape check
This commit is contained in:
parent
f2fd373edf
commit
5edded6ca1
|
@ -3,6 +3,7 @@ package nn
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
@ -171,13 +172,22 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
||||||
|
|
||||||
// for tsName, _ := range vs.Vars.NamedVariables {
|
// for tsName, _ := range vs.Vars.NamedVariables {
|
||||||
for tsName := range vs.Vars.NamedVariables {
|
for tsName := range vs.Vars.NamedVariables {
|
||||||
var currTs ts.Tensor
|
|
||||||
var ok bool
|
// missing variable
|
||||||
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
currTs, ok := namedTensorsMap[tsName]
|
||||||
|
if !ok {
|
||||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
|
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
destShape := currTs.MustSize()
|
||||||
|
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||||
})
|
})
|
||||||
|
@ -217,11 +227,21 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
||||||
for tsName := range vs.Vars.NamedVariables {
|
for tsName := range vs.Vars.NamedVariables {
|
||||||
var currTs ts.Tensor
|
var currTs ts.Tensor
|
||||||
var ok bool
|
var ok bool
|
||||||
|
|
||||||
|
// missing variable
|
||||||
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
||||||
// missing
|
|
||||||
missingVariables = append(missingVariables, tsName)
|
missingVariables = append(missingVariables, tsName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mismatched shape
|
||||||
|
destShape := currTs.MustSize()
|
||||||
|
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||||
|
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||||
|
fmt.Printf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||||
|
missingVariables = append(missingVariables, tsName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
ts.NoGrad(func() {
|
ts.NoGrad(func() {
|
||||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||||
})
|
})
|
||||||
|
@ -274,7 +294,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
||||||
srcNamedVariables := src.Vars.NamedVariables
|
srcNamedVariables := src.Vars.NamedVariables
|
||||||
device := vs.device
|
device := vs.device
|
||||||
|
|
||||||
for k, _ := range vs.Vars.NamedVariables {
|
for k := range vs.Vars.NamedVariables {
|
||||||
if _, ok := srcNamedVariables[k]; !ok {
|
if _, ok := srcNamedVariables[k]; !ok {
|
||||||
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue
Block a user