BREAKING CHANGE(nn/conv): removed pointer receiver at NewConv2D

This commit is contained in:
sugarme 2020-07-07 10:40:05 +10:00
parent ccb5ea95b6
commit f8ce5d0635
10 changed files with 38 additions and 16 deletions

View File

@ -28,11 +28,11 @@ type Net struct {
fc2 nn.Linear
}
func newNet(vs *nn.Path) Net {
func newNet(vs nn.Path) Net {
conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())
fc1 := nn.NewLinear(*vs, 1024, 1024, nn.DefaultLinearConfig())
fc2 := nn.NewLinear(*vs, 1024, 10, nn.DefaultLinearConfig())
fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())
fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())
return Net{
conv1,
@ -83,8 +83,7 @@ func runCNN1() {
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
// vs := nn.NewVarStore(gotch.CPU)
path := vs.Root()
net := newNet(&path)
net := newNet(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
if err != nil {
log.Fatal(err)
@ -161,8 +160,7 @@ func runCNN2() {
cuda := gotch.CudaBuilder(0)
vs := nn.NewVarStore(cuda.CudaIfAvailable())
path := vs.Root()
net := newNet(&path)
net := newNet(vs.Root())
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
if err != nil {
log.Fatal(err)

View File

@ -54,7 +54,7 @@ func main() {
// Pre-compute the final activations.
linear := nn.NewLinear(vs.Root(), 512, dataset.Labels, *nn.DefaultLinearConfig())
linear := nn.NewLinear(vs.Root(), 512, dataset.Labels, nn.DefaultLinearConfig())
sgd, err := nn.DefaultSGDConfig().Build(vs, 1e-3)
if err != nil {
log.Fatal(err)

View File

@ -90,7 +90,7 @@ type Conv2D struct {
Config Conv2DConfig
}
func NewConv2D(vs *Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D {
func NewConv2D(vs Path, inDim, outDim int64, k int64, cfg Conv2DConfig) Conv2D {
var conv Conv2D
conv.Config = cfg
if cfg.Bias {
@ -179,7 +179,7 @@ func NewConv(vs Path, inDim, outDim int64, ksizes []int64, config interface{}) C
case len(ksizes) == 1 && configVal.Type() == reflect.TypeOf(Conv1DConfig{}):
return NewConv1D(&vs, inDim, outDim, ksizes[0], config.(Conv1DConfig))
case len(ksizes) == 2 && configVal.Type() == reflect.TypeOf(Conv2DConfig{}):
return NewConv2D(&vs, inDim, outDim, ksizes[0], config.(Conv2DConfig))
return NewConv2D(vs, inDim, outDim, ksizes[0], config.(Conv2DConfig))
case len(ksizes) == 3 && configVal.Type() == reflect.TypeOf(Conv3DConfig{}):
return NewConv3D(&vs, inDim, outDim, ksizes[0], config.(Conv3DConfig))

View File

@ -13,7 +13,7 @@ func anConv2d(p nn.Path, cIn, cOut, ksize, padding, stride int64) (retVal nn.Con
config.Stride = []int64{stride, stride}
config.Padding = []int64{padding, padding}
return nn.NewConv2D(&p, cIn, cOut, ksize, config)
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
func anMaxPool2d(xs ts.Tensor, ksize, stride int64) (retVal ts.Tensor) {

View File

@ -18,7 +18,7 @@ func dnConv2d(p nn.Path, cIn, cOut, ksize, padding, stride int64) (retVal nn.Con
config.Padding = []int64{padding, padding}
config.Bias = false
return nn.NewConv2D(&p, cIn, cOut, ksize, config)
return nn.NewConv2D(p, cIn, cOut, ksize, config)
}
func denseLayer(p nn.Path, cIn, bnSize, growth int64) (retVal ts.ModuleT) {

View File

@ -75,7 +75,7 @@ func (p params) roundFilters(filters int64) (retVal int64) {
// Conv2D with same padding
func enConv2d(vs nn.Path, i, o, k int64, c nn.Conv2DConfig, train bool) (retVal ts.ModuleT) {
conv2d := nn.NewConv2D(&vs, i, o, k, c)
conv2d := nn.NewConv2D(vs, i, o, k, c)
s := c.Stride
return nn.NewFunc(func(xs ts.Tensor) (res ts.Tensor) {

View File

@ -20,7 +20,7 @@ func convBn(p nn.Path, cIn, cOut, ksize, pad, stride int64) (retVal ts.ModuleT)
seq := nn.SeqT()
convP := p.Sub("conv")
seq.Add(nn.NewConv2D(&convP, cIn, cOut, ksize, convConfig))
seq.Add(nn.NewConv2D(convP, cIn, cOut, ksize, convConfig))
seq.Add(nn.BatchNorm2D(p.Sub("bn"), cOut, bnConfig))

24
vision/mobilenet.go Normal file
View File

@ -0,0 +1,24 @@
package vision
// MobileNet V2 implementation.
// https://ai.googleblog.com/2018/04/mobilenetv2-next-generation-of-on.html
import (
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
// Conv2D + BatchNorm2D + ReLU6
func cbr(p nn.Path, cIn, cOut, ks, stride, g int64) (retVal ts.ModuleT) {
config := nn.DefaultConv2DConfig()
config.Stride = []int64{stride, stride}
pad := (ks - 1) / 2
config.Padding = []int64{pad, pad}
config.Groups = g
config.Bias = false
seq := nn.SeqT()
seq.Add(nn.NewConv2D(p.Sub("0"), cIn, cOut, ks, config))
return seq
}

View File

@ -18,7 +18,7 @@ func conv2d(path nn.Path, cIn, cOut, ksize, padding, stride int64) (retVal nn.Co
config.Padding = []int64{padding, padding}
config.Bias = false
return nn.NewConv2D(&path, cIn, cOut, ksize, config)
return nn.NewConv2D(path, cIn, cOut, ksize, config)
}
func downSample(path nn.Path, cIn, cOut, stride int64) (retVal ts.ModuleT) {

View File

@ -57,7 +57,7 @@ func vggConv2d(path nn.Path, cIn, cOut int64) (retVal nn.Conv2D) {
config.Stride = []int64{1, 1}
config.Padding = []int64{1, 1}
return nn.NewConv2D(&path, cIn, cOut, 3, config)
return nn.NewConv2D(path, cIn, cOut, 3, config)
}
func vgg(path nn.Path, config [][]int64, nclasses int64, batchNorm bool) nn.SequentialT {