diff --git a/example/translation/lang.go b/example/translation/lang.go index b427ef1..080c4c1 100644 --- a/example/translation/lang.go +++ b/example/translation/lang.go @@ -88,5 +88,5 @@ func (l *Lang) SeqToString(seq []int) (retVal string) { words = append(words, w) } - return strings.Join(words, "") + return strings.Join(words, " ") } diff --git a/example/translation/main.go b/example/translation/main.go index 2930d92..4902d5a 100644 --- a/example/translation/main.go +++ b/example/translation/main.go @@ -45,7 +45,7 @@ func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) { func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) { - retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false) + retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true) retState = e.gru.Step(retTs, state).(nn.GRUState) return retTs, retState @@ -83,6 +83,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)}) catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1) + state0.MustDrop() // NOTE. d.attn Ws shape : [512, 10] appliedTs := catTs.Apply(d.attn) @@ -104,15 +105,22 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor, shape := []int64{sz1, MaxLength - sz2, sz3} zerosTs := ts.MustZeros(shape, gotch.Float, d.device) encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1) + zerosTs.MustDrop() } attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true) - fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize()) - fmt.Printf("xs shape: %v\n", forwardTs.MustSize()) + attnWeights.MustDrop() + encOutputsTs.MustDrop() - xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true) + cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1) + forwardTs.MustDrop() + attnApplied.MustDrop() + aTs := cTs.Apply(d.attnCombine) + cTs.MustDrop() + xsTs := aTs.MustRelu(true) retState = d.gru.Step(xsTs, state).(nn.GRUState) + xsTs.MustDrop() retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true) @@ -139,14 +147,13 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { state := m.encoder.gru.ZeroState(1) - fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize()) var encOutputs []ts.Tensor for _, v := range input { s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true) outTs, outState := m.encoder.forward(s, state.(nn.GRUState)) - + s.MustDrop() encOutputs = append(encOutputs, outTs) state.(nn.GRUState).Tensor.MustDrop() state = outState @@ -168,17 +175,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { state = outState targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true) - fmt.Printf("targetTs shape: %v\n", targetTs.MustSize()) - fmt.Printf("outTs shape: %v\n", outTs.MustSize()) - - // TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35 - ws := ts.NewTensor() - currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false) + currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false) loss.MustAdd_(currLoss) currLoss.MustDrop() - _, output := outTs.MustTopK(1, -1, true, true) + noUseTs, output := outTs.MustTopK(1, -1, true, true) + noUseTs.MustDrop() if m.decoderEos == outTs.Int64Values()[0] { break @@ -186,8 +189,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) { prev.MustDrop() prev = output + outTs.MustDrop() } + state.(nn.GRUState).Tensor.MustDrop() + stackTs.MustDrop() + prev.MustDrop() + return loss } @@ -293,12 +301,12 @@ func main() { input := pair.Val1 target := pair.Val2 loss := model.trainLoss(input, target) - panic("reached") opt.BackwardStep(loss) lossStats.update(loss.Float64Values()[0] / float64(len(target))) + loss.MustDrop() if i%1000 == 0 { - fmt.Printf("%v %v\n", i, lossStats.avgAndReset()) + fmt.Printf("Trained %v samples - Avg. Loss: %v\n", i, lossStats.avgAndReset()) for predIdx := 1; predIdx <= 5; predIdx++ { idx := rand.Intn(len(pairs)) in := pairs[idx].Val1 diff --git a/nn/rnn.go b/nn/rnn.go index fb2873b..cffbb6b 100644 --- a/nn/rnn.go +++ b/nn/rnn.go @@ -1,8 +1,6 @@ package nn import ( - "fmt" - "github.com/sugarme/gotch" ts "github.com/sugarme/gotch/tensor" ) @@ -228,6 +226,12 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) { } } + if vs.Device().IsCuda() { + // NOTE. 3 is for GRU + // ref. rnn.cpp in Pytorch + ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional) + } + return GRU{ flatWeights: flatWeights, hiddenDim: hiddenDim, @@ -256,14 +260,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) { func (g GRU) Step(input ts.Tensor, inState State) (retVal State) { unsqueezedInput := input.MustUnsqueeze(1, false) - fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize()) - fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize()) - output, state := g.SeqInit(unsqueezedInput, inState) // NOTE: though we won't use `output`, it is a Ctensor created in C land, so // it should be cleaned up here to prevent memory hold-up. output.MustDrop() + unsqueezedInput.MustDrop() return state }