fix(example/yolo): fixed memory leak at darknet and report
This commit is contained in:
parent
a272b8034d
commit
9fb9cd67c9
|
@ -366,6 +366,9 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
|
|||
}
|
||||
|
||||
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
|
||||
|
||||
device, err := xs.Device()
|
||||
|
||||
size4, err := xs.Size4()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -383,13 +386,16 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
tmp3 := tmp2.MustContiguous(true)
|
||||
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
|
||||
|
||||
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
|
||||
a := grid.MustRepeat([]int64{gridSize, 1}, false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, device)
|
||||
a := grid.MustRepeat([]int64{gridSize, 1}, true)
|
||||
bTmp := a.MustT(false)
|
||||
b := bTmp.MustContiguous(true)
|
||||
|
||||
xOffset := a.MustView([]int64{-1, 1}, false)
|
||||
yOffset := b.MustView([]int64{-1, 1}, false)
|
||||
xOffset := a.MustView([]int64{-1, 1}, true)
|
||||
yOffset := b.MustView([]int64{-1, 1}, true)
|
||||
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
|
||||
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
|
||||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||
|
@ -409,7 +415,7 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
anchorsTmp1 := ts.MustOfSlice(anchorVals)
|
||||
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
|
||||
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true).MustTo(device, true)
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
tmp := xs.MustSigmoid(false)
|
||||
|
@ -488,11 +494,12 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
var prevYs []ts.Tensor = make([]ts.Tensor, 0)
|
||||
var detections []ts.Tensor = make([]ts.Tensor, 0)
|
||||
|
||||
// NOTE: we will delete all tensors in prevYs after looping
|
||||
for _, b := range blocks {
|
||||
blkTyp := reflect.TypeOf(b.Bl)
|
||||
var ysTs ts.Tensor
|
||||
switch blkTyp.Name() {
|
||||
case "Layer": // Layer type
|
||||
case "Layer":
|
||||
layer := b.Bl.(Layer)
|
||||
xsTs := xs
|
||||
if len(prevYs) > 0 {
|
||||
|
@ -511,8 +518,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
from := b.Bl.(Shortcut).TsIdx
|
||||
addTs := prevYs[int(from)]
|
||||
last := prevYs[len(prevYs)-1]
|
||||
ysTs = last.MustAdd(addTs, false) // TODO: Should we delete it?
|
||||
addTs.MustDrop()
|
||||
ysTs = last.MustAdd(addTs, false)
|
||||
case "Yolo":
|
||||
classes := b.Bl.(Yolo).Classes
|
||||
anchors := b.Bl.(Yolo).Anchors
|
||||
|
@ -524,6 +530,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
dt := detect(xsTs, imageHeight, classes, anchors)
|
||||
|
||||
detections = append(detections, dt)
|
||||
|
||||
ysTs = ts.NewTensor()
|
||||
|
||||
default:
|
||||
|
@ -535,6 +542,15 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
|
||||
res = ts.MustCat(detections, 1, true)
|
||||
|
||||
// Now, free-up memory held up by prevYs
|
||||
for _, t := range prevYs {
|
||||
if t.MustDefined() {
|
||||
// fmt.Printf("will delete ts: %v\n", t)
|
||||
// NOTE: if t memory is delete previously (in switch-case), there will be panic!
|
||||
t.MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}) // end of NewFuncT
|
||||
|
||||
|
|
|
@ -88,6 +88,8 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
for index := 0; index < int(npreds); index++ {
|
||||
predIdx := pred.MustGet(index)
|
||||
var predVals []float64 = predIdx.Values()
|
||||
predIdx.MustDrop()
|
||||
|
||||
confidence := predVals[4]
|
||||
if confidence > confidenceThreshold {
|
||||
classIndex := 0
|
||||
|
@ -111,6 +113,7 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
bboxes[classIndex] = append(bboxes[classIndex], bbox)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Perform non-maximum suppression.
|
||||
|
@ -154,7 +157,7 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
initialW := size3[2]
|
||||
|
||||
imageTmp := img.MustTotype(gotch.Float, false)
|
||||
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), false)
|
||||
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
|
||||
var wRatio float64 = float64(initialW) / float64(w)
|
||||
var hRatio float64 = float64(initialH) / float64(h)
|
||||
|
@ -176,8 +179,8 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
}
|
||||
}
|
||||
|
||||
imgTmp := image.MustMul1(ts.FloatScalar(255.0), false)
|
||||
retVal = imgTmp.MustTotype(gotch.Uint8, false)
|
||||
imgTmp := image.MustMul1(ts.FloatScalar(255.0), true)
|
||||
retVal = imgTmp.MustTotype(gotch.Uint8, true)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user