example/cifar: updated
This commit is contained in:
parent
531bda04b5
commit
22333d4544
|
@ -154,9 +154,9 @@ func main() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
/*
|
||||
|
@ -174,6 +174,6 @@ func main() {
|
|||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Loss: \t %.3f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||
fmt.Printf("Loss: \t %.3f\t Accuracy: %10.2f%%\n", lossVal, testAcc*100.0)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
|
|
@ -64,6 +64,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
)
|
||||
|
||||
noGradGuard := NewNoGradGuard()
|
||||
|
||||
defer noGradGuard.Drop()
|
||||
|
||||
iter2 := MustNewIter2(xs, ys, int64(batchSize))
|
||||
|
@ -88,6 +89,7 @@ func BatchAccuracyForLogits(m ModuleT, xs, ys Tensor, d gotch.Device, batchSize
|
|||
}
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
|
||||
}
|
||||
|
||||
// BatchAccuracyForLogitIdx is an alternative of BatchAccuracyForLogits to
|
||||
|
|
Loading…
Reference in New Issue
Block a user