chore(example/cifar): clean up

This commit is contained in:
sugarme 2020-07-09 14:15:02 +10:00
parent 9666d1c44e
commit d872a8e307

View File

@ -83,11 +83,11 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
func learningRate(epoch int) (retVal float64) {
switch {
case epoch < 50:
return 0.4
return float64(0.1)
case epoch < 100:
return 0.2
return float64(0.01)
default:
return 0.1
return float64(0.001)
}
}
@ -101,13 +101,6 @@ func main() {
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
fmt.Printf("Number of labels: %v\n", ds.Labels)
var si *gotch.SI
si = gotch.GetSysInfo()
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
startRAM := si.TotalRam - si.FreeRam
cuda := gotch.CudaBuilder(0)
device := cuda.CudaIfAvailable()
// device := gotch.CPU
@ -117,7 +110,7 @@ func main() {
net := fastResnet(vs.Root())
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
opt, err := optConfig.Build(vs, 0.2)
opt, err := optConfig.Build(vs, 0.1)
if err != nil {
log.Fatal(err)
}
@ -125,11 +118,11 @@ func main() {
var lossVal float64
startTime := time.Now()
for epoch := 0; epoch < 24; epoch++ {
// opt.SetLR(learningRate(epoch))
for epoch := 0; epoch < 150; epoch++ {
opt.SetLR(learningRate(epoch))
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
iter.Shuffle()
// iter = iter.ToDevice(device)
@ -161,10 +154,9 @@ func main() {
loss.MustDrop()
}
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
si = gotch.GetSysInfo()
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
// fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
fmt.Printf("Epoch:\t %v\t \tLoss: \t\t %.3f\n", epoch, lossVal)
iter.Drop()
// Print out GPU used