diff --git a/example/mnist/nn.go b/example/mnist/nn.go index 09d386b..f4e0640 100644 --- a/example/mnist/nn.go +++ b/example/mnist/nn.go @@ -16,7 +16,7 @@ const ( LabelNN int64 = 10 MnistDirNN string = "../../data/mnist" - epochsNN = 200 + epochsNN = 500 LrNN = 1e-2 ) @@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module { n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig())) - // n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor { - // return xs.MustRelu(true) - // })) + n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor { + return xs.MustRelu(false) + })) n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig())) // n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig())) diff --git a/nn/sequential.go b/nn/sequential.go index 220d0a6..800c41a 100644 --- a/nn/sequential.go +++ b/nn/sequential.go @@ -4,6 +4,7 @@ package nn import ( ts "github.com/sugarme/gotch/tensor" + // "reflect" ) // Sequential is a layer (container) that combines multiple other layers. @@ -20,7 +21,7 @@ func Seq() Sequential { //==================== // Len returns number of sub-layers embedded in this layer -func (s *Sequential) Len() (retVal int64) { +func (s Sequential) Len() (retVal int64) { return int64(len(s.layers)) } @@ -74,7 +75,7 @@ func WithUint8(n uint8) func() uint8 { // ========================================== // Forward implements Module interface for Sequential -func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { +func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { if s.IsEmpty() { return xs.MustShallowClone() } @@ -88,8 +89,8 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) { } else if i == len(s.layers)-1 { return s.layers[i].Forward(outs[i-1]) } else { - outs[i+1] = s.layers[i].Forward(outs[i-1]) - defer outs[i+1].MustDrop() + outs[i] = s.layers[i].Forward(outs[i-1]) + defer outs[i].MustDrop() } }