WIP(example/yolo)

This commit is contained in:
sugarme 2020-07-15 13:42:40 +10:00
parent f9f71c7163
commit 530acf0894
5 changed files with 330 additions and 17 deletions

View File

@ -0,0 +1,84 @@
package main
var CocoClasses []string = []string{
"person",
"bicycle",
"car",
"motorbike",
"aeroplane",
"bus",
"train",
"truck",
"boat",
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"bench",
"bird",
"cat",
"dog",
"horse",
"sheep",
"cow",
"elephant",
"bear",
"zebra",
"giraffe",
"backpack",
"umbrella",
"handbag",
"tie",
"suitcase",
"frisbee",
"skis",
"snowboard",
"sports ball",
"kite",
"baseball bat",
"baseball glove",
"skateboard",
"surfboard",
"tennis racket",
"bottle",
"wine glass",
"cup",
"fork",
"knife",
"spoon",
"bowl",
"banana",
"apple",
"sandwich",
"orange",
"broccoli",
"carrot",
"hot dog",
"pizza",
"donut",
"cake",
"chair",
"sofa",
"pottedplant",
"bed",
"diningtable",
"toilet",
"tvmonitor",
"laptop",
"mouse",
"remote",
"keyboard",
"cell phone",
"microwave",
"oven",
"toaster",
"sink",
"refrigerator",
"book",
"clock",
"vase",
"scissors",
"teddy bear",
"hair drier",
"toothbrush",
}

View File

@ -362,8 +362,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
slice.Copy_(src)
src.MustDrop()
// TODO: check whether we need to delete slice to prevent memory blow-up
// slice.MustDrop()
slice.MustDrop()
}
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
@ -386,35 +385,44 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
// fmt.Printf("gridSize: %v\n", gridSize)
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
tmp2 := tmp1.MustTranspose(1, 2, true)
tmp3 := tmp2.MustContiguous(true)
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
a := grid.MustRepeat([]int64{gridSize, 1}, false)
bTmp := a.MustT(false)
b := bTmp.MustContiguous(true)
xOffset := a.MustView([]int64{-1, 1}, false)
yOffset := b.MustView([]int64{-1, 1}, false)
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
fmt.Printf("xyOffset shape: %v\n", xyOffset.MustSize())
var flatAnchors []int64
for _, a := range anchors {
flatAnchors = append(flatAnchors, a...)
}
anchorsTmp1 := ts.MustOfSlice(flatAnchors)
var anchorVals []float32
for _, a := range flatAnchors {
v := float32(a) / float32(stride)
anchorVals = append(anchorVals, v)
}
fmt.Printf("anchors: %v\n", anchorVals)
anchorsTmp1 := ts.MustOfSlice(anchorVals)
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
fmt.Printf("anchors ts shape: %v\n", anchorsTs.MustSize())
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
tmp := xs.MustSigmoid(false)
res = tmp.MustAdd(xyOffset, true)
@ -436,7 +444,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
})
// TODO: delete all middle tensors.
return xsTs
}

View File

@ -8,16 +8,190 @@ import (
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
"log"
"math"
"path/filepath"
"sort"
)
const configName = "yolo-v3.cfg"
const (
configName string = "yolo-v3.cfg"
confidenceThreshold float64 = 0.5
nmsThreshold float64 = 0.4
)
var (
model string
image string
)
type Bbox struct {
xmin float64
ymin float64
xmax float64
ymax float64
confidence float64
classIndex uint
classConfidence float64
}
type ByConfBbox []Bbox
// Implement sort.Interface for []Bbox on Bbox.confidence:
// =====================================================
func (bb ByConfBbox) Len() int { return len(bb) }
func (bb ByConfBbox) Less(i, j int) bool { return bb[i].confidence < bb[j].confidence }
func (bb ByConfBbox) Swap(i, j int) { bb[i], bb[j] = bb[j], bb[i] }
// Intersection over union of two bounding boxes.
func Iou(b1, b2 Bbox) (retVal float64) {
b1Area := (b1.xmax - b1.xmin + 1.0) * (b1.ymax - b1.ymin + 1.0)
b2Area := (b2.xmax - b2.xmin + 1.0) * (b2.ymax - b2.ymin + 1.0)
iXmin := math.Max(b1.xmin, b2.xmin)
iXmax := math.Min(b1.xmax, b2.xmax)
iYmin := math.Max(b1.ymin, b2.ymin)
iYmax := math.Min(b1.ymax, b2.ymax)
iArea := math.Max((iXmax-iXmin+1.0), 0.0) * math.Max((iYmax-iYmin+1.0), 0)
return (iArea) / (b1Area + b2Area - iArea)
}
// Assuming x1 <= x2 and y1 <= y2
func drawRect(t ts.Tensor, x1, x2, y1, y2 int64) {
color := ts.MustOfSlice([]float64{0.0, 0.0, 1.0}).MustView([]int64{3, 1, 1}, true)
tmp1 := t.MustNarrow(2, x1, x2-x1, false)
tmp2 := tmp1.MustNarrow(1, y1, y2-y1, true)
tmp2.Copy_(color)
tmp2.MustDrop()
color.MustDrop()
}
func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor) {
size2, err := pred.Size2()
if err != nil {
log.Fatal(err)
}
npreds := size2[0]
predSize := size2[1]
fmt.Printf("npreds: %v - predSize: %v\n", npreds, predSize)
nclasses := uint(predSize - 5)
// The bounding boxes grouped by (maximum) class index.
var bboxes [][]Bbox = make([][]Bbox, int(nclasses))
// Extract the bounding boxes for which confidence is above the threshold.
for index := 0; index < int(npreds); index++ {
predIdx := pred.MustGet(index)
var predVals []float64 = predIdx.Values()
confidence := predVals[4]
if confidence > confidenceThreshold {
classIndex := 0
for i := 0; i < int(nclasses); i++ {
if predVals[5+i] > predVals[5+classIndex] {
classIndex = i
}
}
if predVals[classIndex+5] > 0.0 {
bbox := Bbox{
xmin: predVals[0] - (predVals[2] / 2.0),
ymin: predVals[1] - (predVals[3] / 2.0),
xmax: predVals[0] + (predVals[2] / 2.0),
ymax: predVals[1] + (predVals[3] / 2.0),
// xmin: (predVals[0] - predVals[2]) / 2.0,
// ymin: (predVals[1] - predVals[3]) / 2.0,
// xmax: (predVals[0] + predVals[2]) / 2.0,
// ymax: (predVals[1] + predVals[3]) / 2.0,
confidence: confidence,
classIndex: uint(classIndex),
classConfidence: predVals[5+classIndex],
}
bboxes[classIndex] = append(bboxes[classIndex], bbox)
}
}
}
// Perform non-maximum suppression.
var bboxesRes [][]Bbox
fmt.Printf("Num of bboxes: %v\n", len(bboxes))
for _, bboxesForClass := range bboxes {
// 1. Sort by confidence
sort.Sort(ByConfBbox(bboxesForClass))
// 2.
var currentIndex = 0
for index := 0; index < len(bboxesForClass); index++ {
drop := false
for predIndex := 0; predIndex < currentIndex; predIndex++ {
iou := Iou(bboxesForClass[predIndex], bboxesForClass[index])
if iou > nmsThreshold {
drop = true
break
}
}
if !drop {
// swap
bboxesForClass[currentIndex], bboxesForClass[index] = bboxesForClass[index], bboxesForClass[currentIndex]
currentIndex += 1
}
}
// 3. Truncate currentIndex
if len(bboxesForClass) > 1 {
bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
}
bboxesRes = append(bboxesRes, bboxesForClass)
}
fmt.Printf("Num of bboxesRes: %v\n", len(bboxesRes))
// Annotate the original image and print boxes information.
fmt.Printf("img shape: %v\n", img.MustSize())
size3, err := img.Size3()
if err != nil {
log.Fatal(err)
}
initialH := size3[1]
initialW := size3[2]
imageTmp := img.MustTotype(gotch.Float, false)
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), true)
var wRatio float64 = float64(initialW) / float64(w)
var hRatio float64 = float64(initialH) / float64(h)
// fmt.Printf("wRatio: %v\n", wRatio)
// fmt.Printf("hRatio: %v\n", hRatio)
for classIndex, bboxesForClass := range bboxesRes {
for _, b := range bboxesForClass {
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
xmin := min(max(int64(b.xmin*wRatio), 0), (initialW - 1))
ymin := min(max(int64(b.ymin*hRatio), 0), (initialH - 1))
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
// fmt.Printf("xmin: %v\t ymin: %v\t xmax: %v\t ymax: %v\n", xmin, ymin, xmax, ymax)
// draw rect
drawRect(img, xmin, xmax, ymin, int64(math.Min(float64(ymax), float64(ymin+2))))
drawRect(img, xmin, xmax, int64(math.Max(float64(ymin), float64(ymax-2))), ymax)
drawRect(img, xmin, int64(math.Min(float64(xmax), float64(xmin+2))), ymin, ymax)
drawRect(img, int64(math.Max(float64(xmax-2), float64(xmin))), xmax, ymin, ymax)
}
}
imgTmp := image.MustMul1(ts.FloatScalar(255.0), false)
retVal = imgTmp.MustTotype(gotch.Uint8, false)
return retVal
}
func init() {
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
flag.StringVar(&image, "image", "../../data/yolo/bondi.jpg", "image file to infer")
@ -63,26 +237,43 @@ func main() {
log.Fatal(err)
}
fmt.Println("Image file loaded")
fmt.Printf("Image shape: %v\n", originalImage.MustSize())
netHeight := darknet.Height()
netWidth := darknet.Width()
fmt.Printf("net Height: %v\n", netHeight)
fmt.Printf("net Width: %v\n", netWidth)
imgClone := originalImage.MustShallowClone().MustDetach()
imageTs, err := vision.Resize(originalImage, netWidth, netHeight)
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
if err != nil {
log.Fatal(err)
}
fmt.Printf("imageTs shape: %v\n", imageTs.MustSize())
imgTmp1 := imageTs.MustUnsqueeze(0, true)
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
predictTmp := model.ForwardT(img, false)
// predictions := predictTmp.MustSqueeze(true)
fmt.Printf("predictTmp shape: %v\n", predictTmp.MustSize())
predictions := predictTmp.MustSqueeze(true)
imgRes := report(predictions, originalImage, netWidth, netHeight)
err = vision.Save(imgRes, "image_result.jpg")
if err != nil {
log.Fatal(err)
}
}
func max(v1, v2 int64) (retVal int64) {
if v1 > v2 {
return v1
}
return v2
}
func min(v1, v2 int64) (retVal int64) {
if v1 < v2 {
return v1
}
return v2
}

View File

@ -731,3 +731,8 @@ func AtgTranspose(ptr *Ctensor, self Ctensor, dim0 int64, dim1 int64) {
C.atg_transpose(ptr, self, cdim0, cdim1)
}
// void atg_squeeze(tensor *, tensor self);
func AtgSqueeze(ptr *Ctensor, self Ctensor) {
C.atg_squeeze(ptr, self)
}

View File

@ -2242,3 +2242,29 @@ func (ts Tensor) MustTranspose(dim0, dim1 int64, del bool) (retVal Tensor) {
return retVal
}
func (ts Tensor) Squeeze(del bool) (retVal Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgSqueeze(ptr, ts.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSqueeze(del bool) (retVal Tensor) {
retVal, err := ts.Squeeze(del)
if err != nil {
log.Fatal(err)
}
return retVal
}