WIP(example/yolo)
This commit is contained in:
parent
9ce34e697a
commit
f9f71c7163
|
@ -1,6 +1,8 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -11,4 +13,9 @@ func main() {
|
|||
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
|
||||
|
||||
tensor.Print()
|
||||
|
||||
nilTs := ts.NewTensor()
|
||||
|
||||
fmt.Printf("nilTs val: %v", nilTs)
|
||||
|
||||
}
|
||||
|
|
|
@ -143,12 +143,21 @@ func ParseConfig(path string) (retVal Darknet) {
|
|||
}
|
||||
|
||||
type (
|
||||
Layer = ts.ModuleT
|
||||
Route = []uint
|
||||
Shortcut = uint
|
||||
Yolo struct {
|
||||
Layer struct {
|
||||
Val nn.FuncT
|
||||
}
|
||||
Route struct {
|
||||
TsIdxs []uint
|
||||
}
|
||||
Shortcut struct {
|
||||
TsIdx uint // tensor index
|
||||
}
|
||||
|
||||
Anchor []int64
|
||||
|
||||
Yolo struct {
|
||||
Classes int64
|
||||
Anchors []int64
|
||||
Anchors []Anchor
|
||||
}
|
||||
|
||||
ChannelsBl struct {
|
||||
|
@ -248,7 +257,7 @@ func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 inte
|
|||
return res
|
||||
})
|
||||
|
||||
return filters, fn
|
||||
return filters, Layer{Val: fn}
|
||||
}
|
||||
|
||||
func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
|
||||
|
@ -264,7 +273,7 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
|
|||
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0)
|
||||
})
|
||||
|
||||
return prevChannels, layer
|
||||
return prevChannels, Layer{Val: layer}
|
||||
}
|
||||
|
||||
func intListOfString(s string) (retVal []int64) {
|
||||
|
@ -302,7 +311,7 @@ func route(index uint, p []ChannelsBl, blk Block) (retVal1 int64, retVal2 interf
|
|||
channels += p[l].Channels
|
||||
}
|
||||
|
||||
return channels, layers
|
||||
return channels, Route{TsIdxs: layers}
|
||||
}
|
||||
|
||||
func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
|
||||
|
@ -313,7 +322,7 @@ func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return p, uintOfIndex(index, from)
|
||||
return p, Shortcut{TsIdx: uintOfIndex(index, from)}
|
||||
}
|
||||
|
||||
func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
|
||||
|
@ -338,9 +347,9 @@ func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
|
|||
|
||||
intMask := intListOfString(blk.get("mask"))
|
||||
|
||||
var retAnchors []int64
|
||||
var retAnchors []Anchor
|
||||
for _, i := range intMask {
|
||||
retAnchors = append(retAnchors, anchors[i]...)
|
||||
retAnchors = append(retAnchors, anchors[i])
|
||||
}
|
||||
|
||||
return p, Yolo{Classes: classes, Anchors: retAnchors}
|
||||
|
@ -357,8 +366,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
|
|||
// slice.MustDrop()
|
||||
}
|
||||
|
||||
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (retVal ts.Tensor) {
|
||||
|
||||
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
|
||||
size4, err := xs.Size4()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -371,11 +379,19 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (re
|
|||
bboxAttrs := classes + 5
|
||||
nanchors := int64(len(anchors))
|
||||
|
||||
// fmt.Printf("xs shape: %v\n", xs.MustSize())
|
||||
// fmt.Printf("bsize: %v\n", bsize)
|
||||
// fmt.Printf("bboxAttrs: %v\n", bboxAttrs)
|
||||
// fmt.Printf("nanchors: %v\n", nanchors)
|
||||
// fmt.Printf("gridSize: %v\n", gridSize)
|
||||
|
||||
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
|
||||
|
||||
tmp2 := tmp1.MustTranspose(1, 2, true)
|
||||
|
||||
tmp3 := tmp2.MustContiguous(true)
|
||||
|
||||
xsTs := tmp3.MustView([]int64{bsize, bboxAttrs * gridSize * nanchors, bboxAttrs}, true)
|
||||
xsTs := tmp3.MustView([]int64{bsize, gridSize * gridSize * nanchors, bboxAttrs}, true)
|
||||
|
||||
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
|
||||
a := grid.MustRepeat([]int64{gridSize, 1}, false)
|
||||
|
@ -389,7 +405,12 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (re
|
|||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(anchors)
|
||||
var flatAnchors []int64
|
||||
for _, a := range anchors {
|
||||
flatAnchors = append(flatAnchors, a...)
|
||||
}
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(flatAnchors)
|
||||
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
|
||||
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
|
||||
|
@ -482,17 +503,17 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
if len(prevYs) > 0 {
|
||||
xsTs = prevYs[len(prevYs)-1] // last prevYs element
|
||||
}
|
||||
ysTs = layer.ForwardT(xsTs, train)
|
||||
ysTs = layer.Val.ForwardT(xsTs, train)
|
||||
case "Route":
|
||||
layerIdxs := b.Bl.(Route)
|
||||
route := b.Bl.(Route)
|
||||
var layers []ts.Tensor
|
||||
for _, i := range layerIdxs {
|
||||
for _, i := range route.TsIdxs {
|
||||
layers = append(layers, prevYs[int(i)])
|
||||
}
|
||||
ysTs = ts.MustCat(layers, 1, true)
|
||||
ysTs = ts.MustCat(layers, 1, false)
|
||||
|
||||
case "Shortcut":
|
||||
from := b.Bl.(Shortcut)
|
||||
from := b.Bl.(Shortcut).TsIdx
|
||||
addTs := prevYs[int(from)]
|
||||
last := prevYs[len(prevYs)-1]
|
||||
ysTs = last.MustAdd(addTs, false) // TODO: Should we delete it?
|
||||
|
@ -500,23 +521,28 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
|||
case "Yolo":
|
||||
classes := b.Bl.(Yolo).Classes
|
||||
anchors := b.Bl.(Yolo).Anchors
|
||||
last := xs
|
||||
xsTs := xs
|
||||
|
||||
if len(prevYs) > 0 {
|
||||
last = prevYs[len(prevYs)-1]
|
||||
xsTs = prevYs[len(prevYs)-1]
|
||||
}
|
||||
dt := detect(last, imageHeight, classes, anchors)
|
||||
|
||||
dt := detect(xsTs, imageHeight, classes, anchors)
|
||||
|
||||
detections = append(detections, dt)
|
||||
ysTs = ts.NewTensor()
|
||||
|
||||
default:
|
||||
log.Fatalf("Unsupported block type: %v\n", blkTyp.Name())
|
||||
// log.Fatalf("BuildModel - FuncT - Unsupported block type: %v\n", blkTyp.Name())
|
||||
} // end of Switch
|
||||
|
||||
prevYs = append(prevYs, ysTs)
|
||||
} // end of For loop
|
||||
|
||||
return ts.MustCat(detections, 1, true)
|
||||
})
|
||||
res = ts.MustCat(detections, 1, true)
|
||||
|
||||
return res
|
||||
}) // end of NewFuncT
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -1,27 +1,46 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
// "flag"
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
"log"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
const configName = "yolo-v3.cfg"
|
||||
|
||||
func init() {
|
||||
var (
|
||||
model string
|
||||
image string
|
||||
)
|
||||
|
||||
func init() {
|
||||
flag.StringVar(&model, "model", "../../data/yolo/yolo-v3.pt", "Yolo model weights file")
|
||||
flag.StringVar(&image, "image", "../../data/yolo/bondi.jpg", "image file to infer")
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
flag.Parse()
|
||||
configPath, err := filepath.Abs(configName)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
modelPath, err := filepath.Abs(model)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
imagePath, err := filepath.Abs(image)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var darknet Darknet = ParseConfig(configPath)
|
||||
|
||||
fmt.Printf("darknet number of parameters: %v\n", len(darknet.Parameters))
|
||||
|
@ -29,6 +48,41 @@ func main() {
|
|||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
model := darknet.BuildModel(vs.Root())
|
||||
|
||||
fmt.Printf("Model: %v\n", model)
|
||||
fmt.Printf("Image path: %v\n", imagePath)
|
||||
|
||||
err = vs.Load(modelPath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("Yolo weights loaded.")
|
||||
|
||||
originalImage, err := vision.Load(imagePath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println("Image file loaded")
|
||||
fmt.Printf("Image shape: %v\n", originalImage.MustSize())
|
||||
|
||||
netHeight := darknet.Height()
|
||||
netWidth := darknet.Width()
|
||||
|
||||
fmt.Printf("net Height: %v\n", netHeight)
|
||||
fmt.Printf("net Width: %v\n", netWidth)
|
||||
|
||||
imageTs, err := vision.Resize(originalImage, netWidth, netHeight)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("imageTs shape: %v\n", imageTs.MustSize())
|
||||
|
||||
imgTmp1 := imageTs.MustUnsqueeze(0, true)
|
||||
imgTmp2 := imgTmp1.MustTotype(gotch.Float, true)
|
||||
img := imgTmp2.MustDiv1(ts.FloatScalar(255.0), true)
|
||||
predictTmp := model.ForwardT(img, false)
|
||||
// predictions := predictTmp.MustSqueeze(true)
|
||||
fmt.Printf("predictTmp shape: %v\n", predictTmp.MustSize())
|
||||
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ func Save(tensor ts.Tensor, path string) (err error) {
|
|||
// This expects as input a tensor of shape [channel, height, width] and returns
|
||||
// a tensor of shape [channel, out_h, out_w].
|
||||
func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
tmpTs, err := ts.ResizeHwc(t, outW, outH)
|
||||
tmpTs, err := ts.ResizeHwc(chwToHWC(t), outW, outH)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user