WIP(example/mnis): NoGrad working with anonymous function. Backward() still not working

This commit is contained in:
sugarme 2020-06-16 02:56:09 +10:00
parent cf61333fab
commit da950fe881

View File

@ -39,16 +39,12 @@ func runLinear() {
bs.ZeroGrad()
loss.Backward()
wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))
bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))
ts.NoGrad(func() {
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
})
wsClone := ws.MustShallowClone()
bsClone := bs.MustShallowClone()
// wsClone.MustAdd_(wsGrad)
// bsClone.MustAdd_(bsGrad)
testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad))
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)