From c095d9f7f8c665d6db322d49cbe19d8090032b36 Mon Sep 17 00:00:00 2001 From: sugarme Date: Thu, 6 Jul 2023 00:20:11 +1000 Subject: [PATCH] removed CleanUp inside ts.NoGrad() --- example/mnist/cnn.go | 3 +-- example/mnist/linear.go | 4 +++- ts/tensor.go | 6 +----- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index 692f790..5a0d505 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -123,8 +123,7 @@ func runCNN1() { if testAccuracy > bestAccuracy { bestAccuracy = testAccuracy } - - }, 3000) // Sleep time in milliseconds. Can be adjusted to available GPU RAMs. Reduce time if having more GPU RAM. + }) } fmt.Printf("Best test accuracy: %.2f%%\n", bestAccuracy*100.0) diff --git a/example/mnist/linear.go b/example/mnist/linear.go index 8a89ce7..0c10f88 100644 --- a/example/mnist/linear.go +++ b/example/mnist/linear.go @@ -48,11 +48,13 @@ func runLinear() { ts.NoGrad(func() { ws.Add_(ws.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true)) bs.Add_(bs.MustGrad(false).MustMulScalar(ts.FloatScalar(-1.0), true)) - }, 100) // 100 msec sleeping time. Adjustable to available GPU RAM. + ts.CleanUp(100) + }) testLogits := testImages.MustMm(ws, false).MustAdd(bs, true) testAccuracy := testLogits.MustArgmax([]int64{-1}, false, true).MustEqTensor(testLabels, true).MustTotype(gotch.Float, true).MustMean(gotch.Float, true).MustView([]int64{-1}, true).MustFloat64Value([]int64{0}) fmt.Printf("Epoch: %v - Loss: %.3f - Test accuracy: %.2f%%\n", epoch, loss.Float64Values()[0], testAccuracy*100) + } } diff --git a/ts/tensor.go b/ts/tensor.go index 23c0b9a..bc8c67c 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -1176,9 +1176,7 @@ func MustGradSetEnabled(b bool) bool { } // NoGrad runs a closure without keeping track of gradients. -func NoGrad(fn func(), sleepTimeOpt ...int) { - CleanUp(sleepTimeOpt...) - +func NoGrad(fn func()) { // Switch off Grad MustGradSetEnabled(false) @@ -1186,8 +1184,6 @@ func NoGrad(fn func(), sleepTimeOpt ...int) { // Switch on Grad MustGradSetEnabled(true) - - CleanUp(sleepTimeOpt...) } func NoGrad1(fn func() interface{}) interface{} {