removed CleanUp inside ts.NoGrad()
This commit is contained in:
parent
f45f0a7ed0
commit
c095d9f7f8
|
@ -123,8 +123,7 @@ func runCNN1() {
|
|||
if testAccuracy > bestAccuracy {
|
||||
bestAccuracy = testAccuracy
|
||||
}
|
||||
|
||||
}, 3000) // Sleep time in milliseconds. Can be adjusted to available GPU RAMs. Reduce time if having more GPU RAM.
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0)
|
||||
|
|
|
@ -48,11 +48,13 @@ func runLinear() {
|
|||
ts.NoGrad(func() {
|
||||
ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true))
|
||||
}, 100) // 100 msec sleeping time. Adjustable to available GPU RAM.
|
||||
ts.CleanUp(100)
|
||||
})
|
||||
|
||||
testLogits := testImages.MustMm(ws, false).MustAdd(bs, true)
|
||||
testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(testLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1176,9 +1176,7 @@ func MustGradSetEnabled(b bool) bool {
|
|||
}
|
||||
|
||||
// NoGrad runs a closure without keeping track of gradients.
|
||||
func NoGrad(fn func(), sleepTimeOpt ...int) {
|
||||
CleanUp(sleepTimeOpt...)
|
||||
|
||||
func NoGrad(fn func()) {
|
||||
// Switch off Grad
|
||||
MustGradSetEnabled(false)
|
||||
|
||||
|
@ -1186,8 +1184,6 @@ func NoGrad(fn func(), sleepTimeOpt ...int) {
|
|||
|
||||
// Switch on Grad
|
||||
MustGradSetEnabled(true)
|
||||
|
||||
CleanUp(sleepTimeOpt...)
|
||||
}
|
||||
|
||||
func NoGrad1(fn func() interface{}) interface{} {
|
||||
|
|
Loading…
Reference in New Issue
Block a user