From a0877826c9eeebb74b625eb1b6b24cd2d8020c19 Mon Sep 17 00:00:00 2001 From: sugarme Date: Mon, 3 Aug 2020 09:56:59 +1000 Subject: [PATCH] WIP(example/translation): more update --- example/translation/main.go | 37 ++++++++++++++++++++++--------------- tensor/patch.go | 10 ++++------ 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/example/translation/main.go b/example/translation/main.go index 4902d5a..d3da39c 100644 --- a/example/translation/main.go +++ b/example/translation/main.go @@ -81,7 +81,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, // hence, just get state[0] of 3D tensor state stateTs := state.Value() state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)}) - catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1) state0.MustDrop() @@ -93,7 +92,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, if err != nil { log.Fatal(err) } - sz1 := size3[0] sz2 := size3[1] sz3 := size3[2] @@ -151,7 +149,6 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { for _, v := range input { s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) - outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) s.MustDrop() encOutputs = append(encOutputs, outTs) @@ -169,13 +166,20 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { prev := m.decoderStart.MustShallowClone() for _, s := range target { - outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true) - + // TODO: fix memory leak at decoder.forward + outTsTest, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true) + outTsTest.MustDrop() state.(nn.GRUState).Tensor.MustDrop() state = outState + // NOTE. fake outTs to fix mem leak + outTs := ts.MustZeros([]int64{1, 1, 2815}, gotch.Float, m.device) + targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true) - currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false) + + outTsView := outTs.MustView([]int64{1, -1}, false) + currLoss := outTsView.MustNLLLoss(targetTs, true) + targetTs.MustDrop() loss.MustAdd_(currLoss) currLoss.MustDrop() @@ -184,6 +188,9 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { noUseTs.MustDrop() if m.decoderEos == outTs.Int64Values()[0] { + prev.MustDrop() + prev = output + outTs.MustDrop() break } @@ -286,11 +293,11 @@ func main() { model := newModel(vs.Root(), ilang, olang, HiddenSize) - optConfig := nn.DefaultAdamConfig() - opt, err := optConfig.Build(vs, LearningRate) - if err != nil { - log.Fatal(err) - } + // optConfig := nn.DefaultAdamConfig() + // opt, err := optConfig.Build(vs, LearningRate) + // if err != nil { + // log.Fatal(err) + // } lossStats := newLossStats() @@ -301,8 +308,8 @@ func main() { input := pair.Val1 target := pair.Val2 loss := model.trainLoss(input, target) - opt.BackwardStep(loss) - lossStats.update(loss.Float64Values()[0] / float64(len(target))) + // opt.BackwardStep(loss) + // lossStats.update(loss.Float64Values()[0] / float64(len(target))) loss.MustDrop() if i%1000 == 0 { @@ -311,11 +318,11 @@ func main() { idx := rand.Intn(len(pairs)) in := pairs[idx].Val1 tgt := pairs[idx].Val2 - predict := model.predict(in) + // predict := model.predict(in) fmt.Printf("input: %v\n", ilang.SeqToString(in)) fmt.Printf("target: %v\n", olang.SeqToString(tgt)) - fmt.Printf("ouput: %v\n", olang.SeqToString(predict)) + // fmt.Printf("ouput: %v\n", olang.SeqToString(predict)) } } } diff --git a/tensor/patch.go b/tensor/patch.go index 2d96582..0dcac81 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -161,13 +161,11 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) { defer ts.MustDrop() } - weight := NewTensor() - reduction := int64(1) // Mean of loss ignoreIndex := int64(-100) - defer C.free(unsafe.Pointer(ptr)) + // defer C.free(unsafe.Pointer(ptr)) - lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex) + lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, nil, reduction, ignoreIndex) if err = TorchErr(); err != nil { return retVal, err } @@ -177,8 +175,8 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) { return retVal, nil } -func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) { - retVal, err := ts.NllLoss(target, del) +func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) { + retVal, err := ts.NLLLoss(target, del) if err != nil { log.Fatal(err) }