fix(nn/sequential): fixed wrong index at Forward method
This commit is contained in:
parent
ed709027c0
commit
b792c6af3c
|
@ -16,7 +16,7 @@ const (
|
||||||
LabelNN int64 = 10
|
LabelNN int64 = 10
|
||||||
MnistDirNN string = "../../data/mnist"
|
MnistDirNN string = "../../data/mnist"
|
||||||
|
|
||||||
epochsNN = 200
|
epochsNN = 500
|
||||||
|
|
||||||
LrNN = 1e-2
|
LrNN = 1e-2
|
||||||
)
|
)
|
||||||
|
@ -28,9 +28,9 @@ func netInit(vs nn.Path) ts.Module {
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
|
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
// n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
// return xs.MustRelu(true)
|
return xs.MustRelu(false)
|
||||||
// }))
|
}))
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
|
||||||
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))
|
// n.Add(nn.NewLinear(vs, ImageDimNN, LabelNN, nn.DefaultLinearConfig()))
|
||||||
|
|
|
@ -4,6 +4,7 @@ package nn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
// "reflect"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Sequential is a layer (container) that combines multiple other layers.
|
// Sequential is a layer (container) that combines multiple other layers.
|
||||||
|
@ -20,7 +21,7 @@ func Seq() Sequential {
|
||||||
//====================
|
//====================
|
||||||
|
|
||||||
// Len returns number of sub-layers embedded in this layer
|
// Len returns number of sub-layers embedded in this layer
|
||||||
func (s *Sequential) Len() (retVal int64) {
|
func (s Sequential) Len() (retVal int64) {
|
||||||
return int64(len(s.layers))
|
return int64(len(s.layers))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -74,7 +75,7 @@ func WithUint8(n uint8) func() uint8 {
|
||||||
// ==========================================
|
// ==========================================
|
||||||
|
|
||||||
// Forward implements Module interface for Sequential
|
// Forward implements Module interface for Sequential
|
||||||
func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
func (s *Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
if s.IsEmpty() {
|
if s.IsEmpty() {
|
||||||
return xs.MustShallowClone()
|
return xs.MustShallowClone()
|
||||||
}
|
}
|
||||||
|
@ -88,8 +89,8 @@ func (s Sequential) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
} else if i == len(s.layers)-1 {
|
} else if i == len(s.layers)-1 {
|
||||||
return s.layers[i].Forward(outs[i-1])
|
return s.layers[i].Forward(outs[i-1])
|
||||||
} else {
|
} else {
|
||||||
outs[i+1] = s.layers[i].Forward(outs[i-1])
|
outs[i] = s.layers[i].Forward(outs[i-1])
|
||||||
defer outs[i+1].MustDrop()
|
defer outs[i].MustDrop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user