fix(nn/sequential): fixed wrong index at Forward method
This commit is contained in:
parent
ed709027c0
commit
b792c6af3c
|
@ -16,7 +16,7 @@ const (
|
|||
LabelNN int64 = 10
|
||||
MnistDirNN string = "../../data/mnist"
|
||||
|
||||
epochsNN = 200
|
||||
epochsNN = 500
|
||||
|
||||
LrNN = 1e-2
|
||||
)
|
||||
|
@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module {
|
|||
|
||||
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
|
||||
|
||||
// n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
// return xs.MustRelu(true)
|
||||
// }))
|
||||
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||
return xs.MustRelu(false)
|
||||
}))
|
||||
|
||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
|
||||
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))
|
||||
|
|
|
@ -4,6 +4,7 @@ package nn
|
|||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "reflect"
|
||||
)
|
||||
|
||||
// Sequential is a layer (container) that combines multiple other layers.
|
||||
|
@ -20,7 +21,7 @@ func Seq() Sequential {
|
|||
//====================
|
||||
|
||||
// Len returns number of sub-layers embedded in this layer
|
||||
func (s *Sequential) Len() (retVal int64) {
|
||||
func (s Sequential) Len() (retVal int64) {
|
||||
return int64(len(s.layers))
|
||||
}
|
||||
|
||||
|
@ -74,7 +75,7 @@ func WithUint8(n uint8) func() uint8 {
|
|||
// ==========================================
|
||||
|
||||
// Forward implements Module interface for Sequential
|
||||
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||
if s.IsEmpty() {
|
||||
return xs.MustShallowClone()
|
||||
}
|
||||
|
@ -88,8 +89,8 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
|||
} else if i == len(s.layers)-1 {
|
||||
return s.layers[i].Forward(outs[i-1])
|
||||
} else {
|
||||
outs[i+1] = s.layers[i].Forward(outs[i-1])
|
||||
defer outs[i+1].MustDrop()
|
||||
outs[i] = s.layers[i].Forward(outs[i-1])
|
||||
defer outs[i].MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user