fix(vision/dataset): augmentation func - memory blow-up at RandomCutout

This commit is contained in:
sugarme 2020-07-10 17:18:47 +10:00
parent 8704adc867
commit 929e098adb
2 changed files with 11 additions and 4 deletions

View File

@ -56,9 +56,7 @@ func main() {
devicedData := item.Data.MustTo(vs.Device(), true)
devicedLabel := item.Label.MustTo(vs.Device(), true)
// bimages := vision.Augmentation(devicedData, true, 4, 8)
// NOTE: memory blow-up at augmentation/RandomCutout
bimages := vision.Augmentation(devicedData, true, 4, 0)
bimages := vision.Augmentation(devicedData, true, 4, 8)
devicedData.MustDrop()
devicedLabel.MustDrop()

View File

@ -148,12 +148,21 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
// TODO: there's memory blow-up here. Need to fix.
// NOTE: using ts.Fill_() causes memory blow-up. Why???
// view := output.Idx(srcIdx)
// zeroSc := ts.FloatScalar(0.0)
// view.Fill_(zeroSc)
// zeroSc.MustDrop()
// view.MustDrop()
view := output.Idx(srcIdx)
zeroTs, err := view.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
view.Copy_(zeroTs)
zeroTs.MustDrop()
view.MustDrop()
}
return output