This commit is contained in:
sugarme 2020-07-08 20:37:58 +10:00
parent 2d2bf65ddf
commit 8eb73b0863

View File

@ -128,7 +128,6 @@ func main() {
opt.SetLR(learningRate(epoch))
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
iter = iter.ToDevice(device)
// iter.Shuffle()
// iter = iter.ToDevice(device)
@ -142,15 +141,15 @@ func main() {
// bimages := vision.Augmentation(item.Data, true, 4, 8)
// logits := net.ForwardT(bimages, true)
logits := net.ForwardT(item.Data, false)
loss := logits.CrossEntropyForLogits(item.Label)
logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
loss := logits.CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
opt.BackwardStep(loss)
lossVal = loss.Values()[0]
// logits.MustDrop()
item.Data.MustDrop()
item.Label.MustDrop()
// item.Data.MustDrop()
// item.Label.MustDrop()
loss.MustDrop()
}