BREAKING CHANGE(nn/linear): removed pointer receiver
This commit is contained in:
parent
4b91f15865
commit
41f3f8af2a
|
@ -31,14 +31,14 @@ type Net struct {
|
||||||
func newNet(vs *nn.Path) Net {
|
func newNet(vs *nn.Path) Net {
|
||||||
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
|
||||||
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
|
||||||
fc1 := nn.NewLinear(*vs, 1024, 1024, *nn.DefaultLinearConfig())
|
fc1 := nn.NewLinear(*vs, 1024, 1024, nn.DefaultLinearConfig())
|
||||||
fc2 := nn.NewLinear(*vs, 1024, 10, *nn.DefaultLinearConfig())
|
fc2 := nn.NewLinear(*vs, 1024, 10, nn.DefaultLinearConfig())
|
||||||
|
|
||||||
return Net{
|
return Net{
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
*fc1,
|
fc1,
|
||||||
*fc2}
|
fc2}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
|
|
|
@ -26,13 +26,13 @@ var l nn.Linear
|
||||||
func netInit(vs nn.Path) ts.Module {
|
func netInit(vs nn.Path) ts.Module {
|
||||||
n := nn.Seq()
|
n := nn.Seq()
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, *nn.DefaultLinearConfig()))
|
n.Add(nn.NewLinear(vs, ImageDimNN, HiddenNodesNN, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
n.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MustRelu(false)
|
return xs.MustRelu(false)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, *nn.DefaultLinearConfig()))
|
n.Add(nn.NewLinear(vs, HiddenNodesNN, LabelNN, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
return &n
|
return &n
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,6 +20,13 @@ func (fn Func) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
return fn.f(xs)
|
return fn.f(xs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ForwardT implements ModuleT for Func object as well.
|
||||||
|
//
|
||||||
|
// NOTE: train param will not be used.
|
||||||
|
func (fn Func) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
|
return fn.f(xs)
|
||||||
|
}
|
||||||
|
|
||||||
type FuncT struct {
|
type FuncT struct {
|
||||||
f func(ts.Tensor, bool) ts.Tensor
|
f func(ts.Tensor, bool) ts.Tensor
|
||||||
}
|
}
|
||||||
|
|
19
nn/linear.go
19
nn/linear.go
|
@ -18,8 +18,8 @@ type LinearConfig struct {
|
||||||
|
|
||||||
// DefaultLinearConfig creates default LinearConfig with
|
// DefaultLinearConfig creates default LinearConfig with
|
||||||
// weights initiated using KaimingUniform and Bias is set to true
|
// weights initiated using KaimingUniform and Bias is set to true
|
||||||
func DefaultLinearConfig() *LinearConfig {
|
func DefaultLinearConfig() LinearConfig {
|
||||||
return &LinearConfig{
|
return LinearConfig{
|
||||||
WsInit: NewKaimingUniformInit(),
|
WsInit: NewKaimingUniformInit(),
|
||||||
BsInit: nil,
|
BsInit: nil,
|
||||||
Bias: true,
|
Bias: true,
|
||||||
|
@ -37,7 +37,7 @@ type Linear struct {
|
||||||
// inDim - input dimension (x) [input features - columns]
|
// inDim - input dimension (x) [input features - columns]
|
||||||
// outDim - output dimension (y) [output features - columns]
|
// outDim - output dimension (y) [output features - columns]
|
||||||
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
// NOTE: w will have shape{outDim, inDim}; b will have shape{outDim}
|
||||||
func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear {
|
func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) Linear {
|
||||||
|
|
||||||
var bs ts.Tensor
|
var bs ts.Tensor
|
||||||
// bs has size of output dimension
|
// bs has size of output dimension
|
||||||
|
@ -55,7 +55,7 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Linear{
|
return Linear{
|
||||||
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
Ws: vs.NewVar("weight", []int64{outDim, inDim}, c.WsInit).MustT(false),
|
||||||
Bs: bs,
|
Bs: bs,
|
||||||
}
|
}
|
||||||
|
@ -89,7 +89,16 @@ func NewLinear(vs Path, inDim, outDim int64, c LinearConfig) *Linear {
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1
|
// 1 1 1
|
||||||
// 1 1 1 ]
|
// 1 1 1 ]
|
||||||
func (l *Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
func (l Linear) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
|
||||||
|
mul := xs.MustMatMul(l.Ws, false)
|
||||||
|
return mul.MustAdd(l.Bs, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForwardT implements ModuleT interface for Linear layer.
|
||||||
|
//
|
||||||
|
// NOTE: train param will not be used.
|
||||||
|
func (l Linear) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
|
|
||||||
mul := xs.MustMatMul(l.Ws, false)
|
mul := xs.MustMatMul(l.Ws, false)
|
||||||
return mul.MustAdd(l.Bs, true)
|
return mul.MustAdd(l.Bs, true)
|
||||||
|
|
|
@ -1310,6 +1310,43 @@ func MustDropout(input Tensor, p float64, train bool) (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Dropout(p float64, train bool, del bool) (retVal Tensor, err error) {
|
||||||
|
|
||||||
|
if del {
|
||||||
|
defer ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
var ctrain int
|
||||||
|
switch train {
|
||||||
|
case true:
|
||||||
|
ctrain = 1
|
||||||
|
case false:
|
||||||
|
ctrain = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgDropout(ptr, ts.ctensor, p, ctrain)
|
||||||
|
err = TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustDropout(p float64, train bool, del bool) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Dropout(p, train, del)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
func (ts Tensor) Dropout_(p float64, train bool) {
|
func (ts Tensor) Dropout_(p float64, train bool) {
|
||||||
|
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
|
@ -84,7 +84,7 @@ func resnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVal nn.FuncT
|
||||||
if nclasses > 0 {
|
if nclasses > 0 {
|
||||||
// With final layer
|
// With final layer
|
||||||
linearConfig := nn.DefaultLinearConfig()
|
linearConfig := nn.DefaultLinearConfig()
|
||||||
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, *linearConfig)
|
fc := nn.NewLinear(path.Sub("fc"), 512, nclasses, linearConfig)
|
||||||
|
|
||||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
c1 := xs.Apply(conv1)
|
c1 := xs.Apply(conv1)
|
||||||
|
@ -207,7 +207,7 @@ func bottleneckResnet(path nn.Path, nclasses int64, c1, c2, c3, c4 int64) (retVa
|
||||||
layer4 := bottleneckLayer(path.Sub("layer4"), 4*256, 512, 2, c4)
|
layer4 := bottleneckLayer(path.Sub("layer4"), 4*256, 512, 2, c4)
|
||||||
|
|
||||||
if nclasses > 0 {
|
if nclasses > 0 {
|
||||||
fc := nn.NewLinear(path.Sub("fc"), 4*512, nclasses, *nn.DefaultLinearConfig())
|
fc := nn.NewLinear(path.Sub("fc"), 4*512, nclasses, nn.DefaultLinearConfig())
|
||||||
|
|
||||||
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
return nn.NewFuncT(func(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||||
c1 := xs.Apply(conv1)
|
c1 := xs.Apply(conv1)
|
||||||
|
|
|
@ -77,41 +77,43 @@ func vgg(path nn.Path, config [][]int64, nclasses int64, batchNorm bool) nn.Sequ
|
||||||
seq.Add(nn.BatchNorm2D(f.Sub(fmt.Sprintf("%v", bnLen)), cOut, nn.DefaultBatchNormConfig()))
|
seq.Add(nn.BatchNorm2D(f.Sub(fmt.Sprintf("%v", bnLen)), cOut, nn.DefaultBatchNormConfig()))
|
||||||
}
|
}
|
||||||
|
|
||||||
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MustRelu(false)
|
return xs.MustRelu(false)
|
||||||
})
|
}))
|
||||||
|
|
||||||
cIn = cOut
|
cIn = cOut
|
||||||
} // end of inner For loop
|
} // end of inner For loop
|
||||||
|
|
||||||
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MaxPool2DDefault(2, false)
|
return xs.MaxPool2DDefault(2, false)
|
||||||
})
|
}))
|
||||||
} // end of outer For loop
|
} // end of outer For loop
|
||||||
|
|
||||||
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.FlatView()
|
return xs.FlatView()
|
||||||
})
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("0")), 512*7*7, 4096, *nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("0")), 512*7*7, 4096, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MustRelu(false)
|
return xs.MustRelu(false)
|
||||||
})
|
}))
|
||||||
|
|
||||||
seq.AddFnT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.Dropout(0.5, train)
|
return xs.MustDropout(0.5, train, false)
|
||||||
})
|
}))
|
||||||
|
|
||||||
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("3")), 4096, 4096, *nn.DefaultLinearConfig()))
|
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("3")), 4096, 4096, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
seq.AddFn(func(xs ts.Tensor) ts.Tensor {
|
seq.AddFn(nn.NewFunc(func(xs ts.Tensor) ts.Tensor {
|
||||||
return xs.MustRelu(false)
|
return xs.MustRelu(false)
|
||||||
})
|
}))
|
||||||
|
|
||||||
seq.AddFnT(func(xs ts.Tensor, train bool) ts.Tensor {
|
seq.AddFn(nn.NewFuncT(func(xs ts.Tensor, train bool) ts.Tensor {
|
||||||
return xs.Dropout(0.5, train)
|
return xs.MustDropout(0.5, train, false)
|
||||||
})
|
}))
|
||||||
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("6")), 4096, nclasses, *nn.DefaultLinearConfig()))
|
|
||||||
|
seq.Add(nn.NewLinear(c.Sub(fmt.Sprint("6")), 4096, nclasses, nn.DefaultLinearConfig()))
|
||||||
|
|
||||||
return seq
|
return seq
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user