WIP(example/yolo)

This commit is contained in:
sugarme 2020-07-13 22:00:03 +10:00
parent a146032afa
commit f16ab429c9
3 changed files with 186 additions and 4 deletions

View File

@ -146,8 +146,13 @@ type (
Route = []uint
Shortcut = uint
Yolo struct {
V1 int64
V2 []int64
Val1 int64
V2 []int64
}
Param struct {
Val1 int64
Val2 interface{}
}
)
@ -233,9 +238,8 @@ func conv(vs nn.Path, index uint, p int64, b block) (retVal1 int64, retVal2 inte
if leaky {
tmp2Mul := tmp2.MustMul1(ts.FloatScalar(0.1), false)
res = tmp2.MustMax1(tmp2Mul)
res = tmp2.MustMax1(tmp2Mul, true)
tmp2Mul.MustDrop()
tmp2.MustDrop()
} else {
res = tmp2
}
@ -245,3 +249,113 @@ func conv(vs nn.Path, index uint, p int64, b block) (retVal1 int64, retVal2 inte
return filters, fn
}
func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
layer := nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
// []int64{n, c, h, w}
res, err := xs.Size4()
if err != nil {
log.Fatal(err)
}
h := res[2]
w := res[3]
return xs.MustUpsampleNearest2d([]int64{h * 2, w * 2}, 2.0, 2.0)
})
return prevChannels, layer
}
func intListOfString(s string) (retVal []int64) {
strs := strings.Split(s, ",")
for _, str := range strs {
str = strings.TrimSpace(str)
i, err := strconv.ParseInt(str, 10, 64)
if err != nil {
log.Fatal(err)
}
retVal = append(retVal, i)
}
return retVal
}
func uintOfIndex(index uint, i int64) (retVal uint) {
if i >= 0 {
return uint(i)
} else {
return uint(int64(index) + i)
}
}
func route(index uint, p []Param, blk block) (retVal1 int64, retVal2 interface{}) {
intLayers := intListOfString(blk.get("layers"))
var layers []uint
for _, l := range intLayers {
layers = append(layers, uintOfIndex(index, l))
}
var channels int64
for _, l := range layers {
channels += p[l].Val1
}
return channels, layers
}
func shortcut(index uint, p int64, blk block) (retVal1 int64, retVal2 interface{}) {
fromStr := blk.get("from")
from, err := strconv.ParseInt(fromStr, 10, 64)
if err != nil {
log.Fatal(err)
}
return p, uintOfIndex(index, from)
}
func yolo(p int64, blk block) (retVal1 int64, retVal2 interface{}) {
classesStr := blk.get("classes")
classes, err := strconv.ParseInt(classesStr, 10, 64)
if err != nil {
log.Fatal(err)
}
flat := intListOfString(blk.get("anchors"))
if (len(flat) % 2) != 0 {
log.Fatalf("Expected even number of flat")
}
var anchors [][]int64
for i := 0; i < len(flat)/2; i++ {
anchors = append(anchors, []int64{flat[2*i], flat[2*i+1]})
}
intMask := intListOfString(blk.get("mask"))
var retAnchors [][]int64
for _, i := range intMask {
retAnchors = append(retAnchors, anchors[i])
}
return p, retAnchors
}
// Apply f to a slice of tensor xs and replace xs values with f output.
func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts.Tensor) {
slice := xs.MustNarrow(2, start, len, false)
src := f(slice)
slice.Copy_(src)
src.MustDrop()
// TODO: check whether we need to delete slice to prevent memory blow-up
// slice.MustDrop()
}
// TODO: continue
// func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors
// [][]int64) (retVal ts.Tensor){
// }

View File

@ -693,3 +693,19 @@ func AtgArangeOut1(ptr *Ctensor, out Ctensor, start Cscalar, end Cscalar) {
C.atg_arange_out1(ptr, out, start, end)
}
// void atg_max1(tensor *, tensor self, tensor other);
func AtgMax1(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_max1(ptr, self, other)
}
// void atg_upsample_nearest2d(tensor *, tensor self, int64_t *output_size_data, int output_size_len, double scales_h, double scales_w);
func AtgUpsampleNearest2d(ptr *Ctensor, self Ctensor, outputSizeData []int64, outputSizeLen int, scalesH float64, scalesW float64) {
coutputSizeDataPtr := (*C.int64_t)(unsafe.Pointer(&outputSizeData[0]))
coutputSizeLen := *(*C.int)(unsafe.Pointer(&outputSizeLen))
cscalesH := *(*C.double)(unsafe.Pointer(&scalesH))
cscalesW := *(*C.double)(unsafe.Pointer(&scalesW))
C.atg_upsample_nearest2d(ptr, self, coutputSizeDataPtr, coutputSizeLen, cscalesH, cscalesW)
}

View File

@ -2111,3 +2111,55 @@ func MustArangeOut1(out Tensor, start, end Scalar) (retVal Tensor) {
return retVal
}
func (ts Tensor) Max1(other Tensor, del bool) (retVal Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgMax1(ptr, ts.ctensor, other.ctensor)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustMax1(other Tensor, del bool) (retVal Tensor) {
retVal, err := ts.Max1(other, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) UpsampleNearest2d(outputSize []int64, scalesH, scalesW float64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgUpsampleNearest2d(ptr, ts.ctensor, outputSize, len(outputSize), scalesH, scalesW)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustUpsampleNearest2d(outputSize []int64, scalesH, scalesW float64) (retVal Tensor) {
retVal, err := ts.UpsampleNearest2d(outputSize, scalesH, scalesW)
if err != nil {
log.Fatal(err)
}
return retVal
}