fixed nn/rnn NewLSTM() clashed weight names

This commit is contained in:
sugarme 2022-01-21 11:05:11 +11:00
parent 6e07f2cca1
commit a4e5f38705

View File

@ -1,6 +1,8 @@
package nn
import (
"fmt"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
@ -90,17 +92,32 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
flatWeights := make([]ts.Tensor, 0)
for i := 0; i < int(cfg.NumLayers); i++ {
for n := 0; n < int(numDirections); n++ {
if i != 0 {
inDim = hiddenDim * numDirections
}
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
bIh := vs.Zeros("b_ih", []int64{gateDim})
bHh := vs.Zeros("b_hh", []int64{gateDim})
if i != 0 {
inDim = hiddenDim * numDirections
}
switch numDirections {
case 1:
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
case 2: // bi-directional
// forward
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
// reverse
wIhR := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
wHhR := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
bIhR := vs.Zeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
bHhR := vs.Zeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
}
}