
This commit is contained in:
sugarme 2020-07-15 13:42:40 +10:00
parent f9f71c7163
commit 530acf0894
5 changed files with 330 additions and 17 deletions

View File

@ -0,0 +1,84 @@
package main
var CocoClasses []string = []string{
"traffic light",
"fire hydrant",
"stop sign",
"parking meter",
"sports ball",
"baseball bat",
"baseball glove",
"tennis racket",
"wine glass",
"hot dog",
"cell phone",
"teddy bear",
"hair drier",

View File

@ -362,8 +362,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
// TODO: check whether we need to delete slice to prevent memory blow-up
// slice.MustDrop()
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
@ -386,35 +385,44 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
// fmt.Printf("gridSize: %v\n", gridSize)
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
tmp2 := tmp1.MustTranspose(1, 2, true)
tmp3 := tmp2.MustContiguous(true)
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
a := grid.MustRepeat([]int64{gridSize, 1}, false)
bTmp := a.MustT(false)
b := bTmp.MustContiguous(true)
xOffset := a.MustView([]int64{-1, 1}, false)
yOffset := b.MustView([]int64{-1, 1}, false)
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
fmt.Printf("xyOffset shape: %v\n", xyOffset.MustSize())
var flatAnchors []int64
for _, a := range anchors {
flatAnchors = append(flatAnchors, a...)
anchorsTmp1 := ts.MustOfSlice(flatAnchors)
var anchorVals []float32
for _, a := range flatAnchors {
v := float32(a) / float32(stride)
anchorVals = append(anchorVals, v)
fmt.Printf("anchors: %v\n", anchorVals)
anchorsTmp1 := ts.MustOfSlice(anchorVals)
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
fmt.Printf("anchors ts shape: %v\n", anchorsTs.MustSize())
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
tmp := xs.MustSigmoid(false)
res = tmp.MustAdd(xyOffset, true)
@ -436,7 +444,6 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
// TODO: delete all middle tensors.
return xsTs

View File

@ -8,16 +8,190 @@ import (
ts "github.com/sugarme/gotch/tensor"
const configName = "yolo-v3.cfg"
const (
configName string = "yolo-v3.cfg"
confidenceThreshold float64 = 0.5
nmsThreshold float64 = 0.4
var (
model string
image string
type Bbox struct {
xmin float64
ymin float64
xmax float64
ymax float64
confidence float64
classIndex uint
classConfidence float64
type ByConfBbox []Bbox
// Implement sort.Interface for []Bbox on Bbox.confidence:
// =====================================================
func (bb ByConfBbox) Len() int { return len(bb) }
func (bb ByConfBbox) Less(i, j int) bool { return bb[i].confidence < bb[j].confidence }
func (bb ByConfBbox) Swap(i, j int) { bb[i], bb[j] = bb[j], bb[i] }
// Intersection over union of two bounding boxes.
func Iou(b1, b2 Bbox) (retVal float64) {
b1Area := (b1.xmax - b1.xmin + 1.0) * (b1.ymax - b1.ymin + 1.0)
b2Area := (b2.xmax - b2.xmin + 1.0) * (b2.ymax - b2.ymin + 1.0)
iXmin := math.Max(b1.xmin, b2.xmin)
iXmax := math.Min(b1.xmax, b2.xmax)
iYmin := math.Max(b1.ymin, b2.ymin)
iYmax := math.Min(b1.ymax, b2.ymax)
iArea := math.Max((iXmax-iXmin+1.0), 0.0) * math.Max((iYmax-iYmin+1.0), 0)
return (iArea) / (b1Area + b2Area - iArea)
// Assuming x1 <= x2 and y1 <= y2
func drawRect(t ts.Tensor, x1, x2, y1, y2 int64) {
color := ts.MustOfSlice([]float64{0.0, 0.0, 1.0}).MustView([]int64{3, 1, 1}, true)
tmp1 := t.MustNarrow(2, x1, x2-x1, false)
tmp2 := tmp1.MustNarrow(1, y1, y2-y1, true)
func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor) {
size2, err := pred.Size2()
if err != nil {
npreds := size2[0]
predSize := size2[1]
fmt.Printf("npreds: %v - predSize: %v\n", npreds, predSize)
nclasses := uint(predSize - 5)
// The bounding boxes grouped by (maximum) class index.
var bboxes [][]Bbox = make([][]Bbox, int(nclasses))
// Extract the bounding boxes for which confidence is above the threshold.
for index := 0; index < int(npreds); index++ {
predIdx := pred.MustGet(index)
var predVals []float64 = predIdx.Values()
confidence := predVals[4]
if confidence > confidenceThreshold {
classIndex := 0
for i := 0; i < int(nclasses); i++ {
if predVals[5+i] > predVals[5+classIndex] {
classIndex = i
if predVals[classIndex+5] > 0.0 {
bbox := Bbox{
xmin: predVals[0] - (predVals[2] / 2.0),
ymin: predVals[1] - (predVals[3] / 2.0),
xmax: predVals[0] + (predVals[2] / 2.0),
ymax: predVals[1] + (predVals[3] / 2.0),
// xmin: (predVals[0] - predVals[2]) / 2.0,
// ymin: (predVals[1] - predVals[3]) / 2.0,
// xmax: (predVals[0] + predVals[2]) / 2.0,
// ymax: (predVals[1] + predVals[3]) / 2.0,
confidence: confidence,
classIndex: uint(classIndex),
classConfidence: predVals[5+classIndex],
bboxes[classIndex] = append(bboxes[classIndex], bbox)
// Perform non-maximum suppression.
var bboxesRes [][]Bbox
fmt.Printf("Num of bboxes: %v\n", len(bboxes))
for _, bboxesForClass := range bboxes {
// 1. Sort by confidence
// 2.
var currentIndex = 0
for index := 0; index < len(bboxesForClass); index++ {
drop := false
for predIndex := 0; predIndex < currentIndex; predIndex++ {
iou := Iou(bboxesForClass[predIndex], bboxesForClass[index])
if iou > nmsThreshold {
drop = true
if !drop {
// swap
bboxesForClass[currentIndex], bboxesForClass[index] = bboxesForClass[index], bboxesForClass[currentIndex]
currentIndex += 1
// 3. Truncate currentIndex
if len(bboxesForClass) > 1 {
bboxesForClass = append(bboxesForClass[:currentIndex], bboxesForClass[currentIndex+1:]...)
bboxesRes = append(bboxesRes, bboxesForClass)
fmt.Printf("Num of bboxesRes: %v\n", len(bboxesRes))
// Annotate the original image and print boxes information.
fmt.Printf("img shape: %v\n", img.MustSize())
size3, err := img.Size3()
if err != nil {
initialH := size3[1]
initialW := size3[2]
imageTmp := img.MustTotype(gotch.Float, false)
image := imageTmp.MustDiv1(ts.FloatScalar(255.0), true)
var wRatio float64 = float64(initialW) / float64(w)
var hRatio float64 = float64(initialH) / float64(h)
// fmt.Printf("wRatio: %v\n", wRatio)
// fmt.Printf("hRatio: %v\n", hRatio)
for classIndex, bboxesForClass := range bboxesRes {
for _, b := range bboxesForClass {
fmt.Printf("%v: %v\n", CocoClasses[classIndex], b)
xmin := min(max(int64(b.xmin*wRatio), 0), (initialW - 1))
ymin := min(max(int64(b.ymin*hRatio), 0), (initialH - 1))
xmax := min(max(int64(b.xmax*wRatio), 0), (initialW - 1))
ymax := min(max(int64(b.ymax*hRatio), 0), (initialH - 1))
// fmt.Printf("xmin: %v\t ymin: %v\t xmax: %v\t ymax: %v\n", xmin, ymin, xmax, ymax)
// draw rect
drawRect(img, xmin, xmax, ymin, int64(math.Min(float64(ymax), float64(ymin+2))))
drawRect(img, xmin, xmax, int64(math.Max(float64(ymin), float64(ymax-2))), ymax)
drawRect(img, xmin, int64(math.Min(float64(xmax), float64(xmin+2))), ymin, ymax)
drawRect(img, int64(math.Max(float64(xmax-2), float64(xmin))), xmax, ymin, ymax)
imgTmp := image.MustMul1(ts.FloatScalar(255.0), false)
retVal = imgTmp.MustTotype(gotch.Uint8, false)
return retVal
func init() {
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
flag.StringVar(&image, "image", "../../data/yolo/bondi.jpg", "image file to infer")
@ -63,26 +237,43 @@ func main() {
fmt.Println("Image file loaded")
fmt.Printf("Image shape: %v\n", originalImage.MustSize())
netHeight := darknet.Height()
netWidth := darknet.Width()
fmt.Printf("net Height: %v\n", netHeight)
fmt.Printf("net Width: %v\n", netWidth)
imgClone := originalImage.MustShallowClone().MustDetach()
imageTs, err := vision.Resize(originalImage, netWidth, netHeight)
imageTs, err := vision.Resize(imgClone, netWidth, netHeight)
if err != nil {
fmt.Printf("imageTs shape: %v\n", imageTs.MustSize())
imgTmp1 := imageTs.MustUnsqueeze(0, true)
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
predictTmp := model.ForwardT(img, false)
// predictions := predictTmp.MustSqueeze(true)
fmt.Printf("predictTmp shape: %v\n", predictTmp.MustSize())
predictions := predictTmp.MustSqueeze(true)
imgRes := report(predictions, originalImage, netWidth, netHeight)
err = vision.Save(imgRes, "image_result.jpg")
if err != nil {
func max(v1, v2 int64) (retVal int64) {
if v1 > v2 {
return v1
return v2
func min(v1, v2 int64) (retVal int64) {
if v1 < v2 {
return v1
return v2

View File

@ -731,3 +731,8 @@ func AtgTranspose(ptr *Ctensor, self Ctensor, dim0 int64, dim1 int64) {
C.atg_transpose(ptr, self, cdim0, cdim1)
// void atg_squeeze(tensor *, tensor self);
func AtgSqueeze(ptr *Ctensor, self Ctensor) {
C.atg_squeeze(ptr, self)

View File

@ -2242,3 +2242,29 @@ func (ts Tensor) MustTranspose(dim0, dim1 int64, del bool) (retVal Tensor) {
return retVal
func (ts Tensor) Squeeze(del bool) (retVal Tensor, err error) {
if del {
defer ts.MustDrop()
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgSqueeze(ptr, ts.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
retVal = Tensor{ctensor: *ptr}
return retVal, nil
func (ts Tensor) MustSqueeze(del bool) (retVal Tensor) {
retVal, err := ts.Squeeze(del)
if err != nil {
return retVal