312 lines
7.5 KiB
Go
312 lines
7.5 KiB
Go
package main
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
"path/filepath"
|
|
"sort"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/nn"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
|
|
)
|
|
|
|
const (
|
|
saveDir string = "../../data/yolo"
|
|
configName string = "yolo-v3.cfg"
|
|
confidenceThreshold float64 = 0.5
|
|
nmsThreshold float64 = 0.4
|
|
)
|
|
|
|
var (
|
|
model string
|
|
imageFile string
|
|
)
|
|
|
|
type Bbox struct {
|
|
xmin float64
|
|
ymin float64
|
|
xmax float64
|
|
ymax float64
|
|
confidence float64
|
|
classIndex uint
|
|
classConfidence float64
|
|
}
|
|
|
|
type ByConfBbox []Bbox
|
|
|
|
// Implement sort.Interface for []Bbox on Bbox.confidence:
|
|
// =====================================================
|
|
func (bb ByConfBbox) Len() int { return len(bb) }
|
|
func (bb ByConfBbox) Less(i, j int) bool { return bb[i].confidence < bb[j].confidence }
|
|
func (bb ByConfBbox) Swap(i, j int) { bb[i], bb[j] = bb[j], bb[i] }
|
|
|
|
// Intersection over union of two bounding boxes.
|
|
func Iou(b1, b2 Bbox) (retVal float64) {
|
|
b1Area := (b1.xmax - b1.xmin + 1.0) * (b1.ymax - b1.ymin + 1.0)
|
|
b2Area := (b2.xmax - b2.xmin + 1.0) * (b2.ymax - b2.ymin + 1.0)
|
|
|
|
iXmin := math.Max(b1.xmin, b2.xmin)
|
|
iXmax := math.Min(b1.xmax, b2.xmax)
|
|
iYmin := math.Max(b1.ymin, b2.ymin)
|
|
iYmax := math.Min(b1.ymax, b2.ymax)
|
|
|
|
iArea := math.Max((iXmax-iXmin+1.0), 0.0) * math.Max((iYmax-iYmin+1.0), 0)
|
|
|
|
return (iArea) / (b1Area + b2Area - iArea)
|
|
}
|
|
|
|
// Assuming x1 <= x2 and y1 <= y2
|
|
func drawRect(t *ts.Tensor, x1, x2, y1, y2 int64) {
|
|
color := ts.MustOfSlice([]float64{0.0, 0.0, 1.0}).MustView([]int64{3, 1, 1}, true)
|
|
|
|
// NOTE: `narrow` will create a tensor (view) that share same storage with
|
|
// original one.
|
|
tmp1 := t.MustNarrow(2, x1, x2-x1, false)
|
|
tmp2 := tmp1.MustNarrow(1, y1, y2-y1, true)
|
|
tmp2.Copy_(color)
|
|
tmp2.MustDrop()
|
|
color.MustDrop()
|
|
}
|
|
|
|
func drawLabel(t *ts.Tensor, text []string, x, y int64) {
|
|
device, err := t.Device()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
label := textToImageTs(text).MustTo(device, true)
|
|
|
|
labelSize := label.MustSize()
|
|
height := labelSize[1]
|
|
width := labelSize[2]
|
|
|
|
imageSize := t.MustSize()
|
|
lenY := height
|
|
if lenY > imageSize[1] {
|
|
lenY = imageSize[1] - y
|
|
}
|
|
|
|
lenX := width
|
|
if lenX > imageSize[2] {
|
|
lenX = imageSize[2] - x
|
|
}
|
|
|
|
// NOTE: `narrow` will create a tensor (view) that share same storage with
|
|
// original one.
|
|
|
|
tmp1 := t.MustNarrow(2, x, lenX, false)
|
|
tmp2 := tmp1.MustNarrow(1, y, lenY, true)
|
|
tmp2.Copy_(label)
|
|
tmp2.MustDrop()
|
|
label.MustDrop()
|
|
}
|
|
|
|
func report(pred *ts.Tensor, img *ts.Tensor, w int64, h int64) *ts.Tensor {
|
|
size2, err := pred.Size2()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
npreds := size2[0]
|
|
predSize := size2[1]
|
|
|
|
nclasses := uint(predSize - 5)
|
|
|
|
// The bounding boxes grouped by (maximum) class index.
|
|
var bboxes [][]Bbox = make([][]Bbox, int(nclasses))
|
|
|
|
// Extract the bounding boxes for which confidence is above the threshold.
|
|
for index := 0; index < int(npreds); index++ {
|
|
predIdx := pred.MustGet(index)
|
|
var predVals []float64 = predIdx.Float64Values()
|
|
predIdx.MustDrop()
|
|
|
|
confidence := predVals[4]
|
|
if confidence > confidenceThreshold {
|
|
classIndex := 0
|
|
for i := 0; i < int(nclasses); i++ {
|
|
if predVals[5+i] > predVals[5+classIndex] {
|
|
classIndex = i
|
|
}
|
|
}
|
|
|
|
if predVals[classIndex+5] > 0.0 {
|
|
bbox := Bbox{
|
|
xmin: predVals[0] - (predVals[2] / 2.0),
|
|
ymin: predVals[1] - (predVals[3] / 2.0),
|
|
xmax: predVals[0] + (predVals[2] / 2.0),
|
|
ymax: predVals[1] + (predVals[3] / 2.0),
|
|
confidence: confidence,
|
|
classIndex: uint(classIndex),
|
|
classConfidence: predVals[5+classIndex],
|
|
}
|
|
|
|
bboxes[classIndex] = append(bboxes[classIndex], bbox)
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
// Perform non-maximum suppression.
|
|
var bboxesRes [][]Bbox
|
|
for _, bboxesForClass := range bboxes {
|
|
// 1. Sort by confidence
|
|
sort.Sort(ByConfBbox(bboxesForClass))
|
|
|
|
// 2.
|
|
var currentIndex = 0
|
|
for index := 0; index < len(bboxesForClass); index++ {
|
|
drop := false
|
|
for predIndex := 0; predIndex < currentIndex; predIndex++ {
|
|
iou := Iou(bboxesForClass[predIndex], bboxesForClass[index])
|
|
if iou > nmsThreshold {
|
|
drop = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !drop {
|
|
// swap
|
|
bboxesForClass[currentIndex], bboxesForClass[index] = bboxesForClass[index], bboxesForClass[currentIndex]
|
|
currentIndex += 1
|
|
}
|
|
}
|
|
// 3. Truncate at currentIndex (exclusive)
|
|
if currentIndex < len(bboxesForClass) {
|
|
bboxesForClass = append(bboxesForClass[:currentIndex])
|
|
}
|
|
|
|
bboxesRes = append(bboxesRes, bboxesForClass)
|
|
}
|
|
|
|
// Annotate the original image and print boxes information.
|
|
size3, err := img.Size3()
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
initialH := size3[1]
|
|
initialW := size3[2]
|
|
|
|
imageTmp := img.MustTotype(gotch.Float, false)
|
|
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), true)
|
|
|
|
var wRatio float64 = float64(initialW) / float64(w)
|
|
var hRatio float64 = float64(initialH) / float64(h)
|
|
|
|
for classIndex, bboxesForClass := range bboxesRes {
|
|
for _, b := range bboxesForClass {
|
|
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
|
|
|
|
xmin := min(max(int64(b.xmin*wRatio), 0), (initialW - 1))
|
|
ymin := min(max(int64(b.ymin*hRatio), 0), (initialH - 1))
|
|
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
|
|
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
|
|
|
|
// draw rect
|
|
drawRect(image, xmin, xmax, ymin, min(ymax, ymin+2))
|
|
drawRect(image, xmin, xmax, max(ymin, ymax-2), ymax)
|
|
drawRect(image, xmin, min(xmax, xmin+2), ymin, ymax)
|
|
drawRect(image, max(xmin, xmax-2), xmax, ymin, ymax)
|
|
|
|
label := fmt.Sprintf("%v; %.3f\n", CocoClasses[classIndex], b.confidence)
|
|
drawLabel(image, []string{label}, xmin, ymin-15)
|
|
}
|
|
}
|
|
|
|
imgTmp := image.MustMul1(ts.FloatScalar(255.0), true)
|
|
retVal := imgTmp.MustTotype(gotch.Uint8, true)
|
|
|
|
return retVal
|
|
}
|
|
|
|
func init() {
|
|
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
|
|
flag.StringVar(&imageFile, "image", "../../data/yolo/bondi.jpg", "image file to infer")
|
|
}
|
|
|
|
func main() {
|
|
|
|
flag.Parse()
|
|
configPath, err := filepath.Abs(configName)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
modelPath, err := filepath.Abs(model)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
imagePath, err := filepath.Abs(imageFile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
var darknet *Darknet = ParseConfig(configPath)
|
|
|
|
vs := nn.NewVarStore(gotch.CPU)
|
|
model := darknet.BuildModel(vs.Root())
|
|
|
|
err = vs.Load(modelPath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Println("Yolo weights loaded.")
|
|
|
|
originalImage, err := vision.Load(imagePath)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
fmt.Println("Image file loaded")
|
|
|
|
netHeight := darknet.Height()
|
|
netWidth := darknet.Width()
|
|
|
|
imgClone := originalImage.MustShallowClone().MustDetach(false)
|
|
|
|
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
|
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
|
img := imgTmp2.MustDivScalar(ts.FloatScalar(255.0), true)
|
|
predictTmp := model.ForwardT(img, false)
|
|
|
|
predictions := predictTmp.MustSqueeze(true)
|
|
|
|
imgRes := report(predictions, originalImage, netWidth, netHeight)
|
|
|
|
savePath, err := filepath.Abs(saveDir)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
inputFile := filepath.Base(imagePath)
|
|
saveFile := fmt.Sprintf("%v/yolo_%v", savePath, inputFile)
|
|
err = vision.Save(imgRes, saveFile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func max(v1, v2 int64) (retVal int64) {
|
|
if v1 > v2 {
|
|
return v1
|
|
}
|
|
|
|
return v2
|
|
}
|
|
|
|
func min(v1, v2 int64) (retVal int64) {
|
|
if v1 < v2 {
|
|
return v1
|
|
}
|
|
|
|
return v2
|
|
}
|