corrected pointer receiver conversion in vision sub-packages and examples

This commit is contained in:
sugarme 2020-11-01 11:59:08 +11:00
parent 5414b6ed57
commit a6d09580aa
11 changed files with 141 additions and 150 deletions

View File

@ -11,6 +11,10 @@ func main() {
// Create a tensor [2,3,4]
tensor := ts.MustArange(ts.IntScalar(2*3*4), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
tensor.Print()
mul := ts.MustOnes([]int64{4, 5}, gotch.Int64, gotch.CPU)
res := tensor.MustMatmul(mul, false)
res.Print()
}

View File

@ -18,7 +18,7 @@ import (
"github.com/sugarme/gotch/vision"
)
func convBn(p nn.Path, cIn, cOut int64) (retVal nn.SequentialT) {
func convBn(p *nn.Path, cIn, cOut int64) *nn.SequentialT {
config := nn.DefaultConv2DConfig()
config.Padding = []int64{1, 1}
config.Bias = false
@ -27,19 +27,19 @@ func convBn(p nn.Path, cIn, cOut int64) (retVal nn.SequentialT) {
seq.Add(nn.NewConv2D(p, cIn, cOut, 3, config))
seq.Add(nn.BatchNorm2D(p, cOut, nn.DefaultBatchNormConfig()))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustRelu(false)
}))
return seq
}
func layer(p nn.Path, cIn, cOut int64) (retVal nn.FuncT) {
func layer(p *nn.Path, cIn, cOut int64) nn.FuncT {
pre := convBn(p.Sub("pre"), cIn, cOut)
block1 := convBn(p.Sub("b1"), cOut, cOut)
block2 := convBn(p.Sub("b2"), cOut, cOut)
return nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
return nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
tmp1 := xs.ApplyT(pre, train)
preTs := tmp1.MaxPool2DDefault(2, true)
tmp2 := preTs.ApplyT(block1, train)
@ -53,17 +53,17 @@ func layer(p nn.Path, cIn, cOut int64) (retVal nn.FuncT) {
})
}
func fastResnet(p nn.Path) (retVal nn.SequentialT) {
func fastResnet(p *nn.Path) *nn.SequentialT {
seq := nn.SeqT()
seq.Add(convBn(p.Sub("pre"), 3, 64))
seq.Add(layer(p.Sub("layer1"), 64, 128))
seq.Add(convBn(p.Sub("inter"), 128, 256))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MaxPool2DDefault(2, false)
}))
seq.Add(layer(p.Sub("layer2"), 256, 512))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
tmp := xs.MaxPool2DDefault(4, false)
res := tmp.FlatView()
tmp.MustDrop()
@ -72,7 +72,7 @@ func fastResnet(p nn.Path) (retVal nn.SequentialT) {
}))
seq.Add(nn.NewLinear(p.Sub("linear"), 512, 10, nn.DefaultLinearConfig()))
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
seq.AddFn(nn.NewFunc(func(xs *ts.Tensor) *ts.Tensor {
return xs.MustMul1(ts.FloatScalar(0.125), false)
}))
@ -89,8 +89,8 @@ func main() {
fmt.Printf("TestLabel shape: %v\n", ds.TestLabels.MustSize())
fmt.Printf("Number of labels: %v\n", ds.Labels)
cuda := gotch.CudaBuilder(0)
device := cuda.CudaIfAvailable()
// device := gotch.CPU
device := gotch.NewCuda().CudaIfAvailable()
vs := nn.NewVarStore(device)
@ -104,7 +104,7 @@ func main() {
for epoch := 0; epoch < 150; epoch++ {
optConfig := nn.NewSGDConfig(0.9, 0.0, 5e-4, true)
var (
opt nn.Optimizer
opt *nn.Optimizer
err error
)
switch {

View File

@ -31,7 +31,7 @@ var (
style string
)
func gramMatrix(m ts.Tensor) (retVal ts.Tensor) {
func gramMatrix(m *ts.Tensor) *ts.Tensor {
sizes, err := m.Size4()
if err != nil {
log.Fatal(err)
@ -52,7 +52,7 @@ func gramMatrix(m ts.Tensor) (retVal ts.Tensor) {
return gram.MustDiv1(ts.IntScalar(a*b*c*d), true)
}
func styleLoss(m1 ts.Tensor, m2 ts.Tensor) (retVal ts.Tensor) {
func styleLoss(m1 *ts.Tensor, m2 *ts.Tensor) *ts.Tensor {
gram1 := gramMatrix(m1)
// m1.MustDrop()
gram2 := gramMatrix(m2)
@ -87,9 +87,9 @@ func main() {
log.Fatal(err)
}
cuda := gotch.CudaBuilder(0)
device := cuda.CudaIfAvailable()
// device := gotch.CPU
// cuda := gotch.CudaBuilder(0)
// device := cuda.CudaIfAvailable()
device := gotch.CPU
netVS := nn.NewVarStore(device)
in := vision.NewImageNet()
@ -153,13 +153,13 @@ func main() {
sLoss := ts.MustZeros([]int64{1}, gotch.Float, device)
cLoss := ts.MustZeros([]int64{1}, gotch.Float, device)
for _, idx := range StyleIndexes {
l := styleLoss(inputLayers[idx], styleLayers[idx])
l := styleLoss(&inputLayers[idx], &styleLayers[idx])
sLoss = sLoss.MustAdd(l, true)
l.MustDrop()
}
for _, idx := range ContentIndexes {
// NOTE: set `del` = true called panic at GPU train (tested on Colab)
l := inputLayers[idx].MustMseLoss(contentLayers[idx], int64(ts.ReductionMean), false)
l := inputLayers[idx].MustMseLoss(&contentLayers[idx], int64(ts.ReductionMean), false)
cLoss = cLoss.MustAdd(l, true)
l.MustDrop()
}

View File

@ -62,22 +62,24 @@ func main() {
trainImages := ts.NoGrad1(func() (retVal interface{}) {
return dataset.TrainImages.ApplyT(net, true)
}).(ts.Tensor)
}).(*ts.Tensor)
testImages := ts.NoGrad1(func() (retVal interface{}) {
return dataset.TestImages.ApplyT(net, true)
}).(ts.Tensor)
}).(*ts.Tensor)
fmt.Println("start training...")
for epoch := 1; epoch <= 1000; epoch++ {
predicted := trainImages.Apply(linear)
predicted := trainImages.ApplyT(linear, true)
loss := predicted.CrossEntropyForLogits(dataset.TrainLabels)
sgd.BackwardStep(loss)
loss.MustDrop()
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Float64Values()[0]*100)
ts.NoGrad(func() {
testAccuracy := testImages.Apply(linear).AccuracyForLogits(dataset.TestLabels)
fmt.Printf("Epoch %v\t Accuracy: %5.2f%%\n", epoch, testAccuracy.Float64Values()[0]*100)
})
}
}

View File

@ -34,35 +34,35 @@ type Encoder struct {
gru nn.GRU
}
func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
func newEncoder(vs *nn.Path, inDim, hiddenDim int64) *Encoder {
gru := nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig())
embedding := nn.NewEmbedding(vs, inDim, hiddenDim, nn.DefaultEmbeddingConfig())
return Encoder{embedding, gru}
return &Encoder{*embedding, *gru}
}
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
func (e *Encoder) forward(xs *ts.Tensor, state *nn.GRUState) (*ts.Tensor, *nn.GRUState) {
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
retState = e.gru.Step(retTs, state).(nn.GRUState)
retTs := e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
retState := e.gru.Step(retTs, state).(*nn.GRUState)
return retTs, retState
}
type Decoder struct {
device gotch.Device
embedding nn.Embedding
gru nn.GRU
attn nn.Linear
attnCombine nn.Linear
linear nn.Linear
embedding *nn.Embedding
gru *nn.GRU
attn *nn.Linear
attnCombine *nn.Linear
linear *nn.Linear
}
func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
func newDecoder(vs *nn.Path, hiddenDim, outDim int64) *Decoder {
return Decoder{
return &Decoder{
device: vs.Device(),
embedding: nn.NewEmbedding(vs, outDim, hiddenDim, nn.DefaultEmbeddingConfig()),
gru: nn.NewGRU(vs, hiddenDim, hiddenDim, nn.DefaultRNNConfig()),
@ -72,7 +72,7 @@ func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
}
}
func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) {
func (d *Decoder) forward(xs *ts.Tensor, state *nn.GRUState, encOutputs *ts.Tensor, isTraining bool) (*ts.Tensor, *nn.GRUState) {
forwardTsTmp := d.embedding.Forward(xs)
forwardTsTmp.MustDropout_(0.1, isTraining)
@ -81,7 +81,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
// NOTE. forwardTs shape: [1, 256] state [1, 1, 256]
// hence, just get state[0] of 3D tensor state
stateTs := state.Value().MustShallowClone().MustView([]int64{1, -1}, true)
catTs := ts.MustCat([]ts.Tensor{forwardTs, stateTs}, 1)
catTs := ts.MustCat([]ts.Tensor{*forwardTs, *stateTs}, 1)
stateTs.MustDrop()
// NOTE. d.attn Ws shape : [512, 10]
@ -97,44 +97,44 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
sz2 := size3[1]
sz3 := size3[2]
var encOutputsTs ts.Tensor
var encOutputsTs *ts.Tensor
if sz2 == MaxLength {
encOutputsTs = encOutputs.MustShallowClone()
} else {
shape := []int64{sz1, MaxLength - sz2, sz3}
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
encOutputsTs = ts.MustCat([]ts.Tensor{*encOutputs, *zerosTs}, 1)
zerosTs.MustDrop()
}
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
encOutputsTs.MustDrop()
cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
cTs := ts.MustCat([]ts.Tensor{*forwardTs, *attnApplied}, 1)
forwardTs.MustDrop()
attnApplied.MustDrop()
aTs := cTs.Apply(d.attnCombine)
cTs.MustDrop()
xsTs := aTs.MustRelu(true)
retState = d.gru.Step(xsTs, state).(nn.GRUState)
retState := d.gru.Step(xsTs, state).(*nn.GRUState)
xsTs.MustDrop()
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
retTs := d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
return retTs, retState
}
type Model struct {
encoder Encoder
decoder Decoder
decoderStart ts.Tensor
encoder *Encoder
decoder *Decoder
decoderStart *ts.Tensor
decoderEos int64
device gotch.Device
}
func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model) {
return Model{
func newModel(vs *nn.Path, ilang Lang, olang Lang, hiddenDim int64) *Model {
return &Model{
encoder: newEncoder(vs.Sub("enc"), int64(ilang.Len()), hiddenDim),
decoder: newDecoder(vs.Sub("dec"), hiddenDim, int64(olang.Len())),
decoderStart: ts.MustOfSlice([]int64{int64(olang.SosToken())}).MustTo(vs.Device(), true),
@ -143,16 +143,16 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model
}
}
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
func (m *Model) trainLoss(input []int, target []int) *ts.Tensor {
state := m.encoder.gru.ZeroState(1)
var encOutputs []ts.Tensor
for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
outTs, outState := m.encoder.forward(s, state.(*nn.GRUState))
s.MustDrop()
encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop()
encOutputs = append(encOutputs, *outTs)
state.(*nn.GRUState).Tensor.MustDrop()
state = outState
}
@ -167,8 +167,8 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
for _, s := range target {
// TODO: fix memory leak at decoder.forward
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
state.(nn.GRUState).Tensor.MustDrop()
outTs, outState := m.decoder.forward(prev, state.(*nn.GRUState), stackTs, true)
state.(*nn.GRUState).Tensor.MustDrop()
state = outState
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
@ -195,7 +195,7 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
outTs.MustDrop()
}
state.(nn.GRUState).Tensor.MustDrop()
state.(*nn.GRUState).Tensor.MustDrop()
stackTs.MustDrop()
prev.MustDrop()
@ -203,16 +203,16 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
}
func (m *Model) predict(input []int) (retVal []int) {
func (m *Model) predict(input []int) []int {
state := m.encoder.gru.ZeroState(1)
var encOutputs []ts.Tensor
for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
outTs, outState := m.encoder.forward(s, state.(*nn.GRUState))
encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop()
encOutputs = append(encOutputs, *outTs)
state.(*nn.GRUState).Tensor.MustDrop()
state = outState
}
@ -225,7 +225,7 @@ func (m *Model) predict(input []int) (retVal []int) {
var outputSeq []int
for i := 0; i < int(MaxLength); i++ {
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
outTs, outState := m.decoder.forward(prev, state.(*nn.GRUState), stackTs, true)
_, output := outTs.MustTopK(1, -1, true, true)
outputVal := output.Int64Values()[0]
outputSeq = append(outputSeq, int(outputVal))
@ -234,7 +234,7 @@ func (m *Model) predict(input []int) (retVal []int) {
break
}
state.(nn.GRUState).Tensor.MustDrop()
state.(*nn.GRUState).Tensor.MustDrop()
state = outState
prev.MustDrop()
prev = output
@ -249,8 +249,8 @@ type LossStats struct {
samples int
}
func newLossStats() (retVal LossStats) {
return LossStats{
func newLossStats() *LossStats {
return &LossStats{
totalLoss: 0.0,
samples: 0,
}
@ -261,7 +261,7 @@ func (ls *LossStats) update(loss float64) {
ls.samples += 1
}
func (ls *LossStats) avgAndReset() (retVal float64) {
func (ls *LossStats) avgAndReset() float64 {
avg := ls.totalLoss / float64(ls.samples)
ls.totalLoss = 0.0
ls.samples = 0

View File

@ -19,7 +19,7 @@ type Block struct {
Parameters map[string]string
}
func (b *Block) get(key string) (retVal string) {
func (b *Block) get(key string) string {
val, ok := b.Parameters[key]
if !ok {
log.Fatalf("Cannot find %v in Block parameters.\n", key)
@ -33,7 +33,7 @@ type Darknet struct {
Parameters map[string]string
}
func (d Darknet) get(key string) (retVal string) {
func (d *Darknet) get(key string) string {
val, ok := d.Parameters[key]
if !ok {
log.Fatalf("Cannot find %v in Darknet parameters.\n", key)
@ -44,16 +44,16 @@ func (d Darknet) get(key string) (retVal string) {
type Accumulator struct {
Parameters map[string]string
Net Darknet
Net *Darknet
BlockType *string // optional
}
func newAccumulator() (retVal Accumulator) {
func newAccumulator() *Accumulator {
return Accumulator{
return &Accumulator{
BlockType: nil,
Parameters: make(map[string]string, 0),
Net: Darknet{
Net: &Darknet{
Blocks: make([]Block, 0),
Parameters: make(map[string]string, 0),
},
@ -79,7 +79,7 @@ func (acc *Accumulator) finishBlock() {
acc.BlockType = nil
}
func ParseConfig(path string) (retVal Darknet) {
func ParseConfig(path string) *Darknet {
acc := newAccumulator()
@ -166,7 +166,7 @@ type (
}
)
func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 interface{}) {
func conv(vs *nn.Path, index uint, p int64, b *Block) (retVal1 int64, retVal2 interface{}) {
activation := b.get("activation")
@ -209,7 +209,7 @@ func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 inte
if p != 0 {
sub := vs.Sub(fmt.Sprintf("batch_norm_%v", index))
bnVal := nn.BatchNorm2D(sub, filters, nn.DefaultBatchNormConfig())
bn = &bnVal
bn = bnVal
bias = false
}
} else {
@ -234,18 +234,19 @@ func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 inte
log.Fatalf("Unsupported activation(%v)\n", activation)
}
fn := nn.NewFuncT(func(xs ts.Tensor, train bool) (res ts.Tensor) {
fn := nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
tmp1 := xs.Apply(conv)
var tmp2 ts.Tensor
var tmp2 *ts.Tensor
if bn != nil {
tmp2 = tmp1.ApplyT(*bn, train)
tmp2 = tmp1.ApplyT(bn, train)
tmp1.MustDrop()
} else {
tmp2 = tmp1
}
var res *ts.Tensor
if leaky {
tmp2Mul := tmp2.MustMul1(ts.FloatScalar(0.1), false)
res = tmp2.MustMax1(tmp2Mul, true)
@ -261,7 +262,7 @@ func conv(vs nn.Path, index uint, p int64, b Block) (retVal1 int64, retVal2 inte
}
func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
layer := nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
layer := nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
// []int64{n, c, h, w}
res, err := xs.Size4()
if err != nil {
@ -276,7 +277,8 @@ func upsample(prevChannels int64) (retVal1 int64, retVal2 interface{}) {
return prevChannels, Layer{Val: layer}
}
func intListOfString(s string) (retVal []int64) {
func intListOfString(s string) []int64 {
var retVal []int64
strs := strings.Split(s, ",")
for _, str := range strs {
str = strings.TrimSpace(str)
@ -290,7 +292,7 @@ func intListOfString(s string) (retVal []int64) {
return retVal
}
func uintOfIndex(index uint, i int64) (retVal uint) {
func uintOfIndex(index uint, i int64) uint {
if i >= 0 {
return uint(i)
} else {
@ -298,7 +300,7 @@ func uintOfIndex(index uint, i int64) (retVal uint) {
}
}
func route(index uint, p []ChannelsBl, blk Block) (retVal1 int64, retVal2 interface{}) {
func route(index uint, p []ChannelsBl, blk *Block) (retVal1 int64, retVal2 interface{}) {
intLayers := intListOfString(blk.get("layers"))
var layers []uint
@ -314,7 +316,7 @@ func route(index uint, p []ChannelsBl, blk Block) (retVal1 int64, retVal2 interf
return channels, Route{TsIdxs: layers}
}
func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
func shortcut(index uint, p int64, blk *Block) (retVal1 int64, retVal2 interface{}) {
fromStr := blk.get("from")
from, err := strconv.ParseInt(fromStr, 10, 64)
@ -325,7 +327,7 @@ func shortcut(index uint, p int64, blk Block) (retVal1 int64, retVal2 interface{
return p, Shortcut{TsIdx: uintOfIndex(index, from)}
}
func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
func yolo(p int64, blk *Block) (retVal1 int64, retVal2 interface{}) {
classesStr := blk.get("classes")
classes, err := strconv.ParseInt(classesStr, 10, 64)
if err != nil {
@ -356,7 +358,7 @@ func yolo(p int64, blk Block) (retVal1 int64, retVal2 interface{}) {
}
// Apply f to a slice of tensor xs and replace xs values with f output.
func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts.Tensor) {
func sliceApplyAndSet(xs *ts.Tensor, start int64, len int64, f func(*ts.Tensor) *ts.Tensor) {
slice := xs.MustNarrow(2, start, len, false)
src := f(slice)
@ -365,7 +367,7 @@ func sliceApplyAndSet(xs ts.Tensor, start int64, len int64, f func(ts.Tensor) ts
slice.MustDrop()
}
func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (retVal ts.Tensor) {
func detect(xs *ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) *ts.Tensor {
device, err := xs.Device()
@ -396,7 +398,7 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
xOffset := a.MustView([]int64{-1, 1}, true)
yOffset := b.MustView([]int64{-1, 1}, true)
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{xOffset, yOffset}, 1)
xyOffsetTmp1 := ts.MustCat([]ts.Tensor{*xOffset, *yOffset}, 1)
xyOffsetTmp2 := xyOffsetTmp1.MustRepeat([]int64{1, nanchors}, true)
xyOffsetTmp3 := xyOffsetTmp2.MustView([]int64{-1, 2}, true)
xyOffset := xyOffsetTmp3.MustUnsqueeze(0, true)
@ -417,23 +419,21 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
anchorsTmp3 := anchorsTmp2.MustRepeat([]int64{gridSize * gridSize, 1}, true)
anchorsTs := anchorsTmp3.MustUnsqueeze(0, true).MustTo(device, true)
sliceApplyAndSet(xsTs, 0, 2, func(xs ts.Tensor) (res ts.Tensor) {
sliceApplyAndSet(xsTs, 0, 2, func(xs *ts.Tensor) *ts.Tensor {
tmp := xs.MustSigmoid(false)
res = tmp.MustAdd(xyOffset, true)
return res
return tmp.MustAdd(xyOffset, true)
})
sliceApplyAndSet(xsTs, 4, classes+1, func(xs ts.Tensor) (res ts.Tensor) {
sliceApplyAndSet(xsTs, 4, classes+1, func(xs *ts.Tensor) *ts.Tensor {
return xs.MustSigmoid(false)
})
sliceApplyAndSet(xsTs, 2, 2, func(xs ts.Tensor) (res ts.Tensor) {
sliceApplyAndSet(xsTs, 2, 2, func(xs *ts.Tensor) *ts.Tensor {
tmp := xs.MustExp(false)
res = tmp.MustMul(anchorsTs, true)
return res
return tmp.MustMul(anchorsTs, true)
})
sliceApplyAndSet(xsTs, 0, 4, func(xs ts.Tensor) (res ts.Tensor) {
sliceApplyAndSet(xsTs, 0, 4, func(xs *ts.Tensor) *ts.Tensor {
return xs.MustMul1(ts.IntScalar(stride), false)
})
@ -441,7 +441,7 @@ func detect(xs ts.Tensor, imageHeight int64, classes int64, anchors []Anchor) (r
return xsTs
}
func (dn *Darknet) Height() (retVal int64) {
func (dn *Darknet) Height() int64 {
imageHeightStr := dn.get("height")
retVal, err := strconv.ParseInt(imageHeightStr, 10, 64)
if err != nil {
@ -451,7 +451,7 @@ func (dn *Darknet) Height() (retVal int64) {
return retVal
}
func (dn *Darknet) Width() (retVal int64) {
func (dn *Darknet) Width() int64 {
imageWidthStr := dn.get("width")
retVal, err := strconv.ParseInt(imageWidthStr, 10, 64)
if err != nil {
@ -461,7 +461,7 @@ func (dn *Darknet) Width() (retVal int64) {
return retVal
}
func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
func (dn *Darknet) BuildModel(vs *nn.Path) nn.FuncT {
var blocks []ChannelsBl // Param is a struct{int64, interface{}}
var prevChannels int64 = 3
@ -471,15 +471,15 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
switch *blk.BlockType {
case "convolutional":
channels, bl = conv(vs.Sub(fmt.Sprintf("%v", index)), uint(index), prevChannels, blk)
channels, bl = conv(vs.Sub(fmt.Sprintf("%v", index)), uint(index), prevChannels, &blk)
case "upsample":
channels, bl = upsample(prevChannels)
case "shortcut":
channels, bl = shortcut(uint(index), prevChannels, blk)
channels, bl = shortcut(uint(index), prevChannels, &blk)
case "route":
channels, bl = route(uint(index), blocks, blk)
channels, bl = route(uint(index), blocks, &blk)
case "yolo":
channels, bl = yolo(prevChannels, blk)
channels, bl = yolo(prevChannels, &blk)
default:
log.Fatalf("Unsupported block type: %v\n", *blk.BlockType)
}
@ -489,7 +489,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
imageHeight := dn.Height()
retVal = nn.NewFuncT(func(xs ts.Tensor, train bool) (res ts.Tensor) {
retVal := nn.NewFuncT(func(xs *ts.Tensor, train bool) *ts.Tensor {
var prevYs []ts.Tensor = make([]ts.Tensor, 0)
var detections []ts.Tensor = make([]ts.Tensor, 0)
@ -497,13 +497,13 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
// NOTE: we will delete all tensors in prevYs after looping
for _, b := range blocks {
blkTyp := reflect.TypeOf(b.Bl)
var ysTs ts.Tensor
var ysTs *ts.Tensor
switch blkTyp.Name() {
case "Layer":
layer := b.Bl.(Layer)
xsTs := xs
if len(prevYs) > 0 {
xsTs = prevYs[len(prevYs)-1] // last prevYs element
xsTs = &prevYs[len(prevYs)-1] // last prevYs element
}
ysTs = layer.Val.ForwardT(xsTs, train)
case "Route":
@ -516,7 +516,7 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
case "Shortcut":
from := b.Bl.(Shortcut).TsIdx
addTs := prevYs[int(from)]
addTs := &prevYs[int(from)]
last := prevYs[len(prevYs)-1]
ysTs = last.MustAdd(addTs, false)
case "Yolo":
@ -524,12 +524,12 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
anchors := b.Bl.(Yolo).Anchors
xsTs := xs
if len(prevYs) > 0 {
xsTs = prevYs[len(prevYs)-1]
xsTs = &prevYs[len(prevYs)-1]
}
dt := detect(xsTs, imageHeight, classes, anchors)
detections = append(detections, dt)
detections = append(detections, *dt)
ysTs = ts.NewTensor()
@ -537,10 +537,10 @@ func (dn *Darknet) BuildModel(vs nn.Path) (retVal nn.FuncT) {
// log.Fatalf("BuildModel - FuncT - Unsupported block type: %v\n", blkTyp.Name())
} // end of Switch
prevYs = append(prevYs, ysTs)
prevYs = append(prevYs, *ysTs)
} // end of For loop
res = ts.MustCat(detections, 1)
res := ts.MustCat(detections, 1)
// Now, free-up memory held up by prevYs
for _, t := range prevYs {

View File

@ -59,7 +59,7 @@ func Iou(b1, b2 Bbox) (retVal float64) {
}
// Assuming x1 <= x2 and y1 <= y2
func drawRect(t ts.Tensor, x1, x2, y1, y2 int64) {
func drawRect(t *ts.Tensor, x1, x2, y1, y2 int64) {
color := ts.MustOfSlice([]float64{0.0, 0.0, 1.0}).MustView([]int64{3, 1, 1}, true)
// NOTE: `narrow` will create a tensor (view) that share same storage with
@ -71,7 +71,7 @@ func drawRect(t ts.Tensor, x1, x2, y1, y2 int64) {
color.MustDrop()
}
func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor) {
func report(pred *ts.Tensor, img *ts.Tensor, w int64, h int64) *ts.Tensor {
size2, err := pred.Size2()
if err != nil {
log.Fatal(err)
@ -180,7 +180,7 @@ func report(pred ts.Tensor, img ts.Tensor, w int64, h int64) (retVal ts.Tensor)
}
imgTmp := image.MustMul1(ts.FloatScalar(255.0), true)
retVal = imgTmp.MustTotype(gotch.Uint8, true)
retVal := imgTmp.MustTotype(gotch.Uint8, true)
return retVal
}
@ -208,7 +208,7 @@ func main() {
log.Fatal(err)
}
var darknet Darknet = ParseConfig(configPath)
var darknet *Darknet = ParseConfig(configPath)
vs := nn.NewVarStore(gotch.CPU)
model := darknet.BuildModel(vs.Root())

View File

@ -68,7 +68,7 @@ func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
@ -100,7 +100,7 @@ func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
@ -130,7 +130,7 @@ func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *Conv
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {

View File

@ -74,7 +74,7 @@ type Conv1D struct {
func NewConv1D(vs *Path, inDim, outDim, k int64, cfg *Conv1DConfig) *Conv1D {
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
@ -99,7 +99,7 @@ type Conv2D struct {
func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg *Conv2DConfig) *Conv2D {
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
@ -124,7 +124,7 @@ type Conv3D struct {
func NewConv3D(vs *Path, inDim, outDim, k int64, cfg *Conv3DConfig) *Conv3D {
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
@ -195,12 +195,12 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
configT := reflect.TypeOf(config)
var (
ws *ts.Tensor
bs *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
switch {
case len(ksizes) == 1 && configT.Name() == "Conv1DConfig":
cfg := config.(Conv1DConfig)
case len(ksizes) == 1 && configT.String() == "*nn.Conv1DConfig":
cfg := config.(*Conv1DConfig)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
@ -210,23 +210,23 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
return &Conv1D{
Ws: ws,
Bs: bs,
Config: &cfg,
Config: cfg,
}
case len(ksizes) == 2 && configT.Name() == "Conv2DConfig":
cfg := config.(Conv2DConfig)
case len(ksizes) == 2 && configT.String() == "*nn.Conv2DConfig":
cfg := config.(*Conv2DConfig)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
ws = vs.NewVar("weight", weightSize, config.(Conv2DConfig).WsInit)
ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return &Conv2D{
Ws: ws,
Bs: bs,
Config: &cfg,
Config: cfg,
}
case len(ksizes) == 3 && configT.Name() == "Conv3DConfig":
cfg := config.(Conv3DConfig)
case len(ksizes) == 3 && configT.String() == "*nn.Conv3DConfig":
cfg := config.(*Conv3DConfig)
if cfg.Bias {
bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
@ -236,10 +236,10 @@ func NewConv(vs *Path, inDim, outDim int64, ksizes []int64, config interface{})
return &Conv3D{
Ws: ws,
Bs: bs,
Config: &cfg,
Config: cfg,
}
default:
err := fmt.Errorf("Expected nd length from 1 to 3. Got %v\n", len(ksizes))
err := fmt.Errorf("Expected nd length from 1 to 3. Got %v - configT name: '%v'\n", len(ksizes), configT.String())
panic(err)
}
}

View File

@ -125,22 +125,8 @@ func (s *SequentialT) IsEmpty() (retVal bool) {
// Implement ModuleT interface for SequentialT:
// ==========================================
/*
* func (s SequentialT) Forward(xs ts.Tensor) (retVal ts.Tensor) {
* if s.IsEmpty() {
* return xs.MustShallowClone()
* }
*
* // forward sequentially
* var currTs ts.Tensor = xs
* for i := 0; i < len(s.layers); i++ {
* currTs = s.layers[i].Forward(currTs)
* }
*
* return currTs
* }
* */
func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
if s.IsEmpty() {
return xs.MustShallowClone()
}
@ -159,8 +145,7 @@ func (s *SequentialT) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
}
}
return
panic("Shouldn't reached here.")
}
// Add appends a layer after all the current layers.

View File

@ -154,7 +154,7 @@ func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1,
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
// with default weight, reduction and ignoreIndex
func (ts *Tensor) NLLLoss(target Tensor, del bool) (retVal *Tensor, err error) {
func (ts *Tensor) NLLLoss(target *Tensor, del bool) (retVal *Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
if del {
defer ts.MustDrop()
@ -174,7 +174,7 @@ func (ts *Tensor) NLLLoss(target Tensor, del bool) (retVal *Tensor, err error) {
return retVal, nil
}
func (ts *Tensor) MustNLLLoss(target Tensor, del bool) (retVal *Tensor) {
func (ts *Tensor) MustNLLLoss(target *Tensor, del bool) (retVal *Tensor) {
retVal, err := ts.NLLLoss(target, del)
if err != nil {
log.Fatal(err)