WIP(example/yolo)

This commit is contained in:
sugarme 2020-07-14 18:25:06 +10:00
parent 9ce34e697a
commit f9f71c7163
4 changed files with 117 additions and 30 deletions

View File

@ -1,6 +1,8 @@
package main
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
@ -11,4 +13,9 @@ func main() {
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
tensor.Print()
nilTs := ts.NewTensor()
fmt.Printf("nilTs val: %v", nilTs)
}

View File

@ -143,12 +143,21 @@ func ParseConfig(path string) (retVal Darknet) {
}
type (
Layer = ts.ModuleT
Route = []uint
Shortcut = uint
Yolo struct {
Layer struct {
Val nn.FuncT
}
Route struct {
TsIdxs []uint
}
Shortcut struct {
TsIdx uint // tensor index
}
Anchor []int64
Yolo struct {
Classes int64
Anchors []int64
Anchors []Anchor
}
ChannelsBl struct {
@ -248,7 +257,7 @@ func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 inte
return res
})
return filters, fn
return filters, Layer{Val: fn}
}
func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
@ -264,7 +273,7 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0)
})
return prevChannels, layer
return prevChannels, Layer{Val: layer}
}
func intListOfString(s string) (retVal []int64) {
@ -302,7 +311,7 @@ func route(index uint, p []ChannelsBl, blk Block) (retVal1 int64, retVal2 interf
channels += p[l].Channels
}
return channels, layers
return channels, Route{TsIdxs: layers}
}
func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
@ -313,7 +322,7 @@ func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{
log.Fatal(err)
}
return p, uintOfIndex(index, from)
return p, Shortcut{TsIdx: uintOfIndex(index, from)}
}
func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
@ -338,9 +347,9 @@ func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
intMask := intListOfString(blk.get("mask"))
var retAnchors []int64
var retAnchors []Anchor
for _, i := range intMask {
retAnchors = append(retAnchors, anchors[i]...)
retAnchors = append(retAnchors, anchors[i])
}
return p, Yolo{Classes: classes, Anchors: retAnchors}
@ -357,8 +366,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
// slice.MustDrop()
}
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (retVal ts.Tensor) {
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
size4, err := xs.Size4()
if err != nil {
log.Fatal(err)
@ -371,11 +379,19 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (re
bboxAttrs := classes + 5
nanchors := int64(len(anchors))
// fmt.Printf("xs shape: %v\n", xs.MustSize())
// fmt.Printf("bsize: %v\n", bsize)
// fmt.Printf("bboxAttrs: %v\n", bboxAttrs)
// fmt.Printf("nanchors: %v\n", nanchors)
// fmt.Printf("gridSize: %v\n", gridSize)
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
tmp2 := tmp1.MustTranspose(1, 2, true)
tmp3 := tmp2.MustContiguous(true)
xsTs := tmp3.MustView([]int64{bsize, bboxAttrs * gridSize * nanchors, bboxAttrs}, true)
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
a := grid.MustRepeat([]int64{gridSize, 1}, false)
@ -389,7 +405,12 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (re
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
anchorsTmp1 := ts.MustOfSlice(anchors)
var flatAnchors []int64
for _, a := range anchors {
flatAnchors = append(flatAnchors, a...)
}
anchorsTmp1 := ts.MustOfSlice(flatAnchors)
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
@ -482,17 +503,17 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
if len(prevYs) > 0 {
xsTs = prevYs[len(prevYs)-1] // last prevYs element
}
ysTs = layer.ForwardT(xsTs, train)
ysTs = layer.Val.ForwardT(xsTs, train)
case "Route":
layerIdxs := b.Bl.(Route)
route := b.Bl.(Route)
var layers []ts.Tensor
for _, i := range layerIdxs {
for _, i := range route.TsIdxs {
layers = append(layers, prevYs[int(i)])
}
ysTs = ts.MustCat(layers, 1, true)
ysTs = ts.MustCat(layers, 1, false)
case "Shortcut":
from := b.Bl.(Shortcut)
from := b.Bl.(Shortcut).TsIdx
addTs := prevYs[int(from)]
last := prevYs[len(prevYs)-1]
ysTs = last.MustAdd(addTs, false) // TODO: Should we delete it?
@ -500,23 +521,28 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
case "Yolo":
classes := b.Bl.(Yolo).Classes
anchors := b.Bl.(Yolo).Anchors
last := xs
xsTs := xs
if len(prevYs) > 0 {
last = prevYs[len(prevYs)-1]
xsTs = prevYs[len(prevYs)-1]
}
dt := detect(last, imageHeight, classes, anchors)
dt := detect(xsTs, imageHeight, classes, anchors)
detections = append(detections, dt)
ysTs = ts.NewTensor()
default:
log.Fatalf("Unsupported block type: %v\n", blkTyp.Name())
// log.Fatalf("BuildModel - FuncT - Unsupported block type: %v\n", blkTyp.Name())
} // end of Switch
prevYs = append(prevYs, ysTs)
} // end of For loop
return ts.MustCat(detections, 1, true)
})
res = ts.MustCat(detections, 1, true)
return res
}) // end of NewFuncT
return retVal
}

View File

@ -1,27 +1,46 @@
package main
import (
"flag"
"fmt"
// "flag"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
"log"
"path/filepath"
)
const configName = "yolo-v3.cfg"
func init() {
var (
model string
image string
)
func init() {
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
flag.StringVar(&image, "image", "../../data/yolo/bondi.jpg", "image file to infer")
}
func main() {
flag.Parse()
configPath, err := filepath.Abs(configName)
if err != nil {
log.Fatal(err)
}
modelPath, err := filepath.Abs(model)
if err != nil {
log.Fatal(err)
}
imagePath, err := filepath.Abs(image)
if err != nil {
log.Fatal(err)
}
var darknet Darknet = ParseConfig(configPath)
fmt.Printf("darknet number of parameters: %v\n", len(darknet.Parameters))
@ -29,6 +48,41 @@ func main() {
vs := nn.NewVarStore(gotch.CPU)
model := darknet.BuildModel(vs.Root())
fmt.Printf("Model: %v\n", model)
fmt.Printf("Image path: %v\n", imagePath)
err = vs.Load(modelPath)
if err != nil {
log.Fatal(err)
}
fmt.Println("Yolo weights loaded.")
originalImage, err := vision.Load(imagePath)
if err != nil {
log.Fatal(err)
}
fmt.Println("Image file loaded")
fmt.Printf("Image shape: %v\n", originalImage.MustSize())
netHeight := darknet.Height()
netWidth := darknet.Width()
fmt.Printf("net Height: %v\n", netHeight)
fmt.Printf("net Width: %v\n", netWidth)
imageTs, err := vision.Resize(originalImage, netWidth, netHeight)
if err != nil {
log.Fatal(err)
}
fmt.Printf("imageTs shape: %v\n", imageTs.MustSize())
imgTmp1 := imageTs.MustUnsqueeze(0, true)
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
predictTmp := model.ForwardT(img, false)
// predictions := predictTmp.MustSqueeze(true)
fmt.Printf("predictTmp shape: %v\n", predictTmp.MustSize())
}

View File

@ -82,7 +82,7 @@ func Save(tensor ts.Tensor, path string) (err error) {
// This expects as input a tensor of shape [channel, height, width] and returns
// a tensor of shape [channel, out_h, out_w].
func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
tmpTs, err := ts.ResizeHwc(t, outW, outH)
tmpTs, err := ts.ResizeHwc(chwToHWC(t), outW, outH)
if err != nil {
return retVal, err
}