fix(nn/varstore): fixed incorrect Load function and replaced '|' with '.'

This commit is contained in:
sugarme 2020-07-02 12:33:18 +10:00
parent 32d4a68e12
commit 76f4b41ad1
4 changed files with 29 additions and 19 deletions

View File

@ -48,17 +48,19 @@ func main() {
vs := nn.NewVarStore(gotch.CPU)
net := vision.ResNet18NoFinalLayer(vs.Root())
for k, _ := range vs.Vars.NamedVariables {
fmt.Printf("First variable name: %v\n", k)
}
panic("Stop")
// for k, _ := range vs.Vars.NamedVariables {
// fmt.Printf("First variable name: %v\n", k)
// }
fmt.Printf("vs variables: %v\n", vs.Variables())
fmt.Printf("vs num of variables: %v\n", vs.Len())
err = vs.Load(weights)
if err != nil {
log.Fatal(err)
}
panic(net)
fmt.Printf("Net infor: %v\n", net)
panic("stop")
}

View File

@ -9,6 +9,7 @@ package libtch
import "C"
import (
"strings"
"unsafe"
)
@ -257,6 +258,7 @@ func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
//export callback_fn
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
tsName := C.GoString(name)
tsName = strings.ReplaceAll(tsName, "|", ".")
namedCtensor := NamedCtensor{
Name: tsName,
Ctensor: ctensor,

View File

@ -159,24 +159,28 @@ func (vs *VarStore) Load(filepath string) (err error) {
return err
}
var namedTensorsMap map[string]ts.Tensor = make(map[string]ts.Tensor, 0)
for _, namedTensor := range namedTensors {
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
}
// Match and in-place copy value (update) from newly loaded tensors
// to existing named tensors if name is matched. Throw error otherwise.
vs.Vars.mutex.Lock()
defer vs.Vars.mutex.Unlock()
for _, namedTs := range namedTensors {
for tsName, _ := range vs.Vars.NamedVariables {
var currTs ts.Tensor
var ok bool
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
if currTs, ok = namedTensorsMap[tsName]; !ok {
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
return err
}
ts.NoGrad(func() {
ts.Copy_(currTs, namedTs.Tensor)
vs.Vars.NamedVariables[tsName].Copy_(currTs)
})
}
return nil
}

View File

@ -1,6 +1,8 @@
package vision
import (
"fmt"
nn "github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
@ -23,8 +25,8 @@ func downSample(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
if stride != 1 || cIn != cOut {
seq := nn.SeqT()
seq.Add(conv2d(path, cIn, cOut, 1, 0, stride))
seq.Add(conv2d(path.Sub("0"), cIn, cOut, 1, 0, stride))
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
} else {
retVal = nn.SeqT()
}
@ -54,8 +56,8 @@ func basicLayer(path nn.Path, cIn, cOut, stride, cnt int64) (retVal ts.ModuleT)
layer := nn.SeqT()
layer.Add(basicBlock(path.Sub("0"), cIn, cOut, stride))
for blockIndex := 0; blockIndex < int(cnt); blockIndex++ {
layer.Add(basicBlock(path.Sub(string(blockIndex)), cOut, cOut, 1))
for blockIndex := 1; blockIndex < int(cnt); blockIndex++ {
layer.Add(basicBlock(path.Sub(fmt.Sprint(blockIndex)), cOut, cOut, 1))
}
return layer
@ -65,9 +67,9 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
conv1 := conv2d(path.Sub("conv1"), 3, 64, 7, 3, 2)
bn1 := nn.BatchNorm2D(path.Sub("bn1"), 64, nn.DefaultBatchNormConfig())
layer1 := basicLayer(path.Sub("layer1"), 64, 64, 1, c1)
layer2 := basicLayer(path.Sub("layer2"), 64, 64, 1, c2)
layer3 := basicLayer(path.Sub("layer3"), 64, 64, 1, c3)
layer4 := basicLayer(path.Sub("layer4"), 64, 64, 1, c4)
layer2 := basicLayer(path.Sub("layer2"), 64, 128, 2, c2)
layer3 := basicLayer(path.Sub("layer3"), 128, 256, 2, c3)
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
if nclasses > 0 {
linearConfig := nn.DefaultLinearConfig()
@ -125,7 +127,7 @@ func bottleneckLayer(path nn.Path, cIn, cOut, stride, cnt int64) (retVal ts.Modu
layer := nn.SeqT()
layer.Add(bottleneckBlock(path.Sub("0"), cIn, cOut, stride, 4))
for blockIndex := 0; blockIndex < int(cnt); blockIndex++ {
layer.Add(bottleneckBlock(path.Sub(string(blockIndex)), (cOut * 4), cOut, 1, 4))
layer.Add(bottleneckBlock(path.Sub(fmt.Sprint(blockIndex)), (cOut * 4), cOut, 1, 4))
}
return layer