diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 589a4cc..dd1d25d 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -39,16 +39,12 @@ func runLinear() { bs.ZeroGrad() loss.Backward() - wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)) - bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)) + ts.NoGrad(func() { + ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))) + }) - wsClone := ws.MustShallowClone() - bsClone := bs.MustShallowClone() - - // wsClone.MustAdd_(wsGrad) - // bsClone.MustAdd_(bsGrad) - - testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad)) + testLogits := ds.TestImages.MustMm(ws).MustAdd(bs) testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0}) fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)