fix(nn/varstore): Load/PartialLoad methods added mismatched shape check

This commit is contained in:
sugarme 2020-08-06 09:18:39 +10:00
parent f2fd373edf
commit 5edded6ca1

View File

@ -3,6 +3,7 @@ package nn
import (
"fmt"
"log"
"reflect"
"strings"
"sync"
@ -171,13 +172,22 @@ func (vs *VarStore) Load(filepath string) (err error) {
// for tsName, _ := range vs.Vars.NamedVariables {
for tsName := range vs.Vars.NamedVariables {
var currTs ts.Tensor
var ok bool
if currTs, ok = namedTensorsMap[tsName]; !ok {
// missing variable
currTs, ok := namedTensorsMap[tsName]
if !ok {
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
return err
}
// mismatched shape
destShape := currTs.MustSize()
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
if !reflect.DeepEqual(destShape, sourceShape) {
err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
return err
}
ts.NoGrad(func() {
vs.Vars.NamedVariables[tsName].Copy_(currTs)
})
@ -217,11 +227,21 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
for tsName := range vs.Vars.NamedVariables {
var currTs ts.Tensor
var ok bool
// missing variable
if currTs, ok = namedTensorsMap[tsName]; !ok {
// missing
missingVariables = append(missingVariables, tsName)
}
// mismatched shape
destShape := currTs.MustSize()
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
if !reflect.DeepEqual(destShape, sourceShape) {
fmt.Printf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
missingVariables = append(missingVariables, tsName)
continue
}
ts.NoGrad(func() {
vs.Vars.NamedVariables[tsName].Copy_(currTs)
})
@ -274,7 +294,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
srcNamedVariables := src.Vars.NamedVariables
device := vs.device
for k, _ := range vs.Vars.NamedVariables {
for k := range vs.Vars.NamedVariables {
if _, ok := srcNamedVariables[k]; !ok {
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
return err