WIP(example/yolo)
This commit is contained in:
parent
f16ab429c9
commit
9ce34e697a
|
@ -5,76 +5,78 @@ import (
|
|||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type block struct {
|
||||
blockType *string // optional
|
||||
parameters map[string]string
|
||||
type Block struct {
|
||||
BlockType *string // optional
|
||||
Parameters map[string]string
|
||||
}
|
||||
|
||||
func (b block) get(key string) (retVal string) {
|
||||
val, ok := b.parameters[key]
|
||||
func (b *Block) get(key string) (retVal string) {
|
||||
val, ok := b.Parameters[key]
|
||||
if !ok {
|
||||
log.Fatalf("Cannot find %v in net parameters.\n", key)
|
||||
log.Fatalf("Cannot find %v in Block parameters.\n", key)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
type Darknet struct {
|
||||
blocks []block
|
||||
parameters map[string]string
|
||||
Blocks []Block
|
||||
Parameters map[string]string
|
||||
}
|
||||
|
||||
func (d Darknet) get(key string) (retVal string) {
|
||||
val, ok := d.parameters[key]
|
||||
val, ok := d.Parameters[key]
|
||||
if !ok {
|
||||
log.Fatalf("Cannot find %v in net parameters.\n", key)
|
||||
log.Fatalf("Cannot find %v in Darknet parameters.\n", key)
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
type accumulator struct {
|
||||
parameters map[string]string
|
||||
net Darknet
|
||||
blockType *string // optional
|
||||
type Accumulator struct {
|
||||
Parameters map[string]string
|
||||
Net Darknet
|
||||
BlockType *string // optional
|
||||
}
|
||||
|
||||
func newAccumulator() (retVal accumulator) {
|
||||
func newAccumulator() (retVal Accumulator) {
|
||||
|
||||
return accumulator{
|
||||
blockType: nil,
|
||||
parameters: make(map[string]string, 0),
|
||||
net: Darknet{
|
||||
blocks: make([]block, 0),
|
||||
parameters: make(map[string]string, 0),
|
||||
return Accumulator{
|
||||
BlockType: nil,
|
||||
Parameters: make(map[string]string, 0),
|
||||
Net: Darknet{
|
||||
Blocks: make([]Block, 0),
|
||||
Parameters: make(map[string]string, 0),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (acc *accumulator) finishBlock() {
|
||||
if acc.blockType != nil {
|
||||
if *acc.blockType == "net" {
|
||||
acc.net.parameters = acc.parameters
|
||||
func (acc *Accumulator) finishBlock() {
|
||||
if acc.BlockType != nil {
|
||||
if *acc.BlockType == "net" {
|
||||
acc.Net.Parameters = acc.Parameters
|
||||
} else {
|
||||
block := block{
|
||||
blockType: acc.blockType,
|
||||
parameters: acc.parameters,
|
||||
block := Block{
|
||||
BlockType: acc.BlockType,
|
||||
Parameters: acc.Parameters,
|
||||
}
|
||||
acc.net.blocks = append(acc.net.blocks, block)
|
||||
acc.Net.Blocks = append(acc.Net.Blocks, block)
|
||||
}
|
||||
|
||||
// clear parameters
|
||||
acc.parameters = make(map[string]string, 0)
|
||||
acc.Parameters = make(map[string]string, 0)
|
||||
}
|
||||
|
||||
acc.blockType = nil
|
||||
acc.BlockType = nil
|
||||
}
|
||||
|
||||
func ParseConfig(path string) (retVal Darknet) {
|
||||
|
@ -101,24 +103,24 @@ func ParseConfig(path string) (retVal Darknet) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
|
||||
for _, ln := range lines {
|
||||
line := ln
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
line = strings.TrimSpace(line)
|
||||
|
||||
line = strings.TrimSpace(line) // trim all spaces before and after
|
||||
line = strings.ReplaceAll(line, " ", "") // trim all space in between
|
||||
if strings.HasPrefix(line, "[") {
|
||||
// make sure line ends with "]"
|
||||
if !strings.HasSuffix(line, "]") {
|
||||
log.Fatalf("Line doesn't end with ']'\n")
|
||||
}
|
||||
|
||||
line = strings.TrimPrefix(line, "[")
|
||||
line = strings.TrimSuffix(line, "]")
|
||||
|
||||
acc.finishBlock()
|
||||
acc.blockType = &line
|
||||
acc.BlockType = &line
|
||||
|
||||
} else {
|
||||
var keyValue []string
|
||||
keyValue = strings.Split(line, "=")
|
||||
|
@ -131,14 +133,13 @@ func ParseConfig(path string) (retVal Darknet) {
|
|||
// log.Fatalf("Multiple values for key - %v\n", line)
|
||||
// }
|
||||
|
||||
acc.parameters[keyValue[0]] = keyValue[1]
|
||||
|
||||
acc.Parameters[keyValue[0]] = keyValue[1]
|
||||
}
|
||||
} // end of for
|
||||
|
||||
acc.finishBlock()
|
||||
|
||||
return acc.net
|
||||
return acc.Net
|
||||
}
|
||||
|
||||
type (
|
||||
|
@ -146,17 +147,17 @@ type (
|
|||
Route = []uint
|
||||
Shortcut = uint
|
||||
Yolo struct {
|
||||
Val1 int64
|
||||
V2 []int64
|
||||
Classes int64
|
||||
Anchors []int64
|
||||
}
|
||||
|
||||
Param struct {
|
||||
Val1 int64
|
||||
Val2 interface{}
|
||||
ChannelsBl struct {
|
||||
Channels int64
|
||||
Bl interface{}
|
||||
}
|
||||
)
|
||||
|
||||
func conv(vs nn.Path, index uint, p int64, b block) (retVal1 int64, retVal2 interface{}) {
|
||||
func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 interface{}) {
|
||||
|
||||
activation := b.get("activation")
|
||||
|
||||
|
@ -190,7 +191,7 @@ func conv(vs nn.Path, index uint, p int64, b block) (retVal1 int64, retVal2 inte
|
|||
bn *nn.BatchNorm
|
||||
bias bool
|
||||
)
|
||||
if pStr, ok := b.parameters["batch_normalize"]; ok {
|
||||
if pStr, ok := b.Parameters["batch_normalize"]; ok {
|
||||
p, err := strconv.ParseInt(pStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -288,7 +289,7 @@ func uintOfIndex(index uint, i int64) (retVal uint) {
|
|||
}
|
||||
}
|
||||
|
||||
func route(index uint, p []Param, blk block) (retVal1 int64, retVal2 interface{}) {
|
||||
func route(index uint, p []ChannelsBl, blk Block) (retVal1 int64, retVal2 interface{}) {
|
||||
intLayers := intListOfString(blk.get("layers"))
|
||||
|
||||
var layers []uint
|
||||
|
@ -298,13 +299,13 @@ func route(index uint, p []Param, blk block) (retVal1 int64, retVal2 interface{}
|
|||
|
||||
var channels int64
|
||||
for _, l := range layers {
|
||||
channels += p[l].Val1
|
||||
channels += p[l].Channels
|
||||
}
|
||||
|
||||
return channels, layers
|
||||
}
|
||||
|
||||
func shortcut(index uint, p int64, blk block) (retVal1 int64, retVal2 interface{}) {
|
||||
func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
|
||||
fromStr := blk.get("from")
|
||||
|
||||
from, err := strconv.ParseInt(fromStr, 10, 64)
|
||||
|
@ -315,13 +316,14 @@ func shortcut(index uint, p int64, blk block) (retVal1 int64, retVal2 interface{
|
|||
return p, uintOfIndex(index, from)
|
||||
}
|
||||
|
||||
func yolo(p int64, blk block) (retVal1 int64, retVal2 interface{}) {
|
||||
func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
|
||||
classesStr := blk.get("classes")
|
||||
classes, err := strconv.ParseInt(classesStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// anchorsStr := blk.get("anchors")
|
||||
flat := intListOfString(blk.get("anchors"))
|
||||
|
||||
if (len(flat) % 2) != 0 {
|
||||
|
@ -336,12 +338,12 @@ func yolo(p int64, blk block) (retVal1 int64, retVal2 interface{}) {
|
|||
|
||||
intMask := intListOfString(blk.get("mask"))
|
||||
|
||||
var retAnchors [][]int64
|
||||
var retAnchors []int64
|
||||
for _, i := range intMask {
|
||||
retAnchors = append(retAnchors, anchors[i])
|
||||
retAnchors = append(retAnchors, anchors[i]...)
|
||||
}
|
||||
|
||||
return p, retAnchors
|
||||
return p, Yolo{Classes: classes, Anchors: retAnchors}
|
||||
}
|
||||
|
||||
// Apply f to a slice of tensor xs and replace xs values with f output.
|
||||
|
@ -355,7 +357,166 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
|
|||
// slice.MustDrop()
|
||||
}
|
||||
|
||||
// TODO: continue
|
||||
// func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors
|
||||
// [][]int64) (retVal ts.Tensor){
|
||||
// }
|
||||
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []int64) (retVal ts.Tensor) {
|
||||
|
||||
size4, err := xs.Size4()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
bsize := size4[0]
|
||||
height := size4[2]
|
||||
|
||||
stride := imageHeight / height
|
||||
gridSize := imageHeight / stride
|
||||
bboxAttrs := classes + 5
|
||||
nanchors := int64(len(anchors))
|
||||
|
||||
tmp1 := xs.MustView([]int64{bsize, bboxAttrs * nanchors, gridSize * gridSize}, false)
|
||||
tmp2 := tmp1.MustTranspose(1, 2, true)
|
||||
tmp3 := tmp2.MustContiguous(true)
|
||||
|
||||
xsTs := tmp3.MustView([]int64{bsize, bboxAttrs * gridSize * nanchors, bboxAttrs}, true)
|
||||
|
||||
grid := ts.MustArange(ts.IntScalar(gridSize), gotch.Float, gotch.CPU)
|
||||
a := grid.MustRepeat([]int64{gridSize, 1}, false)
|
||||
bTmp := a.MustT(false)
|
||||
b := bTmp.MustContiguous(true)
|
||||
xOffset := a.MustView([]int64{-1, 1}, false)
|
||||
yOffset := b.MustView([]int64{-1, 1}, false)
|
||||
|
||||
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1, false)
|
||||
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
|
||||
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
|
||||
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
anchorsTmp1 := ts.MustOfSlice(anchors)
|
||||
anchorsTmp2 := anchorsTmp1.MustView([]int64{-1, 2}, true)
|
||||
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
|
||||
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true)
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
tmp := xs.MustSigmoid(false)
|
||||
res = tmp.MustAdd(xyOffset, true)
|
||||
return res
|
||||
})
|
||||
|
||||
sliceApplyAndSet(xsTs, 4, classes+1, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
return xs.MustSigmoid(false)
|
||||
})
|
||||
|
||||
sliceApplyAndSet(xsTs, 2, 2, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
tmp := xs.MustExp(false)
|
||||
res = tmp.MustMul(anchorsTs, true)
|
||||
return res
|
||||
})
|
||||
|
||||
sliceApplyAndSet(xsTs, 0, 4, func(xs ts.Tensor) (res ts.Tensor) {
|
||||
return xs.MustMul1(ts.IntScalar(stride), false)
|
||||
})
|
||||
|
||||
// TODO: delete all middle tensors.
|
||||
|
||||
return xsTs
|
||||
}
|
||||
|
||||
func (dn *Darknet) Height() (retVal int64) {
|
||||
imageHeightStr := dn.get("height")
|
||||
retVal, err := strconv.ParseInt(imageHeightStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (dn *Darknet) Width() (retVal int64) {
|
||||
imageWidthStr := dn.get("width")
|
||||
retVal, err := strconv.ParseInt(imageWidthStr, 10, 64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
|
||||
var blocks []ChannelsBl // Param is a struct{int64, interface{}}
|
||||
var prevChannels int64 = 3
|
||||
|
||||
for index, blk := range dn.Blocks {
|
||||
var channels int64
|
||||
var bl interface{}
|
||||
|
||||
switch *blk.BlockType {
|
||||
case "convolutional":
|
||||
channels, bl = conv(vs.Sub(fmt.Sprintf("%v", index)), uint(index), prevChannels, blk)
|
||||
case "upsample":
|
||||
channels, bl = upsample(prevChannels)
|
||||
case "shortcut":
|
||||
channels, bl = shortcut(uint(index), prevChannels, blk)
|
||||
case "route":
|
||||
channels, bl = route(uint(index), blocks, blk)
|
||||
case "yolo":
|
||||
channels, bl = yolo(prevChannels, blk)
|
||||
default:
|
||||
log.Fatalf("Unsupported block type: %v\n", *blk.BlockType)
|
||||
}
|
||||
prevChannels = channels
|
||||
blocks = append(blocks, ChannelsBl{channels, bl})
|
||||
}
|
||||
|
||||
imageHeight := dn.Height()
|
||||
|
||||
retVal = nn.NewFuncT(func(xs ts.Tensor, train bool) (res ts.Tensor) {
|
||||
|
||||
var prevYs []ts.Tensor = make([]ts.Tensor, 0)
|
||||
var detections []ts.Tensor = make([]ts.Tensor, 0)
|
||||
|
||||
for _, b := range blocks {
|
||||
blkTyp := reflect.TypeOf(b.Bl)
|
||||
var ysTs ts.Tensor
|
||||
switch blkTyp.Name() {
|
||||
case "Layer": // Layer type
|
||||
layer := b.Bl.(Layer)
|
||||
xsTs := xs
|
||||
if len(prevYs) > 0 {
|
||||
xsTs = prevYs[len(prevYs)-1] // last prevYs element
|
||||
}
|
||||
ysTs = layer.ForwardT(xsTs, train)
|
||||
case "Route":
|
||||
layerIdxs := b.Bl.(Route)
|
||||
var layers []ts.Tensor
|
||||
for _, i := range layerIdxs {
|
||||
layers = append(layers, prevYs[int(i)])
|
||||
}
|
||||
ysTs = ts.MustCat(layers, 1, true)
|
||||
|
||||
case "Shortcut":
|
||||
from := b.Bl.(Shortcut)
|
||||
addTs := prevYs[int(from)]
|
||||
last := prevYs[len(prevYs)-1]
|
||||
ysTs = last.MustAdd(addTs, false) // TODO: Should we delete it?
|
||||
addTs.MustDrop()
|
||||
case "Yolo":
|
||||
classes := b.Bl.(Yolo).Classes
|
||||
anchors := b.Bl.(Yolo).Anchors
|
||||
last := xs
|
||||
if len(prevYs) > 0 {
|
||||
last = prevYs[len(prevYs)-1]
|
||||
}
|
||||
dt := detect(last, imageHeight, classes, anchors)
|
||||
detections = append(detections, dt)
|
||||
ysTs = ts.NewTensor()
|
||||
|
||||
default:
|
||||
log.Fatalf("Unsupported block type: %v\n", blkTyp.Name())
|
||||
} // end of Switch
|
||||
|
||||
prevYs = append(prevYs, ysTs)
|
||||
} // end of For loop
|
||||
|
||||
return ts.MustCat(detections, 1, true)
|
||||
})
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -3,6 +3,8 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
// "flag"
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"log"
|
||||
"path/filepath"
|
||||
)
|
||||
|
@ -22,6 +24,11 @@ func main() {
|
|||
|
||||
var darknet Darknet = ParseConfig(configPath)
|
||||
|
||||
fmt.Printf("darknet number of parameters: %v\n", len(darknet.parameters))
|
||||
fmt.Printf("darknet number of blocks: %v\n", len(darknet.blocks))
|
||||
fmt.Printf("darknet number of parameters: %v\n", len(darknet.Parameters))
|
||||
fmt.Printf("darknet number of blocks: %v\n", len(darknet.Blocks))
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
model := darknet.BuildModel(vs.Root())
|
||||
|
||||
fmt.Printf("Model: %v\n", model)
|
||||
}
|
||||
|
|
|
@ -709,3 +709,25 @@ func AtgUpsampleNearest2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, ou
|
|||
|
||||
C.atg_upsample_nearest2d(ptr, self, coutputSizeDataPtr, coutputSizeLen, cscalesH, cscalesW)
|
||||
}
|
||||
|
||||
// void atg_repeat(tensor *, tensor self, int64_t *repeats_data, int repeats_len);
|
||||
func AtgRepeat(ptr *Ctensor, self Ctensor, repeatData []int64, repeatLen int) {
|
||||
crepeatDataPtr := (*C.int64_t)(unsafe.Pointer(&repeatData[0]))
|
||||
crepeatLen := *(*C.int)(unsafe.Pointer(&repeatLen))
|
||||
|
||||
C.atg_repeat(ptr, self, crepeatDataPtr, crepeatLen)
|
||||
}
|
||||
|
||||
// void atg_contiguous(tensor *, tensor self);
|
||||
func AtgContiguous(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_contiguous(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_transpose(tensor *, tensor self, int64_t dim0, int64_t dim1);
|
||||
func AtgTranspose(ptr *Ctensor, self Ctensor, dim0 int64, dim1 int64) {
|
||||
|
||||
cdim0 := *(*C.int64_t)(unsafe.Pointer(&dim0))
|
||||
cdim1 := *(*C.int64_t)(unsafe.Pointer(&dim1))
|
||||
|
||||
C.atg_transpose(ptr, self, cdim0, cdim1)
|
||||
}
|
||||
|
|
|
@ -2163,3 +2163,82 @@ func (ts Tensor) MustUpsampleNearest2d(outputSize []int64, scalesH, scalesW floa
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Repeat(repeatData []int64, del bool) (retVal Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgRepeat(ptr, ts.ctensor, repeatData, len(repeatData))
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustRepeat(repeatData []int64, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Repeat(repeatData, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Contiguous(del bool) (retVal Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgContiguous(ptr, ts.ctensor)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustContiguous(del bool) (retVal Tensor) {
|
||||
|
||||
retVal, err := ts.Contiguous(del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Transpose(dim0, dim1 int64, del bool) (retVal Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgTranspose(ptr, ts.ctensor, dim0, dim1)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustTranspose(dim0, dim1 int64, del bool) (retVal Tensor) {
|
||||
retVal, err := ts.Transpose(dim0, dim1, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user