chore(example/cifar): reconfigure optimizer inside epoch for loop
This commit is contained in:
parent
18873a8957
commit
c265ba007e
|
@ -109,20 +109,40 @@ func main() {
|
|||
|
||||
net := fastResnet(vs.Root())
|
||||
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
opt, err := optConfig.Build(vs, 0.01)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
// opt, err := optConfig.Build(vs, 0.01)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
|
||||
var lossVal float64
|
||||
startTime := time.Now()
|
||||
|
||||
for epoch := 0; epoch < 150; epoch++ {
|
||||
opt.SetLR(learningRate(epoch))
|
||||
// opt.SetLR(learningRate(epoch))
|
||||
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
|
||||
var opt nn.Optimizer
|
||||
var err error
|
||||
switch {
|
||||
case epoch < 50:
|
||||
opt, err = optConfig.Build(vs, 0.1)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case epoch < 100:
|
||||
opt, err = optConfig.Build(vs, 0.01)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
case epoch >= 100:
|
||||
opt, err = optConfig.Build(vs, 0.001)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
|
||||
// iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(512))
|
||||
iter.Shuffle()
|
||||
// iter = iter.ToDevice(device)
|
||||
|
||||
|
@ -137,26 +157,21 @@ func main() {
|
|||
bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
|
||||
logits := net.ForwardT(bimages, true)
|
||||
// logits := net.ForwardT(devicedData, true)
|
||||
|
||||
// logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
|
||||
loss := logits.CrossEntropyForLogits(devicedLabel)
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
lossVal = loss.Values()[0]
|
||||
|
||||
// logits.MustDrop()
|
||||
// item.Data.MustDrop()
|
||||
// item.Label.MustDrop()
|
||||
devicedData.MustDrop()
|
||||
devicedLabel.MustDrop()
|
||||
bimages.MustDrop()
|
||||
loss.MustDrop()
|
||||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
// fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
// fmt.Printf("Epoch:\t %v\t Loss: \t %.3f \tAcc: %10.2f%%\n", epoch, lossVal, testAcc*100.0)
|
||||
fmt.Printf("Epoch: %10.0d\tLoss:%10.3f\n", epoch, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
/*
|
||||
|
@ -174,6 +189,6 @@ func main() {
|
|||
}
|
||||
|
||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), 512)
|
||||
fmt.Printf("Loss: \t %.3f\t Accuracy: %10.2f%%\n", lossVal, testAcc*100.0)
|
||||
fmt.Printf("Accuracy: %10.2f%%\n", testAcc*100.0)
|
||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user