feat(example/yolo)

This commit is contained in:
sugarme 2020-07-15 18:00:31 +10:00
parent bd56c67b04
commit ddaa00df94
2 changed files with 4 additions and 31 deletions

View File

@ -378,12 +378,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
bboxAttrs := classes + 5
nanchors := int64(len(anchors))
// fmt.Printf("xs shape: %v\n", xs.MustSize())
// fmt.Printf("bsize: %v\n", bsize)
// fmt.Printf("bboxAttrs: %v\n", bboxAttrs)
// fmt.Printf("nanchors: %v\n", nanchors)
// fmt.Printf("gridSize: %v\n", gridSize)
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
tmp2 := tmp1.MustTranspose(1, 2, true)
tmp3 := tmp2.MustContiguous(true)
@ -401,8 +395,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
fmt.Printf("xyOffset shape: %v\n", xyOffset.MustSize())
var flatAnchors []int64
for _, a := range anchors {
flatAnchors = append(flatAnchors, a...)
@ -414,15 +406,11 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
anchorVals = append(anchorVals, v)
}
fmt.Printf("anchors: %v\n", anchorVals)
anchorsTmp1 := ts.MustOfSlice(anchorVals)
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
fmt.Printf("anchors ts shape: %v\n", anchorsTs.MustSize())
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
tmp := xs.MustSigmoid(false)
res = tmp.MustAdd(xyOffset, true)
@ -529,7 +517,6 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
classes := b.Bl.(Yolo).Classes
anchors := b.Bl.(Yolo).Anchors
xsTs := xs
if len(prevYs) > 0 {
xsTs = prevYs[len(prevYs)-1]
}

View File

@ -78,8 +78,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
npreds := size2[0]
predSize := size2[1]
fmt.Printf("npreds: %v - predSize: %v\n", npreds, predSize)
nclasses := uint(predSize - 5)
// The bounding boxes grouped by (maximum) class index.
@ -116,7 +114,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
// Perform non-maximum suppression.
var bboxesRes [][]Bbox
fmt.Printf("Num of bboxes: %v\n", len(bboxes))
for _, bboxesForClass := range bboxes {
// 1. Sort by confidence
sort.Sort(ByConfBbox(bboxesForClass))
@ -139,18 +136,16 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
currentIndex += 1
}
}
// 3. Truncate currentIndex
// 3. Truncate at currentIndex (exclusive)
if currentIndex < len(bboxesForClass) {
bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
// bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
bboxesForClass = append(bboxesForClass[:currentIndex])
}
bboxesRes = append(bboxesRes, bboxesForClass)
}
fmt.Printf("Num of bboxesRes: %v\n", len(bboxesRes))
// Annotate the original image and print boxes information.
fmt.Printf("img shape: %v\n", img.MustSize())
size3, err := img.Size3()
if err != nil {
log.Fatal(err)
@ -164,9 +159,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
var wRatio float64 = float64(initialW) / float64(w)
var hRatio float64 = float64(initialH) / float64(h)
// fmt.Printf("wRatio: %v\n", wRatio)
// fmt.Printf("hRatio: %v\n", hRatio)
for classIndex, bboxesForClass := range bboxesRes {
for _, b := range bboxesForClass {
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
@ -176,8 +168,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
// fmt.Printf("xmin: %v\t ymin: %v\t xmax: %v\t ymax: %v\n", xmin, ymin, xmax, ymax)
// draw rect
drawRect(image, xmin, xmax, ymin, min(ymax, ymin+2))
drawRect(image, xmin, xmax, max(ymin, ymax-2), ymax)
@ -218,13 +208,8 @@ func main() {
var darknet Darknet = ParseConfig(configPath)
fmt.Printf("darknet number of parameters: %v\n", len(darknet.Parameters))
fmt.Printf("darknet number of blocks: %v\n", len(darknet.Blocks))
vs := nn.NewVarStore(gotch.CPU)
model := darknet.BuildModel(vs.Root())
fmt.Printf("Model: %v\n", model)
fmt.Printf("Image path: %v\n", imagePath)
err = vs.Load(modelPath)
if err != nil {
@ -253,6 +238,7 @@ func main() {
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
predictTmp := model.ForwardT(img, false)
predictions := predictTmp.MustSqueeze(true)
imgRes := report(predictions, originalImage, netWidth, netHeight)