Fix(nn/varstore): temp fixed add not add new variable to TrainableVariable

WIP(example/mnist): nn working but take up memory. Need to free after each epoch
This commit is contained in:
sugarme 2020-06-19 20:22:51 +10:00
parent 3f569fdcee
commit a0faf0799d
3 changed files with 27 additions and 22 deletions

View File

@ -16,7 +16,7 @@ const (
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 3
epochsNN = 50
batchSizeNN = 256
LrNN = 1e-3
@ -51,19 +51,16 @@ func runNN() {
log.Fatal(err)
}
bsClone := l.Bs.MustShallowClone()
for epoch := 0; epoch < epochsNN; epoch++ {
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
opt.BackwardStep(loss)
fmt.Printf("Bs vals: %v\n", bsClone.MustToString(int64(1)))
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
fmt.Printf("Loss: %v\n", lossVal)
}
}

View File

@ -24,25 +24,28 @@ func testOptimizer() {
vs := nn.NewVarStore(gotch.CPU)
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
if err != nil {
log.Fatal("Failed building SGD optimizer")
}
cfg := nn.LinearConfig{
WsInit: nn.NewConstInit(0.001),
BsInit: nn.NewConstInit(0.001),
Bias: true,
}
// fmt.Printf("Number of trainable variables: %v\n", vs.Len())
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
// fmt.Printf("Trainable variables at app: %v\n", vs.TrainableVariable())
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Initial Loss: %.3f\n", initialLoss)
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
if err != nil {
log.Fatal("Failed building SGD optimizer")
}
for i := 0; i < 50; i++ {
loss = xs.Apply(linear)
// loss = linear.Forward(xs)
// loss = xs.Apply(linear)
loss = linear.Forward(xs)
loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt())
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))

View File

@ -3,7 +3,6 @@ package nn
// Optimizers to be used for gradient-descent based training.
import (
// "github.com/sugarme/gotch"
"log"
ts "github.com/sugarme/gotch/tensor"
@ -41,13 +40,18 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
return retVal, err
}
vs.variables.mutex.Lock()
defer vs.variables.mutex.Unlock()
// vs.variables.mutex.Lock()
// defer vs.variables.mutex.Unlock()
if len(vs.variables.TrainableVariable) > 0 {
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
return retVal, err
}
// fmt.Printf("Trainable Variables: \n:%v", len(vs.Variables()))
var parameters []ts.Tensor
for _, v := range vs.Variables() {
parameters = append(parameters, v)
}
// if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
if err = opt.AddParameters(parameters); err != nil {
return retVal, err
}
return Optimizer{
@ -224,6 +228,7 @@ func (opt *Optimizer) Step() {
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
opt.addMissingVariables()
err := opt.opt.ZeroGrad()
if err != nil {
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)