WIP(example/mnis): NoGrad working with anonymous function. Backward() still not working
This commit is contained in:
parent
cf61333fab
commit
da950fe881
|
@ -39,16 +39,12 @@ func runLinear() {
|
|||
bs.ZeroGrad()
|
||||
loss.Backward()
|
||||
|
||||
wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||
bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||
ts.NoGrad(func() {
|
||||
ws.MustAdd_(ws.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
bs.MustAdd_(bs.MustGrad().MustMul1(ts.FloatScalar(-1.0)))
|
||||
})
|
||||
|
||||
wsClone := ws.MustShallowClone()
|
||||
bsClone := bs.MustShallowClone()
|
||||
|
||||
// wsClone.MustAdd_(wsGrad)
|
||||
// bsClone.MustAdd_(bsGrad)
|
||||
|
||||
testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad))
|
||||
testLogits := ds.TestImages.MustMm(ws).MustAdd(bs)
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||
|
|
Loading…
Reference in New Issue
Block a user