Fix(nn/varstore): temp fixed add
not add new variable to TrainableVariable
WIP(example/mnist): nn working but take up memory. Need to free after each epoch
This commit is contained in:
parent
3f569fdcee
commit
a0faf0799d
|
@ -16,7 +16,7 @@ const (
|
|||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 3
|
||||
epochsNN = 50
|
||||
batchSizeNN = 256
|
||||
|
||||
LrNN = 1e-3
|
||||
|
@ -51,19 +51,16 @@ func runNN() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
bsClone := l.Bs.MustShallowClone()
|
||||
|
||||
for epoch := 0; epoch < epochsNN; epoch++ {
|
||||
|
||||
loss := net.Forward(ds.TrainImages).CrossEntropyForLogits(ds.TrainLabels)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
fmt.Printf("Bs vals: %v\n", bsClone.MustToString(int64(1)))
|
||||
|
||||
lossVal := loss.MustShallowClone().MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
testAccuracy := net.Forward(ds.TestImages).AccuracyForLogits(ds.TestLabels).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, lossVal, testAccuracy*100)
|
||||
|
||||
fmt.Printf("Loss: %v\n", lossVal)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -24,25 +24,28 @@ func testOptimizer() {
|
|||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
log.Fatal("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
cfg := nn.LinearConfig{
|
||||
WsInit: nn.NewConstInit(0.001),
|
||||
BsInit: nn.NewConstInit(0.001),
|
||||
Bias: true,
|
||||
}
|
||||
|
||||
// fmt.Printf("Number of trainable variables: %v\n", vs.Len())
|
||||
linear := nn.NewLinear(vs.Root(), 1, 1, cfg)
|
||||
// fmt.Printf("Trainable variables at app: %v\n", vs.TrainableVariable())
|
||||
|
||||
loss := xs.Apply(linear).MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
initialLoss := loss.MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
fmt.Printf("Initial Loss: %.3f\n", initialLoss)
|
||||
|
||||
opt, err := nn.DefaultSGDConfig().Build(vs, 1e-2)
|
||||
if err != nil {
|
||||
log.Fatal("Failed building SGD optimizer")
|
||||
}
|
||||
|
||||
for i := 0; i < 50; i++ {
|
||||
loss = xs.Apply(linear)
|
||||
// loss = linear.Forward(xs)
|
||||
// loss = xs.Apply(linear)
|
||||
loss = linear.Forward(xs)
|
||||
loss = loss.MustMseLoss(ys, ts.ReductionMean.ToInt())
|
||||
|
||||
fmt.Printf("Loss: %.3f\n", loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}))
|
||||
|
|
|
@ -3,7 +3,6 @@ package nn
|
|||
// Optimizers to be used for gradient-descent based training.
|
||||
|
||||
import (
|
||||
// "github.com/sugarme/gotch"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -41,13 +40,18 @@ func defaultBuild(config OptimizerConfig, vs VarStore, lr float64) (retVal Optim
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
vs.variables.mutex.Lock()
|
||||
defer vs.variables.mutex.Unlock()
|
||||
// vs.variables.mutex.Lock()
|
||||
// defer vs.variables.mutex.Unlock()
|
||||
|
||||
if len(vs.variables.TrainableVariable) > 0 {
|
||||
if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
// fmt.Printf("Trainable Variables: \n:%v", len(vs.Variables()))
|
||||
var parameters []ts.Tensor
|
||||
for _, v := range vs.Variables() {
|
||||
parameters = append(parameters, v)
|
||||
}
|
||||
|
||||
// if err = opt.AddParameters(vs.variables.TrainableVariable); err != nil {
|
||||
if err = opt.AddParameters(parameters); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Optimizer{
|
||||
|
@ -224,6 +228,7 @@ func (opt *Optimizer) Step() {
|
|||
func (opt *Optimizer) BackwardStep(loss ts.Tensor) {
|
||||
|
||||
opt.addMissingVariables()
|
||||
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
|
||||
|
|
Loading…
Reference in New Issue
Block a user