WIP(example/yolo)
This commit is contained in:
parent
f9f71c7163
commit
530acf0894
84
example/yolo/coco-classes.go
Normal file
84
example/yolo/coco-classes.go
Normal file
|
@ -0,0 +1,84 @@
|
|||
package main
|
||||
|
||||
var CocoClasses []string = []string{
|
||||
"person",
|
||||
"bicycle",
|
||||
"car",
|
||||
"motorbike",
|
||||
"aeroplane",
|
||||
"bus",
|
||||
"train",
|
||||
"truck",
|
||||
"boat",
|
||||
"traffic light",
|
||||
"fire hydrant",
|
||||
"stop sign",
|
||||
"parking meter",
|
||||
"bench",
|
||||
"bird",
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"sheep",
|
||||
"cow",
|
||||
"elephant",
|
||||
"bear",
|
||||
"zebra",
|
||||
"giraffe",
|
||||
"backpack",
|
||||
"umbrella",
|
||||
"handbag",
|
||||
"tie",
|
||||
"suitcase",
|
||||
"frisbee",
|
||||
"skis",
|
||||
"snowboard",
|
||||
"sports ball",
|
||||
"kite",
|
||||
"baseball bat",
|
||||
"baseball glove",
|
||||
"skateboard",
|
||||
"surfboard",
|
||||
"tennis racket",
|
||||
"bottle",
|
||||
"wine glass",
|
||||
"cup",
|
||||
"fork",
|
||||
"knife",
|
||||
"spoon",
|
||||
"bowl",
|
||||
"banana",
|
||||
"apple",
|
||||
"sandwich",
|
||||
"orange",
|
||||
"broccoli",
|
||||
"carrot",
|
||||
"hot dog",
|
||||
"pizza",
|
||||
"donut",
|
||||
"cake",
|
||||
"chair",
|
||||
"sofa",
|
||||
"pottedplant",
|
||||
"bed",
|
||||
"diningtable",
|
||||
"toilet",
|
||||
"tvmonitor",
|
||||
"laptop",
|
||||
"mouse",
|
||||
"remote",
|
||||
"keyboard",
|
||||
"cell phone",
|
||||
"microwave",
|
||||
"oven",
|
||||
"toaster",
|
||||
"sink",
|
||||
"refrigerator",
|
||||
"book",
|
||||
"clock",
|
||||
"vase",
|
||||
"scissors",
|
||||
"teddy bear",
|
||||
"hair drier",
|
||||
"toothbrush",
|
||||
}
|
|
@ -362,8 +362,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
|
|||
|
||||
slice.Copy_(src)
|
||||
src.MustDrop()
|
||||
// TODO: check whether we need to delete slice to prevent memory blow-up
|
||||
// slice.MustDrop()
|
||||
slice.MustDrop()
|
||||
}
|
||||
|
||||
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
|
||||
|
@ -386,35 +385,44 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
// fmt.Printf("gridSize: %v\n", gridSize)
|
||||
|
||||
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
|
||||
|
||||
tmp2 := tmp1.MustTranspose(1, 2, true)
|
||||
|
||||
tmp3 := tmp2.MustContiguous(true)
|
||||
|
||||
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
|
||||
|
||||
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
|
||||
a := grid.MustRepeat([]int64{gridSize, 1}, false)
|
||||
bTmp := a.MustT(false)
|
||||
b := bTmp.MustContiguous(true)
|
||||
|
||||
xOffset := a.MustView([]int64{-1, 1}, false)
|
||||
yOffset := b.MustView([]int64{-1, 1}, false)
|
||||
|
||||
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
|
||||
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
|
||||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
fmt.Printf("xyOffset shape: %v\n", xyOffset.MustSize())
|
||||
|
||||
var flatAnchors []int64
|
||||
for _, a := range anchors {
|
||||
flatAnchors = append(flatAnchors, a...)
|
||||
}
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(flatAnchors)
|
||||
var anchorVals []float32
|
||||
for _, a := range flatAnchors {
|
||||
v := float32(a) / float32(stride)
|
||||
anchorVals = append(anchorVals, v)
|
||||
}
|
||||
|
||||
fmt.Printf("anchors: %v\n", anchorVals)
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(anchorVals)
|
||||
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
|
||||
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
fmt.Printf("anchors ts shape: %v\n", anchorsTs.MustSize())
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
tmp := xs.MustSigmoid(false)
|
||||
res = tmp.MustAdd(xyOffset, true)
|
||||
|
@ -436,7 +444,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
|
|||
})
|
||||
|
||||
// TODO: delete all middle tensors.
|
||||
|
||||
return xsTs
|
||||
}
|
||||
|
||||
|
|
|
@ -8,16 +8,190 @@ import (
|
|||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
"log"
|
||||
"math"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const configName = "yolo-v3.cfg"
|
||||
const (
|
||||
configName string = "yolo-v3.cfg"
|
||||
confidenceThreshold float64 = 0.5
|
||||
nmsThreshold float64 = 0.4
|
||||
)
|
||||
|
||||
var (
|
||||
model string
|
||||
image string
|
||||
)
|
||||
|
||||
type Bbox struct {
|
||||
xmin float64
|
||||
ymin float64
|
||||
xmax float64
|
||||
ymax float64
|
||||
confidence float64
|
||||
classIndex uint
|
||||
classConfidence float64
|
||||
}
|
||||
|
||||
type ByConfBbox []Bbox
|
||||
|
||||
// Implement sort.Interface for []Bbox on Bbox.confidence:
|
||||
// =====================================================
|
||||
func (bb ByConfBbox) Len() int { return len(bb) }
|
||||
func (bb ByConfBbox) Less(i, j int) bool { return bb[i].confidence < bb[j].confidence }
|
||||
func (bb ByConfBbox) Swap(i, j int) { bb[i], bb[j] = bb[j], bb[i] }
|
||||
|
||||
// Intersection over union of two bounding boxes.
|
||||
func Iou(b1, b2 Bbox) (retVal float64) {
|
||||
b1Area := (b1.xmax - b1.xmin + 1.0) * (b1.ymax - b1.ymin + 1.0)
|
||||
b2Area := (b2.xmax - b2.xmin + 1.0) * (b2.ymax - b2.ymin + 1.0)
|
||||
|
||||
iXmin := math.Max(b1.xmin, b2.xmin)
|
||||
iXmax := math.Min(b1.xmax, b2.xmax)
|
||||
iYmin := math.Max(b1.ymin, b2.ymin)
|
||||
iYmax := math.Min(b1.ymax, b2.ymax)
|
||||
iArea := math.Max((iXmax-iXmin+1.0), 0.0) * math.Max((iYmax-iYmin+1.0), 0)
|
||||
|
||||
return (iArea) / (b1Area + b2Area - iArea)
|
||||
}
|
||||
|
||||
// Assuming x1 <= x2 and y1 <= y2
|
||||
func drawRect(t ts.Tensor, x1, x2, y1, y2 int64) {
|
||||
color := ts.MustOfSlice([]float64{0.0, 0.0, 1.0}).MustView([]int64{3, 1, 1}, true)
|
||||
|
||||
tmp1 := t.MustNarrow(2, x1, x2-x1, false)
|
||||
tmp2 := tmp1.MustNarrow(1, y1, y2-y1, true)
|
||||
tmp2.Copy_(color)
|
||||
tmp2.MustDrop()
|
||||
color.MustDrop()
|
||||
}
|
||||
|
||||
func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor) {
|
||||
size2, err := pred.Size2()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
npreds := size2[0]
|
||||
predSize := size2[1]
|
||||
|
||||
fmt.Printf("npreds: %v - predSize: %v\n", npreds, predSize)
|
||||
|
||||
nclasses := uint(predSize - 5)
|
||||
|
||||
// The bounding boxes grouped by (maximum) class index.
|
||||
var bboxes [][]Bbox = make([][]Bbox, int(nclasses))
|
||||
|
||||
// Extract the bounding boxes for which confidence is above the threshold.
|
||||
for index := 0; index < int(npreds); index++ {
|
||||
predIdx := pred.MustGet(index)
|
||||
var predVals []float64 = predIdx.Values()
|
||||
confidence := predVals[4]
|
||||
if confidence > confidenceThreshold {
|
||||
classIndex := 0
|
||||
for i := 0; i < int(nclasses); i++ {
|
||||
if predVals[5+i] > predVals[5+classIndex] {
|
||||
classIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if predVals[classIndex+5] > 0.0 {
|
||||
bbox := Bbox{
|
||||
xmin: predVals[0] - (predVals[2] / 2.0),
|
||||
ymin: predVals[1] - (predVals[3] / 2.0),
|
||||
xmax: predVals[0] + (predVals[2] / 2.0),
|
||||
ymax: predVals[1] + (predVals[3] / 2.0),
|
||||
// xmin: (predVals[0] - predVals[2]) / 2.0,
|
||||
// ymin: (predVals[1] - predVals[3]) / 2.0,
|
||||
// xmax: (predVals[0] + predVals[2]) / 2.0,
|
||||
// ymax: (predVals[1] + predVals[3]) / 2.0,
|
||||
confidence: confidence,
|
||||
classIndex: uint(classIndex),
|
||||
classConfidence: predVals[5+classIndex],
|
||||
}
|
||||
|
||||
bboxes[classIndex] = append(bboxes[classIndex], bbox)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform non-maximum suppression.
|
||||
var bboxesRes [][]Bbox
|
||||
fmt.Printf("Num of bboxes: %v\n", len(bboxes))
|
||||
for _, bboxesForClass := range bboxes {
|
||||
// 1. Sort by confidence
|
||||
sort.Sort(ByConfBbox(bboxesForClass))
|
||||
|
||||
// 2.
|
||||
var currentIndex = 0
|
||||
for index := 0; index < len(bboxesForClass); index++ {
|
||||
drop := false
|
||||
for predIndex := 0; predIndex < currentIndex; predIndex++ {
|
||||
iou := Iou(bboxesForClass[predIndex], bboxesForClass[index])
|
||||
if iou > nmsThreshold {
|
||||
drop = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !drop {
|
||||
// swap
|
||||
bboxesForClass[currentIndex], bboxesForClass[index] = bboxesForClass[index], bboxesForClass[currentIndex]
|
||||
currentIndex += 1
|
||||
}
|
||||
}
|
||||
// 3. Truncate currentIndex
|
||||
if len(bboxesForClass) > 1 {
|
||||
bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
|
||||
}
|
||||
bboxesRes = append(bboxesRes, bboxesForClass)
|
||||
}
|
||||
|
||||
fmt.Printf("Num of bboxesRes: %v\n", len(bboxesRes))
|
||||
|
||||
// Annotate the original image and print boxes information.
|
||||
fmt.Printf("img shape: %v\n", img.MustSize())
|
||||
size3, err := img.Size3()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
initialH := size3[1]
|
||||
initialW := size3[2]
|
||||
|
||||
imageTmp := img.MustTotype(gotch.Float, false)
|
||||
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
|
||||
var wRatio float64 = float64(initialW) / float64(w)
|
||||
var hRatio float64 = float64(initialH) / float64(h)
|
||||
|
||||
// fmt.Printf("wRatio: %v\n", wRatio)
|
||||
// fmt.Printf("hRatio: %v\n", hRatio)
|
||||
|
||||
for classIndex, bboxesForClass := range bboxesRes {
|
||||
for _, b := range bboxesForClass {
|
||||
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
|
||||
|
||||
xmin := min(max(int64(b.xmin*wRatio), 0), (initialW - 1))
|
||||
ymin := min(max(int64(b.ymin*hRatio), 0), (initialH - 1))
|
||||
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
|
||||
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
|
||||
|
||||
// fmt.Printf("xmin: %v\t ymin: %v\t xmax: %v\t ymax: %v\n", xmin, ymin, xmax, ymax)
|
||||
|
||||
// draw rect
|
||||
drawRect(img, xmin, xmax, ymin, int64(math.Min(float64(ymax), float64(ymin+2))))
|
||||
drawRect(img, xmin, xmax, int64(math.Max(float64(ymin), float64(ymax-2))), ymax)
|
||||
drawRect(img, xmin, int64(math.Min(float64(xmax), float64(xmin+2))), ymin, ymax)
|
||||
drawRect(img, int64(math.Max(float64(xmax-2), float64(xmin))), xmax, ymin, ymax)
|
||||
}
|
||||
}
|
||||
|
||||
imgTmp := image.MustMul1(ts.FloatScalar(255.0), false)
|
||||
retVal = imgTmp.MustTotype(gotch.Uint8, false)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
|
||||
flag.StringVar(&image, "image", "../../data/yolo/bondi.jpg", "image file to infer")
|
||||
|
@ -63,26 +237,43 @@ func main() {
|
|||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Image file loaded")
|
||||
fmt.Printf("Image shape: %v\n", originalImage.MustSize())
|
||||
|
||||
netHeight := darknet.Height()
|
||||
netWidth := darknet.Width()
|
||||
|
||||
fmt.Printf("net Height: %v\n", netHeight)
|
||||
fmt.Printf("net Width: %v\n", netWidth)
|
||||
imgClone := originalImage.MustShallowClone().MustDetach()
|
||||
|
||||
imageTs, err := vision.Resize(originalImage, netWidth, netHeight)
|
||||
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("imageTs shape: %v\n", imageTs.MustSize())
|
||||
|
||||
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
||||
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
||||
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
predictTmp := model.ForwardT(img, false)
|
||||
// predictions := predictTmp.MustSqueeze(true)
|
||||
fmt.Printf("predictTmp shape: %v\n", predictTmp.MustSize())
|
||||
predictions := predictTmp.MustSqueeze(true)
|
||||
|
||||
imgRes := report(predictions, originalImage, netWidth, netHeight)
|
||||
|
||||
err = vision.Save(imgRes, "image_result.jpg")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func max(v1, v2 int64) (retVal int64) {
|
||||
if v1 > v2 {
|
||||
return v1
|
||||
}
|
||||
|
||||
return v2
|
||||
}
|
||||
|
||||
func min(v1, v2 int64) (retVal int64) {
|
||||
if v1 < v2 {
|
||||
return v1
|
||||
}
|
||||
|
||||
return v2
|
||||
}
|
||||
|
|
|
@ -731,3 +731,8 @@ func AtgTranspose(ptr *Ctensor, self Ctensor, dim0 int64, dim1 int64) {
|
|||
|
||||
C.atg_transpose(ptr, self, cdim0, cdim1)
|
||||
}
|
||||
|
||||
// void atg_squeeze(tensor *, tensor self);
|
||||
func AtgSqueeze(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_squeeze(ptr, self)
|
||||
}
|
||||
|
|
|
@ -2242,3 +2242,29 @@ func (ts Tensor) MustTranspose(dim0, dim1 int64, del bool) (retVal Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Squeeze(del bool) (retVal Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgSqueeze(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSqueeze(del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Squeeze(del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user