chore(example/cifar): clean up
This commit is contained in:
parent
9666d1c44e
commit
d872a8e307
|
@ -83,11 +83,11 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
|
|||
func learningRate(epoch int) (retVal float64) {
|
||||
switch {
|
||||
case epoch < 50:
|
||||
return 0.4
|
||||
return float64(0.1)
|
||||
case epoch < 100:
|
||||
return 0.2
|
||||
return float64(0.01)
|
||||
default:
|
||||
return 0.1
|
||||
return float64(0.001)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -101,13 +101,6 @@ func main() {
|
|||
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
|
||||
fmt.Printf("Number of labels: %v\n", ds.Labels)
|
||||
|
||||
var si *gotch.SI
|
||||
si = gotch.GetSysInfo()
|
||||
fmt.Printf("Total RAM (MB):\t %8.2f\n", float64(si.TotalRam)/1024)
|
||||
fmt.Printf("Used RAM (MB):\t %8.2f\n", float64(si.TotalRam-si.FreeRam)/1024)
|
||||
|
||||
startRAM := si.TotalRam - si.FreeRam
|
||||
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
device := cuda.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
|
@ -117,7 +110,7 @@ func main() {
|
|||
net := fastResnet(vs.Root())
|
||||
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
opt, err := optConfig.Build(vs, 0.2)
|
||||
opt, err := optConfig.Build(vs, 0.1)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -125,11 +118,11 @@ func main() {
|
|||
var lossVal float64
|
||||
startTime := time.Now()
|
||||
|
||||
for epoch := 0; epoch < 24; epoch++ {
|
||||
// opt.SetLR(learningRate(epoch))
|
||||
for epoch := 0; epoch < 150; epoch++ {
|
||||
opt.SetLR(learningRate(epoch))
|
||||
|
||||
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
|
||||
iter.Shuffle()
|
||||
// iter = iter.ToDevice(device)
|
||||
|
||||
|
@ -161,10 +154,9 @@ func main() {
|
|||
loss.MustDrop()
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
si = gotch.GetSysInfo()
|
||||
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
|
||||
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f \tAcc: %10.2f%%\n", epoch, memUsed, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch:\t %v\t \tLoss: \t\t %.3f\n", epoch, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
// Print out GPU used
|
||||
|
|
Loading…
Reference in New Issue
Block a user