WIP(example/translation): more update

This commit is contained in:
sugarme 2020-08-03 09:56:59 +10:00
parent 98ca761d30
commit a0877826c9
2 changed files with 26 additions and 21 deletions

View File

@ -81,7 +81,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
// hence, just get state[0] of 3D tensor state // hence, just get state[0] of 3D tensor state
stateTs := state.Value() stateTs := state.Value()
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)}) state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1) catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
state0.MustDrop() state0.MustDrop()
@ -93,7 +92,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
sz1 := size3[0] sz1 := size3[0]
sz2 := size3[1] sz2 := size3[1]
sz3 := size3[2] sz3 := size3[2]
@ -151,7 +149,6 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
for _, v := range input { for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
s.MustDrop() s.MustDrop()
encOutputs = append(encOutputs, outTs) encOutputs = append(encOutputs, outTs)
@ -169,13 +166,20 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
prev := m.decoderStart.MustShallowClone() prev := m.decoderStart.MustShallowClone()
for _, s := range target { for _, s := range target {
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true) // TODO: fix memory leak at decoder.forward
outTsTest, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
outTsTest.MustDrop()
state.(nn.GRUState).Tensor.MustDrop() state.(nn.GRUState).Tensor.MustDrop()
state = outState state = outState
// NOTE. fake outTs to fix mem leak
outTs := ts.MustZeros([]int64{1, 1, 2815}, gotch.Float, m.device)
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true) targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
outTsView := outTs.MustView([]int64{1, -1}, false)
currLoss := outTsView.MustNLLLoss(targetTs, true)
targetTs.MustDrop()
loss.MustAdd_(currLoss) loss.MustAdd_(currLoss)
currLoss.MustDrop() currLoss.MustDrop()
@ -184,6 +188,9 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
noUseTs.MustDrop() noUseTs.MustDrop()
if m.decoderEos == outTs.Int64Values()[0] { if m.decoderEos == outTs.Int64Values()[0] {
prev.MustDrop()
prev = output
outTs.MustDrop()
break break
} }
@ -286,11 +293,11 @@ func main() {
model := newModel(vs.Root(), ilang, olang, HiddenSize) model := newModel(vs.Root(), ilang, olang, HiddenSize)
optConfig := nn.DefaultAdamConfig() // optConfig := nn.DefaultAdamConfig()
opt, err := optConfig.Build(vs, LearningRate) // opt, err := optConfig.Build(vs, LearningRate)
if err != nil { // if err != nil {
log.Fatal(err) // log.Fatal(err)
} // }
lossStats := newLossStats() lossStats := newLossStats()
@ -301,8 +308,8 @@ func main() {
input := pair.Val1 input := pair.Val1
target := pair.Val2 target := pair.Val2
loss := model.trainLoss(input, target) loss := model.trainLoss(input, target)
opt.BackwardStep(loss) // opt.BackwardStep(loss)
lossStats.update(loss.Float64Values()[0] / float64(len(target))) // lossStats.update(loss.Float64Values()[0] / float64(len(target)))
loss.MustDrop() loss.MustDrop()
if i%1000 == 0 { if i%1000 == 0 {
@ -311,11 +318,11 @@ func main() {
idx := rand.Intn(len(pairs)) idx := rand.Intn(len(pairs))
in := pairs[idx].Val1 in := pairs[idx].Val1
tgt := pairs[idx].Val2 tgt := pairs[idx].Val2
predict := model.predict(in) // predict := model.predict(in)
fmt.Printf("input: %v\n", ilang.SeqToString(in)) fmt.Printf("input: %v\n", ilang.SeqToString(in))
fmt.Printf("target: %v\n", olang.SeqToString(tgt)) fmt.Printf("target: %v\n", olang.SeqToString(tgt))
fmt.Printf("ouput: %v\n", olang.SeqToString(predict)) // fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
} }
} }
} }

View File

@ -161,13 +161,11 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
defer ts.MustDrop() defer ts.MustDrop()
} }
weight := NewTensor()
reduction := int64(1) // Mean of loss reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100) ignoreIndex := int64(-100)
defer C.free(unsafe.Pointer(ptr)) // defer C.free(unsafe.Pointer(ptr))
lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex) lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, nil, reduction, ignoreIndex)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
return retVal, err return retVal, err
} }
@ -177,8 +175,8 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
return retVal, nil return retVal, nil
} }
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) { func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) {
retVal, err := ts.NllLoss(target, del) retVal, err := ts.NLLLoss(target, del)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }