WIP(example/translation): more update

This commit is contained in:
sugarme 2020-08-01 19:32:19 +10:00
parent b77fa54eb0
commit 2bd23fcc89
3 changed files with 31 additions and 21 deletions

View File

@ -88,5 +88,5 @@ func (l *Lang) SeqToString(seq []int) (retVal string) {
words = append(words, w) words = append(words, w)
} }
return strings.Join(words, "") return strings.Join(words, " ")
} }

View File

@ -45,7 +45,7 @@ func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) { func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false) retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
retState = e.gru.Step(retTs, state).(nn.GRUState) retState = e.gru.Step(retTs, state).(nn.GRUState)
return retTs, retState return retTs, retState
@ -83,6 +83,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)}) state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1) catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
state0.MustDrop()
// NOTE. d.attn Ws shape : [512, 10] // NOTE. d.attn Ws shape : [512, 10]
appliedTs := catTs.Apply(d.attn) appliedTs := catTs.Apply(d.attn)
@ -104,15 +105,22 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
shape := []int64{sz1, MaxLength - sz2, sz3} shape := []int64{sz1, MaxLength - sz2, sz3}
zerosTs := ts.MustZeros(shape, gotch.Float, d.device) zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1) encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
zerosTs.MustDrop()
} }
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true) attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize()) attnWeights.MustDrop()
fmt.Printf("xs shape: %v\n", forwardTs.MustSize()) encOutputsTs.MustDrop()
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true) cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
forwardTs.MustDrop()
attnApplied.MustDrop()
aTs := cTs.Apply(d.attnCombine)
cTs.MustDrop()
xsTs := aTs.MustRelu(true)
retState = d.gru.Step(xsTs, state).(nn.GRUState) retState = d.gru.Step(xsTs, state).(nn.GRUState)
xsTs.MustDrop()
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true) retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
@ -139,14 +147,13 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
state := m.encoder.gru.ZeroState(1) state := m.encoder.gru.ZeroState(1)
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
var encOutputs []ts.Tensor var encOutputs []ts.Tensor
for _, v := range input { for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
s.MustDrop()
encOutputs = append(encOutputs, outTs) encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop() state.(nn.GRUState).Tensor.MustDrop()
state = outState state = outState
@ -168,17 +175,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
state = outState state = outState
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true) targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize()) currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
ws := ts.NewTensor()
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
loss.MustAdd_(currLoss) loss.MustAdd_(currLoss)
currLoss.MustDrop() currLoss.MustDrop()
_, output := outTs.MustTopK(1, -1, true, true) noUseTs, output := outTs.MustTopK(1, -1, true, true)
noUseTs.MustDrop()
if m.decoderEos == outTs.Int64Values()[0] { if m.decoderEos == outTs.Int64Values()[0] {
break break
@ -186,8 +189,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
prev.MustDrop() prev.MustDrop()
prev = output prev = output
outTs.MustDrop()
} }
state.(nn.GRUState).Tensor.MustDrop()
stackTs.MustDrop()
prev.MustDrop()
return loss return loss
} }
@ -293,12 +301,12 @@ func main() {
input := pair.Val1 input := pair.Val1
target := pair.Val2 target := pair.Val2
loss := model.trainLoss(input, target) loss := model.trainLoss(input, target)
panic("reached")
opt.BackwardStep(loss) opt.BackwardStep(loss)
lossStats.update(loss.Float64Values()[0] / float64(len(target))) lossStats.update(loss.Float64Values()[0] / float64(len(target)))
loss.MustDrop()
if i%1000 == 0 { if i%1000 == 0 {
fmt.Printf("%v %v\n", i, lossStats.avgAndReset()) fmt.Printf("Trained %v samples - Avg. Loss: %v\n", i, lossStats.avgAndReset())
for predIdx := 1; predIdx <= 5; predIdx++ { for predIdx := 1; predIdx <= 5; predIdx++ {
idx := rand.Intn(len(pairs)) idx := rand.Intn(len(pairs))
in := pairs[idx].Val1 in := pairs[idx].Val1

View File

@ -1,8 +1,6 @@
package nn package nn
import ( import (
"fmt"
"github.com/sugarme/gotch" "github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor" ts "github.com/sugarme/gotch/tensor"
) )
@ -228,6 +226,12 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
} }
} }
if vs.Device().IsCuda() {
// NOTE. 3 is for GRU
// ref. rnn.cpp in Pytorch
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
}
return GRU{ return GRU{
flatWeights: flatWeights, flatWeights: flatWeights,
hiddenDim: hiddenDim, hiddenDim: hiddenDim,
@ -256,14 +260,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) { func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
unsqueezedInput := input.MustUnsqueeze(1, false) unsqueezedInput := input.MustUnsqueeze(1, false)
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
output, state := g.SeqInit(unsqueezedInput, inState) output, state := g.SeqInit(unsqueezedInput, inState)
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so // NOTE: though we won't use `output`, it is a Ctensor created in C land, so
// it should be cleaned up here to prevent memory hold-up. // it should be cleaned up here to prevent memory hold-up.
output.MustDrop() output.MustDrop()
unsqueezedInput.MustDrop()
return state return state
} }