fix(nn/sequential): fixed wrong index at Forward method

This commit is contained in:
sugarme 2020-06-23 15:21:16 +10:00
parent ed709027c0
commit b792c6af3c
2 changed files with 9 additions and 8 deletions

View File

@ -16,7 +16,7 @@ const (
LabelNN int64 = 10
MnistDirNN string = "../../data/mnist"
epochsNN = 200
epochsNN = 500
LrNN = 1e-2
)
@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module {
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
// n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
// return xs.MustRelu(true)
// }))
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
return xs.MustRelu(false)
}))
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))

View File

@ -4,6 +4,7 @@ package nn
import (
ts "github.com/sugarme/gotch/tensor"
// "reflect"
)
// Sequential is a layer (container) that combines multiple other layers.
@ -20,7 +21,7 @@ func Seq() Sequential {
//====================
// Len returns number of sub-layers embedded in this layer
func (s *Sequential) Len() (retVal int64) {
func (s Sequential) Len() (retVal int64) {
return int64(len(s.layers))
}
@ -74,7 +75,7 @@ func WithUint8(n uint8) func() uint8 {
// ==========================================
// Forward implements Module interface for Sequential
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
if s.IsEmpty() {
return xs.MustShallowClone()
}
@ -88,8 +89,8 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
} else if i == len(s.layers)-1 {
return s.layers[i].Forward(outs[i-1])
} else {
outs[i+1] = s.layers[i].Forward(outs[i-1])
defer outs[i+1].MustDrop()
outs[i] = s.layers[i].Forward(outs[i-1])
defer outs[i].MustDrop()
}
}