chore(example/cifar): added shuffle
This commit is contained in:
parent
8eb73b0863
commit
11d345360d
|
@ -107,9 +107,9 @@ func main() {
|
|||
|
||||
startRAM := si.TotalRam - si.FreeRam
|
||||
|
||||
cuda := gotch.CudaBuilder(0)
|
||||
device := cuda.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
// cuda := gotch.CudaBuilder(0)
|
||||
// device := cuda.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
|
||||
vs := nn.NewVarStore(device)
|
||||
|
||||
|
@ -128,8 +128,7 @@ func main() {
|
|||
opt.SetLR(learningRate(epoch))
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
|
||||
// iter.Shuffle()
|
||||
iter.Shuffle()
|
||||
// iter = iter.ToDevice(device)
|
||||
|
||||
for {
|
||||
|
@ -154,7 +153,7 @@ func main() {
|
|||
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.3f\n", epoch, lossVal)
|
||||
|
||||
si = gotch.GetSysInfo()
|
||||
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", epoch, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
|
||||
|
@ -163,6 +162,6 @@ func main() {
|
|||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||
fmt.Printf("Loss: \t %.3f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user