fix(tensor/module): BatchAccuracyForLogits - Replace NoGradGuard with SetGradEnabled

This commit is contained in:
sugarme 2020-07-10 18:37:07 +10:00
parent 929e098adb
commit 18873a8957
3 changed files with 13 additions and 6 deletions

View File

@ -154,9 +154,9 @@ func main() {
loss.MustDrop()
}
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
// fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
iter.Drop()
/*

View File

@ -63,9 +63,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
sampleCount float64 = 0.0
)
noGradGuard := NewNoGradGuard()
defer noGradGuard.Drop()
_ = MustGradSetEnabled(false)
iter2 := MustNewIter2(xs, ys, int64(batchSize))
for {
@ -88,6 +86,8 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
acc.MustDrop()
}
_ = MustGradSetEnabled(true)
return sumAccuracy / sampleCount
}

View File

@ -958,6 +958,13 @@ func NoGrad1(fn func() interface{}) (retVal interface{}) {
}
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
// The guard itself saved the current status and set it to false in the constructor.
// And restore the saved status in its destructor.
// That way it is similar to a with torch.no_grad(): block in python.
// Ref. https://discuss.pytorch.org/t/how-does-nogradguard-works-in-cpp/34960/2
//
// TODO: should we implement Go `mutex` here???
type NoGradGuard struct {
enabled bool
}