fix(nn/varstore): fixed incorrect Load function and replaced '|' with '.'
This commit is contained in:
parent
32d4a68e12
commit
76f4b41ad1
|
@ -48,17 +48,19 @@ func main() {
|
|||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := vision.ResNet18NoFinalLayer(vs.Root())
|
||||
|
||||
for k, _ := range vs.Vars.NamedVariables {
|
||||
fmt.Printf("First variable name: %v\n", k)
|
||||
}
|
||||
|
||||
panic("Stop")
|
||||
// for k, _ := range vs.Vars.NamedVariables {
|
||||
// fmt.Printf("First variable name: %v\n", k)
|
||||
// }
|
||||
fmt.Printf("vs variables: %v\n", vs.Variables())
|
||||
fmt.Printf("vs num of variables: %v\n", vs.Len())
|
||||
|
||||
err = vs.Load(weights)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
panic(net)
|
||||
fmt.Printf("Net infor: %v\n", net)
|
||||
|
||||
panic("stop")
|
||||
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ package libtch
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
|
@ -257,6 +258,7 @@ func AtLoadCallback(filename string, dataPtr unsafe.Pointer) {
|
|||
//export callback_fn
|
||||
func callback_fn(dataPtr unsafe.Pointer, name *C.char, ctensor C.tensor) {
|
||||
tsName := C.GoString(name)
|
||||
tsName = strings.ReplaceAll(tsName, "|", ".")
|
||||
namedCtensor := NamedCtensor{
|
||||
Name: tsName,
|
||||
Ctensor: ctensor,
|
||||
|
|
|
@ -159,24 +159,28 @@ func (vs *VarStore) Load(filepath string) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
var namedTensorsMap map[string]ts.Tensor = make(map[string]ts.Tensor, 0)
|
||||
for _, namedTensor := range namedTensors {
|
||||
namedTensorsMap[namedTensor.Name] = namedTensor.Tensor
|
||||
}
|
||||
|
||||
// Match and in-place copy value (update) from newly loaded tensors
|
||||
// to existing named tensors if name is matched. Throw error otherwise.
|
||||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, namedTs := range namedTensors {
|
||||
for tsName, _ := range vs.Vars.NamedVariables {
|
||||
var currTs ts.Tensor
|
||||
var ok bool
|
||||
if currTs, ok = vs.Vars.NamedVariables[namedTs.Name]; !ok {
|
||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", namedTs.Name)
|
||||
if currTs, ok = namedTensorsMap[tsName]; !ok {
|
||||
err = fmt.Errorf("Cannot find tensor with name: %v in variable store. \n", tsName)
|
||||
return err
|
||||
}
|
||||
|
||||
ts.NoGrad(func() {
|
||||
ts.Copy_(currTs, namedTs.Tensor)
|
||||
vs.Vars.NamedVariables[tsName].Copy_(currTs)
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package vision
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
nn "github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -23,8 +25,8 @@ func downSample(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
|
|||
|
||||
if stride != 1 || cIn != cOut {
|
||||
seq := nn.SeqT()
|
||||
seq.Add(conv2d(path, cIn, cOut, 1, 0, stride))
|
||||
|
||||
seq.Add(conv2d(path.Sub("0"), cIn, cOut, 1, 0, stride))
|
||||
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
|
||||
} else {
|
||||
retVal = nn.SeqT()
|
||||
}
|
||||
|
@ -54,8 +56,8 @@ func basicLayer(path nn.Path, cIn, cOut, stride, cnt int64) (retVal ts.ModuleT)
|
|||
layer := nn.SeqT()
|
||||
layer.Add(basicBlock(path.Sub("0"), cIn, cOut, stride))
|
||||
|
||||
for blockIndex := 0; blockIndex < int(cnt); blockIndex++ {
|
||||
layer.Add(basicBlock(path.Sub(string(blockIndex)), cOut, cOut, 1))
|
||||
for blockIndex := 1; blockIndex < int(cnt); blockIndex++ {
|
||||
layer.Add(basicBlock(path.Sub(fmt.Sprint(blockIndex)), cOut, cOut, 1))
|
||||
}
|
||||
|
||||
return layer
|
||||
|
@ -65,9 +67,9 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
|||
conv1 := conv2d(path.Sub("conv1"), 3, 64, 7, 3, 2)
|
||||
bn1 := nn.BatchNorm2D(path.Sub("bn1"), 64, nn.DefaultBatchNormConfig())
|
||||
layer1 := basicLayer(path.Sub("layer1"), 64, 64, 1, c1)
|
||||
layer2 := basicLayer(path.Sub("layer2"), 64, 64, 1, c2)
|
||||
layer3 := basicLayer(path.Sub("layer3"), 64, 64, 1, c3)
|
||||
layer4 := basicLayer(path.Sub("layer4"), 64, 64, 1, c4)
|
||||
layer2 := basicLayer(path.Sub("layer2"), 64, 128, 2, c2)
|
||||
layer3 := basicLayer(path.Sub("layer3"), 128, 256, 2, c3)
|
||||
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
|
||||
|
||||
if nclasses > 0 {
|
||||
linearConfig := nn.DefaultLinearConfig()
|
||||
|
@ -125,7 +127,7 @@ func bottleneckLayer(path nn.Path, cIn, cOut, stride, cnt int64) (retVal ts.Modu
|
|||
layer := nn.SeqT()
|
||||
layer.Add(bottleneckBlock(path.Sub("0"), cIn, cOut, stride, 4))
|
||||
for blockIndex := 0; blockIndex < int(cnt); blockIndex++ {
|
||||
layer.Add(bottleneckBlock(path.Sub(string(blockIndex)), (cOut * 4), cOut, 1, 4))
|
||||
layer.Add(bottleneckBlock(path.Sub(fmt.Sprint(blockIndex)), (cOut * 4), cOut, 1, 4))
|
||||
}
|
||||
|
||||
return layer
|
||||
|
|
Loading…
Reference in New Issue
Block a user