WIP(example/translation): more update

This commit is contained in:
sugarme 2020-08-01 19:32:19 +10:00
parent b77fa54eb0
commit 2bd23fcc89
3 changed files with 31 additions and 21 deletions

View File

@ -88,5 +88,5 @@ func (l *Lang) SeqToString(seq []int) (retVal string) {
words = append(words, w)
}
return strings.Join(words, "")
return strings.Join(words, " ")
}

View File

@ -45,7 +45,7 @@ func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
retState = e.gru.Step(retTs, state).(nn.GRUState)
return retTs, retState
@ -83,6 +83,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
state0.MustDrop()
// NOTE. d.attn Ws shape : [512, 10]
appliedTs := catTs.Apply(d.attn)
@ -104,15 +105,22 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
shape := []int64{sz1, MaxLength - sz2, sz3}
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
zerosTs.MustDrop()
}
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
attnWeights.MustDrop()
encOutputsTs.MustDrop()
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
forwardTs.MustDrop()
attnApplied.MustDrop()
aTs := cTs.Apply(d.attnCombine)
cTs.MustDrop()
xsTs := aTs.MustRelu(true)
retState = d.gru.Step(xsTs, state).(nn.GRUState)
xsTs.MustDrop()
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
@ -139,14 +147,13 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
state := m.encoder.gru.ZeroState(1)
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
var encOutputs []ts.Tensor
for _, v := range input {
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
s.MustDrop()
encOutputs = append(encOutputs, outTs)
state.(nn.GRUState).Tensor.MustDrop()
state = outState
@ -168,17 +175,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
state = outState
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
ws := ts.NewTensor()
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
loss.MustAdd_(currLoss)
currLoss.MustDrop()
_, output := outTs.MustTopK(1, -1, true, true)
noUseTs, output := outTs.MustTopK(1, -1, true, true)
noUseTs.MustDrop()
if m.decoderEos == outTs.Int64Values()[0] {
break
@ -186,8 +189,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
prev.MustDrop()
prev = output
outTs.MustDrop()
}
state.(nn.GRUState).Tensor.MustDrop()
stackTs.MustDrop()
prev.MustDrop()
return loss
}
@ -293,12 +301,12 @@ func main() {
input := pair.Val1
target := pair.Val2
loss := model.trainLoss(input, target)
panic("reached")
opt.BackwardStep(loss)
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
loss.MustDrop()
if i%1000 == 0 {
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
fmt.Printf("Trained %v samples - Avg. Loss: %v\n", i, lossStats.avgAndReset())
for predIdx := 1; predIdx <= 5; predIdx++ {
idx := rand.Intn(len(pairs))
in := pairs[idx].Val1

View File

@ -1,8 +1,6 @@
package nn
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
@ -228,6 +226,12 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
}
}
if vs.Device().IsCuda() {
// NOTE. 3 is for GRU
// ref. rnn.cpp in Pytorch
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
}
return GRU{
flatWeights: flatWeights,
hiddenDim: hiddenDim,
@ -256,14 +260,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
unsqueezedInput := input.MustUnsqueeze(1, false)
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
output, state := g.SeqInit(unsqueezedInput, inState)
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
// it should be cleaned up here to prevent memory hold-up.
output.MustDrop()
unsqueezedInput.MustDrop()
return state
}