WIP(example/translation): more update
This commit is contained in:
parent
98ca761d30
commit
a0877826c9
|
@ -81,7 +81,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
||||||
// hence, just get state[0] of 3D tensor state
|
// hence, just get state[0] of 3D tensor state
|
||||||
stateTs := state.Value()
|
stateTs := state.Value()
|
||||||
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
state0 := stateTs.Idx([]ts.TensorIndexer{ts.NewSelect(0)})
|
||||||
|
|
||||||
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
catTs := ts.MustCat([]ts.Tensor{forwardTs, state0}, 1)
|
||||||
state0.MustDrop()
|
state0.MustDrop()
|
||||||
|
|
||||||
|
@ -93,7 +92,6 @@ func (d Decoder) forward(xs ts.Tensor, state nn.GRUState, encOutputs ts.Tensor,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
sz1 := size3[0]
|
sz1 := size3[0]
|
||||||
sz2 := size3[1]
|
sz2 := size3[1]
|
||||||
sz3 := size3[2]
|
sz3 := size3[2]
|
||||||
|
@ -151,7 +149,6 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
|
|
||||||
for _, v := range input {
|
for _, v := range input {
|
||||||
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
s := ts.MustOfSlice([]int64{int64(v)}).MustTo(m.device, true)
|
||||||
|
|
||||||
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
outTs, outState := m.encoder.forward(s, state.(nn.GRUState))
|
||||||
s.MustDrop()
|
s.MustDrop()
|
||||||
encOutputs = append(encOutputs, outTs)
|
encOutputs = append(encOutputs, outTs)
|
||||||
|
@ -169,13 +166,20 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
prev := m.decoderStart.MustShallowClone()
|
prev := m.decoderStart.MustShallowClone()
|
||||||
|
|
||||||
for _, s := range target {
|
for _, s := range target {
|
||||||
outTs, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
// TODO: fix memory leak at decoder.forward
|
||||||
|
outTsTest, outState := m.decoder.forward(prev, state.(nn.GRUState), stackTs, true)
|
||||||
|
outTsTest.MustDrop()
|
||||||
state.(nn.GRUState).Tensor.MustDrop()
|
state.(nn.GRUState).Tensor.MustDrop()
|
||||||
state = outState
|
state = outState
|
||||||
|
|
||||||
|
// NOTE. fake outTs to fix mem leak
|
||||||
|
outTs := ts.MustZeros([]int64{1, 1, 2815}, gotch.Float, m.device)
|
||||||
|
|
||||||
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
targetTs := ts.MustOfSlice([]int64{int64(s)}).MustTo(m.device, true)
|
||||||
currLoss := outTs.MustView([]int64{1, -1}, false).MustNllLoss(targetTs, ts.NewTensor(), int64(1), -100, false)
|
|
||||||
|
outTsView := outTs.MustView([]int64{1, -1}, false)
|
||||||
|
currLoss := outTsView.MustNLLLoss(targetTs, true)
|
||||||
|
targetTs.MustDrop()
|
||||||
|
|
||||||
loss.MustAdd_(currLoss)
|
loss.MustAdd_(currLoss)
|
||||||
currLoss.MustDrop()
|
currLoss.MustDrop()
|
||||||
|
@ -184,6 +188,9 @@ func (m *Model) trainLoss(input []int, target []int) (retVal ts.Tensor) {
|
||||||
noUseTs.MustDrop()
|
noUseTs.MustDrop()
|
||||||
|
|
||||||
if m.decoderEos == outTs.Int64Values()[0] {
|
if m.decoderEos == outTs.Int64Values()[0] {
|
||||||
|
prev.MustDrop()
|
||||||
|
prev = output
|
||||||
|
outTs.MustDrop()
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -286,11 +293,11 @@ func main() {
|
||||||
|
|
||||||
model := newModel(vs.Root(), ilang, olang, HiddenSize)
|
model := newModel(vs.Root(), ilang, olang, HiddenSize)
|
||||||
|
|
||||||
optConfig := nn.DefaultAdamConfig()
|
// optConfig := nn.DefaultAdamConfig()
|
||||||
opt, err := optConfig.Build(vs, LearningRate)
|
// opt, err := optConfig.Build(vs, LearningRate)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
log.Fatal(err)
|
// log.Fatal(err)
|
||||||
}
|
// }
|
||||||
|
|
||||||
lossStats := newLossStats()
|
lossStats := newLossStats()
|
||||||
|
|
||||||
|
@ -301,8 +308,8 @@ func main() {
|
||||||
input := pair.Val1
|
input := pair.Val1
|
||||||
target := pair.Val2
|
target := pair.Val2
|
||||||
loss := model.trainLoss(input, target)
|
loss := model.trainLoss(input, target)
|
||||||
opt.BackwardStep(loss)
|
// opt.BackwardStep(loss)
|
||||||
lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
// lossStats.update(loss.Float64Values()[0] / float64(len(target)))
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
|
|
||||||
if i%1000 == 0 {
|
if i%1000 == 0 {
|
||||||
|
@ -311,11 +318,11 @@ func main() {
|
||||||
idx := rand.Intn(len(pairs))
|
idx := rand.Intn(len(pairs))
|
||||||
in := pairs[idx].Val1
|
in := pairs[idx].Val1
|
||||||
tgt := pairs[idx].Val2
|
tgt := pairs[idx].Val2
|
||||||
predict := model.predict(in)
|
// predict := model.predict(in)
|
||||||
|
|
||||||
fmt.Printf("input: %v\n", ilang.SeqToString(in))
|
fmt.Printf("input: %v\n", ilang.SeqToString(in))
|
||||||
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
|
fmt.Printf("target: %v\n", olang.SeqToString(tgt))
|
||||||
fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
|
// fmt.Printf("ouput: %v\n", olang.SeqToString(predict))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,13 +161,11 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
|
||||||
defer ts.MustDrop()
|
defer ts.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
weight := NewTensor()
|
|
||||||
|
|
||||||
reduction := int64(1) // Mean of loss
|
reduction := int64(1) // Mean of loss
|
||||||
ignoreIndex := int64(-100)
|
ignoreIndex := int64(-100)
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
// defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
lib.AtgNLLLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
|
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, nil, reduction, ignoreIndex)
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
@ -177,8 +175,8 @@ func (ts Tensor) NLLLoss(target Tensor, del bool) (retVal Tensor, err error) {
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) MustNllLoss(target Tensor, del bool) (retVal Tensor) {
|
func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) {
|
||||||
retVal, err := ts.NllLoss(target, del)
|
retVal, err := ts.NLLLoss(target, del)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user