feat(example/transfer-learning): completed

This commit is contained in:
sugarme 2020-07-02 16:26:54 +10:00
parent 4f57855c9b
commit d23a606a64
3 changed files with 93 additions and 20 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
@ -38,29 +39,45 @@ func main() {
log.Fatal(err)
}
fmt.Printf("Dataset: %v\n", dataset)
fmt.Printf("Train shape: %v\n", dataset.TrainImages.MustSize())
fmt.Printf("Train shape: %v\n", dataset.TrainLabels.MustSize())
fmt.Printf("Test shape: %v\n", dataset.TestImages.MustSize())
fmt.Printf("Test shape: %v\n", dataset.TestLabels.MustSize())
fmt.Println("Dataset loaded")
// Create the model and load the weights from the file.
vs := nn.NewVarStore(gotch.CPU)
net := vision.ResNet18NoFinalLayer(vs.Root())
// for k, _ := range vs.Vars.NamedVariables {
// fmt.Printf("First variable name: %v\n", k)
// }
fmt.Printf("vs variables: %v\n", vs.Variables())
fmt.Printf("vs num of variables: %v\n", vs.Len())
err = vs.Load(weights)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Net infor: %v\n", net)
fmt.Println("Weights loaded")
panic("stop")
// Pre-compute the final activations.
linear := nn.NewLinear(vs.Root(), 512, dataset.Labels, *nn.DefaultLinearConfig())
sgd, err := nn.DefaultSGDConfig().Build(vs, 1e-3)
if err != nil {
log.Fatal(err)
}
trainImages := ts.NoGrad1(func() (retVal interface{}) {
return dataset.TrainImages.ApplyT(net, true)
}).(ts.Tensor)
testImages := ts.NoGrad1(func() (retVal interface{}) {
return dataset.TestImages.ApplyT(net, true)
}).(ts.Tensor)
fmt.Println("start training...")
for epoch := 1; epoch <= 1000; epoch++ {
predicted := trainImages.Apply(linear)
loss := predicted.CrossEntropyForLogits(dataset.TrainLabels)
sgd.BackwardStep(loss)
loss.MustDrop()
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Values()[0]*100)
}
}

View File

@ -942,6 +942,21 @@ func NoGrad(fn interface{}) {
}
func NoGrad1(fn func() interface{}) (retVal interface{}) {
newTs := NewTensor()
newTs.Drop()
// Switch off Grad
prev := MustGradSetEnabled(false)
retVal = fn()
// Switch on Grad
_ = MustGradSetEnabled(prev)
return retVal
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
type NoGradGuard struct {
enabled bool

View File

@ -27,6 +27,7 @@ func downSample(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
seq := nn.SeqT()
seq.Add(conv2d(path.Sub("0"), cIn, cOut, 1, 0, stride))
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
retVal = seq
} else {
retVal = nn.SeqT()
}
@ -46,8 +47,10 @@ func basicBlock(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
res := downsampleLayer.MustRelu(true)
return xs.ApplyT(downsample, train).MustAdd(ys, true).MustRelu(true)
return res
})
}
@ -72,17 +75,55 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
if nclasses > 0 {
linearConfig := nn.DefaultLinearConfig()
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, *linearConfig)
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
return xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true).ApplyT(layer1, train).ApplyT(layer2, train).ApplyT(layer3, train).ApplyT(layer4, train).MustAdaptiveAvgPool2D([]int64{1, 1}).FlatView().ApplyOpt(ts.WithModule(fc))
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
l4.MustDrop()
fv := avgpool.FlatView()
avgpool.MustDrop()
// final layer
linearConfig := nn.DefaultLinearConfig()
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, *linearConfig)
retVal = fv.ApplyOpt(ts.WithModule(fc))
return retVal
})
} else {
// No final layer
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
return xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true).ApplyT(layer1, train).ApplyT(layer2, train).ApplyT(layer3, train).ApplyT(layer4, train).MustAdaptiveAvgPool2D([]int64{1, 1}).FlatView()
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
c1 := xs.Apply(conv1)
bn1 := c1.ApplyT(bn1, train)
c1.MustDrop()
relu := bn1.MustRelu(true)
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
l1 := maxpool.ApplyT(layer1, train)
l2 := l1.ApplyT(layer2, train)
l1.MustDrop()
l3 := l2.ApplyT(layer3, train)
l2.MustDrop()
l4 := l3.ApplyT(layer4, train)
l3.MustDrop()
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
l4.MustDrop()
retVal = avgpool.FlatView()
avgpool.MustDrop()
return retVal
})
}
}