fixed nn/rnn NewLSTM() clashed weight names
This commit is contained in:
parent
6e07f2cca1
commit
a4e5f38705
35
nn/rnn.go
35
nn/rnn.go
|
@ -1,6 +1,8 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
@ -90,17 +92,32 @@ func NewLSTM(vs *Path, inDim, hiddenDim int64, cfg *RNNConfig) *LSTM {
|
|||
flatWeights := make([]ts.Tensor, 0)
|
||||
|
||||
for i := 0; i < int(cfg.NumLayers); i++ {
|
||||
for n := 0; n < int(numDirections); n++ {
|
||||
if i != 0 {
|
||||
inDim = hiddenDim * numDirections
|
||||
}
|
||||
|
||||
wIh := vs.KaimingUniform("w_ih", []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform("w_hh", []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros("b_ih", []int64{gateDim})
|
||||
bHh := vs.Zeros("b_hh", []int64{gateDim})
|
||||
if i != 0 {
|
||||
inDim = hiddenDim * numDirections
|
||||
}
|
||||
switch numDirections {
|
||||
case 1:
|
||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
|
||||
case 2: // bi-directional
|
||||
// forward
|
||||
wIh := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d", i), []int64{gateDim, inDim})
|
||||
wHh := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d", i), []int64{gateDim, hiddenDim})
|
||||
bIh := vs.Zeros(fmt.Sprintf("bias_ih_l%d", i), []int64{gateDim})
|
||||
bHh := vs.Zeros(fmt.Sprintf("bias_hh_l%d", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIh, *wHh, *bIh, *bHh)
|
||||
|
||||
// reverse
|
||||
wIhR := vs.KaimingUniform(fmt.Sprintf("weight_ih_l%d_reverse", i), []int64{gateDim, inDim})
|
||||
wHhR := vs.KaimingUniform(fmt.Sprintf("weight_hh_l%d_reverse", i), []int64{gateDim, hiddenDim})
|
||||
bIhR := vs.Zeros(fmt.Sprintf("bias_ih_l%d_reverse", i), []int64{gateDim})
|
||||
bHhR := vs.Zeros(fmt.Sprintf("bias_hh_l%d_reverse", i), []int64{gateDim})
|
||||
flatWeights = append(flatWeights, *wIhR, *wHhR, *bIhR, *bHhR)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user