WIP(example/translation): more update
This commit is contained in:
parent
b77fa54eb0
commit
2bd23fcc89
|
@ -88,5 +88,5 @@ func (l *Lang) SeqToString(seq []int) (retVal string) {
|
||||||
words = append(words, w)
|
words = append(words, w)
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(words, "")
|
return strings.Join(words, " ")
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ func newEncoder(vs nn.Path, inDim, hiddenDim int64) (retVal Encoder) {
|
||||||
|
|
||||||
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
|
func (e Encoder) forward(xs ts.Tensor, state nn.GRUState) (retTs ts.Tensor, retState nn.GRUState) {
|
||||||
|
|
||||||
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, false)
|
retTs = e.embedding.Forward(xs).MustView([]int64{1, -1}, true)
|
||||||
retState = e.gru.Step(retTs, state).(nn.GRUState)
|
retState = e.gru.Step(retTs, state).(nn.GRUState)
|
||||||
|
|
||||||
return retTs, retState
|
return retTs, retState
|
||||||
|
@ -83,6 +83,7 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
||||||
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
||||||
|
|
||||||
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
||||||
|
state0.MustDrop()
|
||||||
|
|
||||||
// NOTE. d.attn Ws shape : [512, 10]
|
// NOTE. d.attn Ws shape : [512, 10]
|
||||||
appliedTs := catTs.Apply(d.attn)
|
appliedTs := catTs.Apply(d.attn)
|
||||||
|
@ -104,15 +105,22 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
||||||
shape := []int64{sz1, MaxLength - sz2, sz3}
|
shape := []int64{sz1, MaxLength - sz2, sz3}
|
||||||
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
|
zerosTs := ts.MustZeros(shape, gotch.Float, d.device)
|
||||||
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
|
encOutputsTs = ts.MustCat([]ts.Tensor{encOutputs, zerosTs}, 1)
|
||||||
|
zerosTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
|
attnApplied := attnWeights.MustBmm(encOutputsTs, true).MustSqueeze1(1, true)
|
||||||
fmt.Printf("attnApplied shape: %v\n", attnApplied.MustSize())
|
attnWeights.MustDrop()
|
||||||
fmt.Printf("xs shape: %v\n", forwardTs.MustSize())
|
encOutputsTs.MustDrop()
|
||||||
|
|
||||||
xsTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1).Apply(d.attnCombine).MustRelu(true)
|
cTs := ts.MustCat([]ts.Tensor{forwardTs, attnApplied}, 1)
|
||||||
|
forwardTs.MustDrop()
|
||||||
|
attnApplied.MustDrop()
|
||||||
|
aTs := cTs.Apply(d.attnCombine)
|
||||||
|
cTs.MustDrop()
|
||||||
|
xsTs := aTs.MustRelu(true)
|
||||||
|
|
||||||
retState = d.gru.Step(xsTs, state).(nn.GRUState)
|
retState = d.gru.Step(xsTs, state).(nn.GRUState)
|
||||||
|
xsTs.MustDrop()
|
||||||
|
|
||||||
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
|
retTs = d.linear.Forward(retState.Value()).MustLogSoftmax(-1, gotch.Float, true)
|
||||||
|
|
||||||
|
@ -139,14 +147,13 @@ func newModel(vs nn.Path, ilang Lang, olang Lang, hiddenDim int64) (retVal Model
|
||||||
|
|
||||||
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
state := m.encoder.gru.ZeroState(1)
|
state := m.encoder.gru.ZeroState(1)
|
||||||
fmt.Printf("state shape: %v\n", state.(nn.GRUState).Value().MustSize())
|
|
||||||
var encOutputs []ts.Tensor
|
var encOutputs []ts.Tensor
|
||||||
|
|
||||||
for _, v := range input {
|
for _, v := range input {
|
||||||
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||||
|
|
||||||
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||||
|
s.MustDrop()
|
||||||
encOutputs = append(encOutputs, outTs)
|
encOutputs = append(encOutputs, outTs)
|
||||||
state.(nn.GRUState).Tensor.MustDrop()
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
state = outState
|
state = outState
|
||||||
|
@ -168,17 +175,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
state = outState
|
state = outState
|
||||||
|
|
||||||
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
||||||
fmt.Printf("targetTs shape: %v\n", targetTs.MustSize())
|
currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
|
||||||
fmt.Printf("outTs shape: %v\n", outTs.MustSize())
|
|
||||||
|
|
||||||
// TODO: fix the error: input tensor should be 1D or 2D at /pytorch/aten/src/THCUNN/generic/ClassNLLCriterion.cu:35
|
|
||||||
ws := ts.NewTensor()
|
|
||||||
currLoss := outTs.MustNllLoss(targetTs, ws, int64(1), -100, false)
|
|
||||||
|
|
||||||
loss.MustAdd_(currLoss)
|
loss.MustAdd_(currLoss)
|
||||||
currLoss.MustDrop()
|
currLoss.MustDrop()
|
||||||
|
|
||||||
_, output := outTs.MustTopK(1, -1, true, true)
|
noUseTs, output := outTs.MustTopK(1, -1, true, true)
|
||||||
|
noUseTs.MustDrop()
|
||||||
|
|
||||||
if m.decoderEos == outTs.Int64Values()[0] {
|
if m.decoderEos == outTs.Int64Values()[0] {
|
||||||
break
|
break
|
||||||
|
@ -186,8 +189,13 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
|
|
||||||
prev.MustDrop()
|
prev.MustDrop()
|
||||||
prev = output
|
prev = output
|
||||||
|
outTs.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
|
stackTs.MustDrop()
|
||||||
|
prev.MustDrop()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -293,12 +301,12 @@ func main() {
|
||||||
input := pair.Val1
|
input := pair.Val1
|
||||||
target := pair.Val2
|
target := pair.Val2
|
||||||
loss := model.trainLoss(input, target)
|
loss := model.trainLoss(input, target)
|
||||||
panic("reached")
|
|
||||||
opt.BackwardStep(loss)
|
opt.BackwardStep(loss)
|
||||||
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
||||||
|
loss.MustDrop()
|
||||||
|
|
||||||
if i%1000 == 0 {
|
if i%1000 == 0 {
|
||||||
fmt.Printf("%v %v\n", i, lossStats.avgAndReset())
|
fmt.Printf("Trained %v samples - Avg. Loss: %v\n", i, lossStats.avgAndReset())
|
||||||
for predIdx := 1; predIdx <= 5; predIdx++ {
|
for predIdx := 1; predIdx <= 5; predIdx++ {
|
||||||
idx := rand.Intn(len(pairs))
|
idx := rand.Intn(len(pairs))
|
||||||
in := pairs[idx].Val1
|
in := pairs[idx].Val1
|
||||||
|
|
12
nn/rnn.go
12
nn/rnn.go
|
@ -1,8 +1,6 @@
|
||||||
package nn
|
package nn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
@ -228,6 +226,12 @@ func NewGRU(vs Path, inDim, hiddenDim int64, cfg RNNConfig) (retVal GRU) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if vs.Device().IsCuda() {
|
||||||
|
// NOTE. 3 is for GRU
|
||||||
|
// ref. rnn.cpp in Pytorch
|
||||||
|
ts.Must_CudnnRnnFlattenWeight(flatWeights, 4, inDim, 3, hiddenDim, cfg.NumLayers, cfg.BatchFirst, cfg.Bidirectional)
|
||||||
|
}
|
||||||
|
|
||||||
return GRU{
|
return GRU{
|
||||||
flatWeights: flatWeights,
|
flatWeights: flatWeights,
|
||||||
hiddenDim: hiddenDim,
|
hiddenDim: hiddenDim,
|
||||||
|
@ -256,14 +260,12 @@ func (g GRU) ZeroState(batchDim int64) (retVal State) {
|
||||||
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
func (g GRU) Step(input ts.Tensor, inState State) (retVal State) {
|
||||||
unsqueezedInput := input.MustUnsqueeze(1, false)
|
unsqueezedInput := input.MustUnsqueeze(1, false)
|
||||||
|
|
||||||
fmt.Printf("unsqueezed input shape: %v\n", unsqueezedInput.MustSize())
|
|
||||||
fmt.Printf("inState Ts size: %v\n", inState.(GRUState).Tensor.MustSize())
|
|
||||||
|
|
||||||
output, state := g.SeqInit(unsqueezedInput, inState)
|
output, state := g.SeqInit(unsqueezedInput, inState)
|
||||||
|
|
||||||
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
// NOTE: though we won't use `output`, it is a Ctensor created in C land, so
|
||||||
// it should be cleaned up here to prevent memory hold-up.
|
// it should be cleaned up here to prevent memory hold-up.
|
||||||
output.MustDrop()
|
output.MustDrop()
|
||||||
|
unsqueezedInput.MustDrop()
|
||||||
|
|
||||||
return state
|
return state
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user