feat(example/yolo)
This commit is contained in:
parent
bd56c67b04
commit
ddaa00df94
|
@ -378,12 +378,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
bboxAttrs := classes + 5
|
||||
nanchors := int64(len(anchors))
|
||||
|
||||
// fmt.Printf("xs shape: %v\n", xs.MustSize())
|
||||
// fmt.Printf("bsize: %v\n", bsize)
|
||||
// fmt.Printf("bboxAttrs: %v\n", bboxAttrs)
|
||||
// fmt.Printf("nanchors: %v\n", nanchors)
|
||||
// fmt.Printf("gridSize: %v\n", gridSize)
|
||||
|
||||
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
|
||||
tmp2 := tmp1.MustTranspose(1, 2, true)
|
||||
tmp3 := tmp2.MustContiguous(true)
|
||||
|
@ -401,8 +395,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
fmt.Printf("xyOffset shape: %v\n", xyOffset.MustSize())
|
||||
|
||||
var flatAnchors []int64
|
||||
for _, a := range anchors {
|
||||
flatAnchors = append(flatAnchors, a...)
|
||||
|
@ -414,15 +406,11 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
anchorVals = append(anchorVals, v)
|
||||
}
|
||||
|
||||
fmt.Printf("anchors: %v\n", anchorVals)
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(anchorVals)
|
||||
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
|
||||
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
fmt.Printf("anchors ts shape: %v\n", anchorsTs.MustSize())
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
tmp := xs.MustSigmoid(false)
|
||||
res = tmp.MustAdd(xyOffset, true)
|
||||
|
@ -529,7 +517,6 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
classes := b.Bl.(Yolo).Classes
|
||||
anchors := b.Bl.(Yolo).Anchors
|
||||
xsTs := xs
|
||||
|
||||
if len(prevYs) > 0 {
|
||||
xsTs = prevYs[len(prevYs)-1]
|
||||
}
|
||||
|
|
|
@ -78,8 +78,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
npreds := size2[0]
|
||||
predSize := size2[1]
|
||||
|
||||
fmt.Printf("npreds: %v - predSize: %v\n", npreds, predSize)
|
||||
|
||||
nclasses := uint(predSize - 5)
|
||||
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
|
@ -116,7 +114,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
|
||||
// Perform non-maximum suppression.
|
||||
var bboxesRes [][]Bbox
|
||||
fmt.Printf("Num of bboxes: %v\n", len(bboxes))
|
||||
for _, bboxesForClass := range bboxes {
|
||||
// 1. Sort by confidence
|
||||
sort.Sort(ByConfBbox(bboxesForClass))
|
||||
|
@ -139,18 +136,16 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
currentIndex += 1
|
||||
}
|
||||
}
|
||||
// 3. Truncate currentIndex
|
||||
// 3. Truncate at currentIndex (exclusive)
|
||||
if currentIndex < len(bboxesForClass) {
|
||||
bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
|
||||
// bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
|
||||
bboxesForClass = append(bboxesForClass[:currentIndex])
|
||||
}
|
||||
|
||||
bboxesRes = append(bboxesRes, bboxesForClass)
|
||||
}
|
||||
|
||||
fmt.Printf("Num of bboxesRes: %v\n", len(bboxesRes))
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
fmt.Printf("img shape: %v\n", img.MustSize())
|
||||
size3, err := img.Size3()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -164,9 +159,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
var wRatio float64 = float64(initialW) / float64(w)
|
||||
var hRatio float64 = float64(initialH) / float64(h)
|
||||
|
||||
// fmt.Printf("wRatio: %v\n", wRatio)
|
||||
// fmt.Printf("hRatio: %v\n", hRatio)
|
||||
|
||||
for classIndex, bboxesForClass := range bboxesRes {
|
||||
for _, b := range bboxesForClass {
|
||||
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
|
||||
|
@ -176,8 +168,6 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
|
|||
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
|
||||
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
|
||||
|
||||
// fmt.Printf("xmin: %v\t ymin: %v\t xmax: %v\t ymax: %v\n", xmin, ymin, xmax, ymax)
|
||||
|
||||
// draw rect
|
||||
drawRect(image, xmin, xmax, ymin, min(ymax, ymin+2))
|
||||
drawRect(image, xmin, xmax, max(ymin, ymax-2), ymax)
|
||||
|
@ -218,13 +208,8 @@ func main() {
|
|||
|
||||
var darknet Darknet = ParseConfig(configPath)
|
||||
|
||||
fmt.Printf("darknet number of parameters: %v\n", len(darknet.Parameters))
|
||||
fmt.Printf("darknet number of blocks: %v\n", len(darknet.Blocks))
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
model := darknet.BuildModel(vs.Root())
|
||||
fmt.Printf("Model: %v\n", model)
|
||||
fmt.Printf("Image path: %v\n", imagePath)
|
||||
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
|
@ -253,6 +238,7 @@ func main() {
|
|||
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
||||
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
predictTmp := model.ForwardT(img, false)
|
||||
|
||||
predictions := predictTmp.MustSqueeze(true)
|
||||
|
||||
imgRes := report(predictions, originalImage, netWidth, netHeight)
|
||||
|
|
Loading…
Reference in New Issue
Block a user