removed CleanUp inside ts.NoGrad()

This commit is contained in:
sugarme 2023-07-06 00:20:11 +10:00
parent f45f0a7ed0
commit c095d9f7f8
3 changed files with 5 additions and 8 deletions

View File

@ -123,8 +123,7 @@ func runCNN1() {
if testAccuracy > bestAccuracy {
bestAccuracy = testAccuracy
}
}, 3000) // Sleep time in milliseconds. Can be adjusted to available GPU RAMs. Reduce time if having more GPU RAM.
})
}
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)

View File

@ -48,11 +48,13 @@ func runLinear() {
ts.NoGrad(func() {
ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
}, 100) // 100 msec sleeping time. Adjustable to available GPU RAM.
ts.CleanUp(100)
})
testLogits := testImages.MustMm(ws, false).MustAdd(bs, true)
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(testLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
}
}

View File

@ -1176,9 +1176,7 @@ func MustGradSetEnabled(b bool) bool {
}
// NoGrad runs a closure without keeping track of gradients.
func NoGrad(fn func(), sleepTimeOpt ...int) {
CleanUp(sleepTimeOpt...)
func NoGrad(fn func()) {
// Switch off Grad
MustGradSetEnabled(false)
@ -1186,8 +1184,6 @@ func NoGrad(fn func(), sleepTimeOpt ...int) {
// Switch on Grad
MustGradSetEnabled(true)
CleanUp(sleepTimeOpt...)
}
func NoGrad1(fn func() interface{}) interface{} {