fix(nn/rnn): fixed wrong type assertion
This commit is contained in:
parent
74480db4d6
commit
a72ded34a3
|
@ -267,8 +267,7 @@ func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) {
|
|||
output, state = g.SeqInit(input, inState)
|
||||
|
||||
// Delete intermediate tensors in inState
|
||||
inState.(LSTMState).Tensor1.MustDrop()
|
||||
inState.(LSTMState).Tensor2.MustDrop()
|
||||
inState.(GRUState).Tensor.MustDrop()
|
||||
|
||||
return output, state
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user