diff --git a/nn/varstore.go b/nn/varstore.go index e6cc6ee..14466fd 100644 --- a/nn/varstore.go +++ b/nn/varstore.go @@ -3,6 +3,7 @@ package nn import ( "fmt" "log" + "reflect" "strings" "sync" @@ -171,13 +172,22 @@ func (vs *VarStore) Load(filepath string) (err error) { // for tsName, _ := range vs.Vars.NamedVariables { for tsName := range vs.Vars.NamedVariables { - var currTs ts.Tensor - var ok bool - if currTs, ok = namedTensorsMap[tsName]; !ok { + + // missing variable + currTs, ok := namedTensorsMap[tsName] + if !ok { err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName) return err } + // mismatched shape + destShape := currTs.MustSize() + sourceShape := vs.Vars.NamedVariables[tsName].MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape) + return err + } + ts.NoGrad(func() { vs.Vars.NamedVariables[tsName].Copy_(currTs) }) @@ -217,11 +227,21 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) { for tsName := range vs.Vars.NamedVariables { var currTs ts.Tensor var ok bool + + // missing variable if currTs, ok = namedTensorsMap[tsName]; !ok { - // missing missingVariables = append(missingVariables, tsName) } + // mismatched shape + destShape := currTs.MustSize() + sourceShape := vs.Vars.NamedVariables[tsName].MustSize() + if !reflect.DeepEqual(destShape, sourceShape) { + fmt.Printf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape) + missingVariables = append(missingVariables, tsName) + continue + } + ts.NoGrad(func() { vs.Vars.NamedVariables[tsName].Copy_(currTs) }) @@ -274,7 +294,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) { srcNamedVariables := src.Vars.NamedVariables device := vs.device - for k, _ := range vs.Vars.NamedVariables { + for k := range vs.Vars.NamedVariables { if _, ok := srcNamedVariables[k]; !ok { err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k) return err