fix(vision/augmentation)

This commit is contained in:
sugarme 2020-07-08 21:23:17 +10:00
parent f76566834c
commit 0355d6de37
2 changed files with 10 additions and 8 deletions

View File

@ -137,10 +137,11 @@ func main() {
break
}
// bimages := vision.Augmentation(item.Data, true, 4, 8)
// logits := net.ForwardT(bimages, true)
bimages := vision.Augmentation(item.Data.MustTo(vs.Device(), true), true, 4, 8)
logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
logits := net.ForwardT(bimages, true)
// logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
loss := logits.CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
opt.BackwardStep(loss)
@ -149,14 +150,14 @@ func main() {
// logits.MustDrop()
// item.Data.MustDrop()
// item.Label.MustDrop()
bimages.MustDrop()
loss.MustDrop()
}
fmt.Printf("Epoch:\t %v\tLoss: \t %.3f\n", epoch, lossVal)
si = gotch.GetSysInfo()
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", epoch, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f\n", epoch, memUsed, lossVal)
iter.Drop()
}

View File

@ -135,8 +135,9 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
output.Idx(srcIdx)
output.Fill_(ts.FloatScalar(0.0))
tmp := output.Idx(srcIdx)
tmp.Fill_(ts.FloatScalar(0.0))
tmp.MustDrop()
}
return output