fix(vision/augmentation)
This commit is contained in:
parent
f76566834c
commit
0355d6de37
|
@ -137,10 +137,11 @@ func main() {
|
|||
break
|
||||
}
|
||||
|
||||
// bimages := vision.Augmentation(item.Data, true, 4, 8)
|
||||
// logits := net.ForwardT(bimages, true)
|
||||
bimages := vision.Augmentation(item.Data.MustTo(vs.Device(), true), true, 4, 8)
|
||||
|
||||
logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
|
||||
logits := net.ForwardT(bimages, true)
|
||||
|
||||
// logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
|
||||
loss := logits.CrossEntropyForLogits(item.Label.MustTo(vs.Device(), true))
|
||||
opt.BackwardStep(loss)
|
||||
|
||||
|
@ -149,14 +150,14 @@ func main() {
|
|||
// logits.MustDrop()
|
||||
// item.Data.MustDrop()
|
||||
// item.Label.MustDrop()
|
||||
bimages.MustDrop()
|
||||
loss.MustDrop()
|
||||
|
||||
}
|
||||
|
||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.3f\n", epoch, lossVal)
|
||||
|
||||
si = gotch.GetSysInfo()
|
||||
fmt.Printf("Epoch %v\t Used: [%8.2f MiB]\n", epoch, (float64(si.TotalRam-si.FreeRam)-float64(startRAM))/1024)
|
||||
memUsed := (float64(si.TotalRam-si.FreeRam) - float64(startRAM)) / 1024
|
||||
fmt.Printf("Epoch:\t %v\t Memory Used:\t [%8.2f MiB]\tLoss: \t %.3f\n", epoch, memUsed, lossVal)
|
||||
iter.Drop()
|
||||
|
||||
}
|
||||
|
|
|
@ -135,8 +135,9 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
|
|||
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
|
||||
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
|
||||
|
||||
output.Idx(srcIdx)
|
||||
output.Fill_(ts.FloatScalar(0.0))
|
||||
tmp := output.Idx(srcIdx)
|
||||
tmp.Fill_(ts.FloatScalar(0.0))
|
||||
tmp.MustDrop()
|
||||
}
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue
Block a user