fix(example/translation): more memory fix

This commit is contained in:
sugarme 2020-08-03 15:40:04 +10:00
parent a0877826c9
commit 4721ffe670
2 changed files with 15 additions and 20 deletions

View File

@ -73,19 +73,20 @@ func newDecoder(vs nn.Path, hiddenDim, outDim int64) (retVal Decoder) {
}
func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, isTraining bool) (retTs ts.Tensor, retState nn.GRUState) {
forwardTsTmp := d.embedding.Forward(xs)
forwardTsTmp.MustDropout_(0.1, isTraining)
forwardTs := forwardTsTmp.MustView([]int64{1, -1}, true)
// NOTE. forwardTs shape: [1, 256] state [1, 1, 256]
// hence, just get state[0] of 3D tensor state
stateTs := state.Value()
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
state0.MustDrop()
stateTs := state.Value().MustShallowClone().MustView([]int64{1, -1}, true)
catTs := ts.MustCat([]ts.Tensor{forwardTs, stateTs}, 1)
stateTs.MustDrop()
// NOTE. d.attn Ws shape : [512, 10]
appliedTs := catTs.Apply(d.attn)
catTs.MustDrop()
attnWeights := appliedTs.MustUnsqueeze(0, true)
size3, err := encOutputs.Size3()
@ -107,7 +108,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
}
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
attnWeights.MustDrop()
encOutputsTs.MustDrop()
cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
@ -167,14 +167,10 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
for _, s := range target {
// TODO: fix memory leak at decoder.forward
outTsTest, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
outTsTest.MustDrop()
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
state.(nn.GRUState).Tensor.MustDrop()
state = outState
// NOTE. fake outTs to fix mem leak
outTs := ts.MustZeros([]int64{1, 1, 2815}, gotch.Float, m.device)
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
outTsView := outTs.MustView([]int64{1, -1}, false)
@ -293,11 +289,11 @@ func main() {
model := newModel(vs.Root(), ilang, olang, HiddenSize)
// optConfig := nn.DefaultAdamConfig()
// opt, err := optConfig.Build(vs, LearningRate)
// if err != nil {
// log.Fatal(err)
// }
optConfig := nn.DefaultAdamConfig()
opt, err := optConfig.Build(vs, LearningRate)
if err != nil {
log.Fatal(err)
}
lossStats := newLossStats()
@ -308,8 +304,8 @@ func main() {
input := pair.Val1
target := pair.Val2
loss := model.trainLoss(input, target)
// opt.BackwardStep(loss)
// lossStats.update(loss.Float64Values()[0] / float64(len(target)))
opt.BackwardStep(loss)
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
loss.MustDrop()
if i%1000 == 0 {
@ -318,11 +314,11 @@ func main() {
idx := rand.Intn(len(pairs))
in := pairs[idx].Val1
tgt := pairs[idx].Val2
// predict := model.predict(in)
predict := model.predict(in)
fmt.Printf("input: %v\n", ilang.SeqToString(in))
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
// fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
}
}
}

View File

@ -259,7 +259,6 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
unsqueezedInput := input.MustUnsqueeze(1, false)
output, state := g.SeqInit(unsqueezedInput, inState)
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so