fix(nn/varstore): Load/PartialLoad methods added mismatched shape check
This commit is contained in:
parent
f2fd373edf
commit
5edded6ca1
|
@ -3,6 +3,7 @@ package nn
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
@ -171,13 +172,22 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
|||
|
||||
// for tsName, _ := range vs.Vars.NamedVariables {
|
||||
for tsName := range vs.Vars.NamedVariables {
|
||||
var currTs ts.Tensor
|
||||
var ok bool
|
||||
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
||||
|
||||
// missing variable
|
||||
currTs, ok := namedTensorsMap[tsName]
|
||||
if !ok {
|
||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||
return err
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
destShape := currTs.MustSize()
|
||||
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
err = fmt.Errorf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||
return err
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
|
@ -217,11 +227,21 @@ func (vs *VarStore) LoadPartial(filepath string) (retVal []string, err error) {
|
|||
for tsName := range vs.Vars.NamedVariables {
|
||||
var currTs ts.Tensor
|
||||
var ok bool
|
||||
|
||||
// missing variable
|
||||
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
||||
// missing
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
}
|
||||
|
||||
// mismatched shape
|
||||
destShape := currTs.MustSize()
|
||||
sourceShape := vs.Vars.NamedVariables[tsName].MustSize()
|
||||
if !reflect.DeepEqual(destShape, sourceShape) {
|
||||
fmt.Printf("Mismatched shape error for variable name: %v - At store: %v - At source %v\n", tsName, destShape, sourceShape)
|
||||
missingVariables = append(missingVariables, tsName)
|
||||
continue
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
|
@ -274,7 +294,7 @@ func (vs *VarStore) Copy(src VarStore) (err error) {
|
|||
srcNamedVariables := src.Vars.NamedVariables
|
||||
device := vs.device
|
||||
|
||||
for k, _ := range vs.Vars.NamedVariables {
|
||||
for k := range vs.Vars.NamedVariables {
|
||||
if _, ok := srcNamedVariables[k]; !ok {
|
||||
err = fmt.Errorf("VarStore copy error: cannot find %v in the source var store.\n", k)
|
||||
return err
|
||||
|
|
Loading…
Reference in New Issue
Block a user