fix(tensor/module): BatchAccuracyForLogits - Replace NoGradGuard with SetGradEnabled
This commit is contained in:
parent
929e098adb
commit
18873a8957
|
@ -154,9 +154,9 @@ func main() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
/*
|
||||
|
|
|
@ -63,9 +63,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
noGradGuard := NewNoGradGuard()
|
||||
|
||||
defer noGradGuard.Drop()
|
||||
_ = MustGradSetEnabled(false)
|
||||
|
||||
iter2 := MustNewIter2(xs, ys, int64(batchSize))
|
||||
for {
|
||||
|
@ -88,6 +86,8 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
acc.MustDrop()
|
||||
}
|
||||
|
||||
_ = MustGradSetEnabled(true)
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
|
||||
}
|
||||
|
|
|
@ -958,6 +958,13 @@ func NoGrad1(fn func() interface{}) (retVal interface{}) {
|
|||
}
|
||||
|
||||
// NoGradGuard is a RAII guard that prevents gradient tracking until deallocated.
|
||||
// It actually sets a global flag that is checked by the backend whenever an op is done on a variable.
|
||||
// The guard itself saved the current status and set it to false in the constructor.
|
||||
// And restore the saved status in it’s destructor.
|
||||
// That way it is similar to a with torch.no_grad(): block in python.
|
||||
// Ref. https://discuss.pytorch.org/t/how-does-nogradguard-works-in-cpp/34960/2
|
||||
//
|
||||
// TODO: should we implement Go `mutex` here???
|
||||
type NoGradGuard struct {
|
||||
enabled bool
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user