feat(example/transfer-learning): completed
This commit is contained in:
parent
4f57855c9b
commit
d23a606a64
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
|
@ -38,29 +39,45 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Dataset: %v\n", dataset)
|
||||
fmt.Printf("Train shape: %v\n", dataset.TrainImages.MustSize())
|
||||
fmt.Printf("Train shape: %v\n", dataset.TrainLabels.MustSize())
|
||||
fmt.Printf("Test shape: %v\n", dataset.TestImages.MustSize())
|
||||
fmt.Printf("Test shape: %v\n", dataset.TestLabels.MustSize())
|
||||
fmt.Println("Dataset loaded")
|
||||
|
||||
// Create the model and load the weights from the file.
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
net := vision.ResNet18NoFinalLayer(vs.Root())
|
||||
|
||||
// for k, _ := range vs.Vars.NamedVariables {
|
||||
// fmt.Printf("First variable name: %v\n", k)
|
||||
// }
|
||||
fmt.Printf("vs variables: %v\n", vs.Variables())
|
||||
fmt.Printf("vs num of variables: %v\n", vs.Len())
|
||||
|
||||
err = vs.Load(weights)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Net infor: %v\n", net)
|
||||
fmt.Println("Weights loaded")
|
||||
|
||||
panic("stop")
|
||||
// Pre-compute the final activations.
|
||||
|
||||
linear := nn.NewLinear(vs.Root(), 512, dataset.Labels, *nn.DefaultLinearConfig())
|
||||
sgd, err := nn.DefaultSGDConfig().Build(vs, 1e-3)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
trainImages := ts.NoGrad1(func() (retVal interface{}) {
|
||||
return dataset.TrainImages.ApplyT(net, true)
|
||||
}).(ts.Tensor)
|
||||
|
||||
testImages := ts.NoGrad1(func() (retVal interface{}) {
|
||||
return dataset.TestImages.ApplyT(net, true)
|
||||
}).(ts.Tensor)
|
||||
|
||||
fmt.Println("start training...")
|
||||
|
||||
for epoch := 1; epoch <= 1000; epoch++ {
|
||||
|
||||
predicted := trainImages.Apply(linear)
|
||||
loss := predicted.CrossEntropyForLogits(dataset.TrainLabels)
|
||||
sgd.BackwardStep(loss)
|
||||
loss.MustDrop()
|
||||
|
||||
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
|
||||
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Values()[0]*100)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -942,6 +942,21 @@ func NoGrad(fn interface{}) {
|
|||
|
||||
}
|
||||
|
||||
func NoGrad1(fn func() interface{}) (retVal interface{}) {
|
||||
newTs := NewTensor()
|
||||
newTs.Drop()
|
||||
|
||||
// Switch off Grad
|
||||
prev := MustGradSetEnabled(false)
|
||||
|
||||
retVal = fn()
|
||||
|
||||
// Switch on Grad
|
||||
_ = MustGradSetEnabled(prev)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
||||
type NoGradGuard struct {
|
||||
enabled bool
|
||||
|
|
|
@ -27,6 +27,7 @@ func downSample(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
|
|||
seq := nn.SeqT()
|
||||
seq.Add(conv2d(path.Sub("0"), cIn, cOut, 1, 0, stride))
|
||||
seq.Add(nn.BatchNorm2D(path.Sub("1"), cOut, nn.DefaultBatchNormConfig()))
|
||||
retVal = seq
|
||||
} else {
|
||||
retVal = nn.SeqT()
|
||||
}
|
||||
|
@ -46,8 +47,10 @@ func basicBlock(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {
|
|||
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
ys := xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).Apply(conv2).ApplyT(bn2, train)
|
||||
downsampleLayer := xs.ApplyT(downsample, train).MustAdd(ys, true)
|
||||
res := downsampleLayer.MustRelu(true)
|
||||
|
||||
return xs.ApplyT(downsample, train).MustAdd(ys, true).MustRelu(true)
|
||||
return res
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -72,17 +75,55 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
|||
layer4 := basicLayer(path.Sub("layer4"), 256, 512, 2, c4)
|
||||
|
||||
if nclasses > 0 {
|
||||
linearConfig := nn.DefaultLinearConfig()
|
||||
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, *linearConfig)
|
||||
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
return xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true).ApplyT(layer1, train).ApplyT(layer2, train).ApplyT(layer3, train).ApplyT(layer4, train).MustAdaptiveAvgPool2D([]int64{1, 1}).FlatView().ApplyOpt(ts.WithModule(fc))
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
c1 := xs.Apply(conv1)
|
||||
bn1 := c1.ApplyT(bn1, train)
|
||||
c1.MustDrop()
|
||||
relu := bn1.MustRelu(true)
|
||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||
l1 := maxpool.ApplyT(layer1, train)
|
||||
l2 := l1.ApplyT(layer2, train)
|
||||
l1.MustDrop()
|
||||
l3 := l2.ApplyT(layer3, train)
|
||||
l2.MustDrop()
|
||||
l4 := l3.ApplyT(layer4, train)
|
||||
l3.MustDrop()
|
||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
||||
l4.MustDrop()
|
||||
fv := avgpool.FlatView()
|
||||
avgpool.MustDrop()
|
||||
|
||||
// final layer
|
||||
linearConfig := nn.DefaultLinearConfig()
|
||||
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, *linearConfig)
|
||||
|
||||
retVal = fv.ApplyOpt(ts.WithModule(fc))
|
||||
|
||||
return retVal
|
||||
})
|
||||
|
||||
} else {
|
||||
// No final layer
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||
return xs.Apply(conv1).ApplyT(bn1, train).MustRelu(false).MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true).ApplyT(layer1, train).ApplyT(layer2, train).ApplyT(layer3, train).ApplyT(layer4, train).MustAdaptiveAvgPool2D([]int64{1, 1}).FlatView()
|
||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
c1 := xs.Apply(conv1)
|
||||
bn1 := c1.ApplyT(bn1, train)
|
||||
c1.MustDrop()
|
||||
relu := bn1.MustRelu(true)
|
||||
maxpool := relu.MustMaxPool2D([]int64{3, 3}, []int64{2, 2}, []int64{1, 1}, []int64{1, 1}, false, true)
|
||||
l1 := maxpool.ApplyT(layer1, train)
|
||||
l2 := l1.ApplyT(layer2, train)
|
||||
l1.MustDrop()
|
||||
l3 := l2.ApplyT(layer3, train)
|
||||
l2.MustDrop()
|
||||
l4 := l3.ApplyT(layer4, train)
|
||||
l3.MustDrop()
|
||||
avgpool := l4.MustAdaptiveAvgPool2D([]int64{1, 1})
|
||||
l4.MustDrop()
|
||||
retVal = avgpool.FlatView()
|
||||
avgpool.MustDrop()
|
||||
|
||||
return retVal
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user