WIP(fix BatchAccuracyForLogits)
This commit is contained in:
parent
c265ba007e
commit
290421a30f
|
@ -118,6 +118,8 @@ func main() {
|
|||
var lossVal float64
|
||||
startTime := time.Now()
|
||||
|
||||
var bestAccuracy float64
|
||||
|
||||
for epoch := 0; epoch < 150; epoch++ {
|
||||
// opt.SetLR(learningRate(epoch))
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
|
@ -169,9 +171,14 @@ func main() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
vs.Freeze()
|
||||
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
vs.Unfreeze()
|
||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
if testAcc > bestAccuracy {
|
||||
bestAccuracy = testAcc
|
||||
}
|
||||
iter.Drop()
|
||||
|
||||
/*
|
||||
|
@ -188,7 +195,39 @@ func main() {
|
|||
* */
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Accuracy: %10.2f%%\n", testAcc*100.0)
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Best Accuracy: %10.2f%%\n", bestAccuracy*100.0)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
||||
func batchAccuracyForLogits(m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||
|
||||
var (
|
||||
sumAccuracy float64 = 0.0
|
||||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
||||
for {
|
||||
item, ok := iter2.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
size := float64(item.Data.MustSize()[0])
|
||||
bImages := item.Data.MustTo(d, true)
|
||||
bLabels := item.Label.MustTo(d, true)
|
||||
|
||||
logits := m.ForwardT(bImages, false)
|
||||
acc := logits.AccuracyForLogits(bLabels)
|
||||
sumAccuracy += acc.Values()[0] * size
|
||||
sampleCount += size
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
acc.MustDrop()
|
||||
}
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
|
||||
}
|
||||
|
|
|
@ -125,6 +125,8 @@ func runCNN1() {
|
|||
logits := net.ForwardT(bImages, true)
|
||||
loss := logits.CrossEntropyForLogits(bLabels)
|
||||
|
||||
// loss = loss.MustSetRequiresGrad(true)
|
||||
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
epocLoss = loss.MustShallowClone()
|
||||
|
@ -138,10 +140,12 @@ func runCNN1() {
|
|||
// loss.MustDrop()
|
||||
}
|
||||
|
||||
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
||||
// fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100)
|
||||
vs.Freeze()
|
||||
testAccuracy := batchAccuracyForLogits(net, testImages, testLabels, vs.Device(), 1024)
|
||||
vs.Unfreeze()
|
||||
fmt.Printf("Epoch: %v\t Loss: %.2f \t Test accuracy: %.2f%%\n", epoch, epocLoss.Values()[0], testAccuracy*100.0)
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
||||
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, epocLoss.Values()[0])
|
||||
epocLoss.MustDrop()
|
||||
imagesTs.MustDrop()
|
||||
labelsTs.MustDrop()
|
||||
|
@ -197,13 +201,47 @@ func runCNN2() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
||||
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
||||
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
|
||||
vs.Freeze()
|
||||
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||
vs.Unfreeze()
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100.0)
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
||||
func batchAccuracyForLogits(m ts.ModuleT, xs, ys ts.Tensor, d gotch.Device, batchSize int) (retVal float64) {
|
||||
|
||||
var (
|
||||
sumAccuracy float64 = 0.0
|
||||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
iter2 := ts.MustNewIter2(xs, ys, int64(batchSize))
|
||||
for {
|
||||
item, ok := iter2.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
size := float64(item.Data.MustSize()[0])
|
||||
bImages := item.Data.MustTo(d, true)
|
||||
bLabels := item.Label.MustTo(d, true)
|
||||
|
||||
logits := m.ForwardT(bImages, false)
|
||||
acc := logits.AccuracyForLogits(bLabels)
|
||||
sumAccuracy += acc.Values()[0] * size
|
||||
sampleCount += size
|
||||
|
||||
bImages.MustDrop()
|
||||
bLabels.MustDrop()
|
||||
acc.MustDrop()
|
||||
}
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user