updated
This commit is contained in:
parent
2d2bf65ddf
commit
8eb73b0863
|
@ -128,7 +128,6 @@ func main() {
|
|||
opt.SetLR(learningRate(epoch))
|
||||
|
||||
iter := ts.MustNewIter2(ds.TrainImages, ds.TrainLabels, int64(64))
|
||||
iter = iter.ToDevice(device)
|
||||
|
||||
// iter.Shuffle()
|
||||
// iter = iter.ToDevice(device)
|
||||
|
@ -142,15 +141,15 @@ func main() {
|
|||
// bimages := vision.Augmentation(item.Data, true, 4, 8)
|
||||
// logits := net.ForwardT(bimages, true)
|
||||
|
||||
logits := net.ForwardT(item.Data, false)
|
||||
loss := logits.CrossEntropyForLogits(item.Label)
|
||||
logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
|
||||
loss := logits.CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
lossVal = loss.Values()[0]
|
||||
|
||||
// logits.MustDrop()
|
||||
item.Data.MustDrop()
|
||||
item.Label.MustDrop()
|
||||
// item.Data.MustDrop()
|
||||
// item.Label.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user