fix(vision/dataset): augmentation func - memory blow-up at RandomCutout
This commit is contained in:
parent
8704adc867
commit
929e098adb
|
@ -56,9 +56,7 @@ func main() {
|
|||
|
||||
devicedData := item.Data.MustTo(vs.Device(), true)
|
||||
devicedLabel := item.Label.MustTo(vs.Device(), true)
|
||||
// bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
// NOTE: memory blow-up at augmentation/RandomCutout
|
||||
bimages := vision.Augmentation(devicedData, true, 4, 0)
|
||||
bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
|
||||
devicedData.MustDrop()
|
||||
devicedLabel.MustDrop()
|
||||
|
|
|
@ -148,12 +148,21 @@ func RandomCutout(t ts.Tensor, sz int64) (retVal ts.Tensor) {
|
|||
wIdx := ts.NewNarrow(int64(startW), int64(startW)+sz)
|
||||
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
|
||||
|
||||
// TODO: there's memory blow-up here. Need to fix.
|
||||
// NOTE: using ts.Fill_() causes memory blow-up. Why???
|
||||
// view := output.Idx(srcIdx)
|
||||
// zeroSc := ts.FloatScalar(0.0)
|
||||
// view.Fill_(zeroSc)
|
||||
// zeroSc.MustDrop()
|
||||
// view.MustDrop()
|
||||
|
||||
view := output.Idx(srcIdx)
|
||||
zeroTs, err := view.ZerosLike(false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
view.Copy_(zeroTs)
|
||||
zeroTs.MustDrop()
|
||||
view.MustDrop()
|
||||
}
|
||||
|
||||
return output
|
||||
|
|
Loading…
Reference in New Issue
Block a user