fixed nn.Seq Forward nil pointer if layers length = 1

This commit is contained in:
sugarme 2021-05-15 17:50:08 +10:00
parent 9135395bce
commit 720beffa62

View File

@ -81,6 +81,10 @@ func (s *Sequential) Forward(xs *ts.Tensor) (retVal *ts.Tensor) {
return xs.MustShallowClone()
}
if len(s.layers) == 1 {
return s.layers[0].Forward(xs)
}
// forward sequentially
outs := make([]ts.Tensor, len(s.layers))
for i := 0; i < len(s.layers); i++ {