fix(nn/rnn): fixed wrong type assertion

This commit is contained in:
sugarme 2020-07-28 16:15:57 +10:00
parent 74480db4d6
commit a72ded34a3

View File

@ -267,8 +267,7 @@ func (g GRU) Seq(input ts.Tensor) (output ts.Tensor, state State) {
output, state = g.SeqInit(input, inState)
// Delete intermediate tensors in inState
inState.(LSTMState).Tensor1.MustDrop()
inState.(LSTMState).Tensor2.MustDrop()
inState.(GRUState).Tensor.MustDrop()
return output, state
}