WIP(example/translation): more update
This commit is contained in:
parent
b77fa54eb0
commit
2bd23fcc89
|
@ -88,5 +88,5 @@ func (l *Lang) SeqToString(seq []int) (retVal string) {
|
|||
words = append(words, w)
|
||||
}
|
||||
|
||||
return strings.Join(words, "")
|
||||
return strings.Join(words, " ")
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
|
|||
|
||||
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
|
||||
|
||||
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
|
||||
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
|
||||
retState = e.gru.Step(retTs, state).(nn.GRUState)
|
||||
|
||||
return retTs, retState
|
||||
|
@ -83,6 +83,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
|||
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
||||
|
||||
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
||||
state0.MustDrop()
|
||||
|
||||
// NOTE. d.attn Ws shape : [512, 10]
|
||||
appliedTs := catTs.Apply(d.attn)
|
||||
|
@ -104,15 +105,22 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
|||
shape := []int64{sz1, MaxLength - sz2, sz3}
|
||||
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
|
||||
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
|
||||
zerosTs.MustDrop()
|
||||
}
|
||||
|
||||
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
|
||||
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
|
||||
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
|
||||
attnWeights.MustDrop()
|
||||
encOutputsTs.MustDrop()
|
||||
|
||||
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
|
||||
cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
|
||||
forwardTs.MustDrop()
|
||||
attnApplied.MustDrop()
|
||||
aTs := cTs.Apply(d.attnCombine)
|
||||
cTs.MustDrop()
|
||||
xsTs := aTs.MustRelu(true)
|
||||
|
||||
retState = d.gru.Step(xsTs, state).(nn.GRUState)
|
||||
xsTs.MustDrop()
|
||||
|
||||
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
|
||||
|
||||
|
@ -139,14 +147,13 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model
|
|||
|
||||
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||
state := m.encoder.gru.ZeroState(1)
|
||||
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
|
||||
var encOutputs []ts.Tensor
|
||||
|
||||
for _, v := range input {
|
||||
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||
|
||||
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||
|
||||
s.MustDrop()
|
||||
encOutputs = append(encOutputs, outTs)
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
state = outState
|
||||
|
@ -168,17 +175,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
|||
state = outState
|
||||
|
||||
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
||||
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
|
||||
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
|
||||
|
||||
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
|
||||
ws := ts.NewTensor()
|
||||
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
|
||||
currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
|
||||
|
||||
loss.MustAdd_(currLoss)
|
||||
currLoss.MustDrop()
|
||||
|
||||
_, output := outTs.MustTopK(1, -1, true, true)
|
||||
noUseTs, output := outTs.MustTopK(1, -1, true, true)
|
||||
noUseTs.MustDrop()
|
||||
|
||||
if m.decoderEos == outTs.Int64Values()[0] {
|
||||
break
|
||||
|
@ -186,8 +189,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
|||
|
||||
prev.MustDrop()
|
||||
prev = output
|
||||
outTs.MustDrop()
|
||||
}
|
||||
|
||||
state.(nn.GRUState).Tensor.MustDrop()
|
||||
stackTs.MustDrop()
|
||||
prev.MustDrop()
|
||||
|
||||
return loss
|
||||
|
||||
}
|
||||
|
@ -293,12 +301,12 @@ func main() {
|
|||
input := pair.Val1
|
||||
target := pair.Val2
|
||||
loss := model.trainLoss(input, target)
|
||||
panic("reached")
|
||||
opt.BackwardStep(loss)
|
||||
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
||||
loss.MustDrop()
|
||||
|
||||
if i%1000 == 0 {
|
||||
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
|
||||
fmt.Printf("Trained %v samples - Avg. Loss: %v\n", i, lossStats.avgAndReset())
|
||||
for predIdx := 1; predIdx <= 5; predIdx++ {
|
||||
idx := rand.Intn(len(pairs))
|
||||
in := pairs[idx].Val1
|
||||
|
|
12
nn/rnn.go
12
nn/rnn.go
|
@ -1,8 +1,6 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -228,6 +226,12 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
|||
}
|
||||
}
|
||||
|
||||
if vs.Device().IsCuda() {
|
||||
// NOTE. 3 is for GRU
|
||||
// ref. rnn.cpp in Pytorch
|
||||
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||
}
|
||||
|
||||
return GRU{
|
||||
flatWeights: flatWeights,
|
||||
hiddenDim: hiddenDim,
|
||||
|
@ -256,14 +260,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
|||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||
|
||||
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
|
||||
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
|
||||
|
||||
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||
|
||||
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
||||
// it should be cleaned up here to prevent memory hold-up.
|
||||
output.MustDrop()
|
||||
unsqueezedInput.MustDrop()
|
||||
|
||||
return state
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user