change(example/cifar): new set of hyperparameters

This commit is contained in:
sugarme 2020-07-11 10:12:09 +10:00
parent 64c1f378c5
commit 44ef7776e5

View File

@ -126,17 +126,17 @@ func main() {
var opt nn.Optimizer
var err error
switch {
case epoch < 50:
case epoch < 150:
opt, err = optConfig.Build(vs, 0.1)
if err != nil {
log.Fatal(err)
}
case epoch < 100:
case epoch < 250:
opt, err = optConfig.Build(vs, 0.01)
if err != nil {
log.Fatal(err)
}
case epoch >= 100:
case epoch >= 250:
opt, err = optConfig.Build(vs, 0.001)
if err != nil {
log.Fatal(err)
@ -144,7 +144,7 @@ func main() {
}
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(128))
iter.Shuffle()
// iter = iter.ToDevice(device)
@ -172,7 +172,7 @@ func main() {
}
vs.Freeze()
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
testAcc := batchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 100)
vs.Unfreeze()
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)