From 71fb5ae79b0c29cd209d2727e8c4a90b39d7ab9a Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 24 Jun 2020 12:47:10 +1000 Subject: [PATCH] fix(KaimingUniformInit): fixed incorrect init of KaimingUniform method --- example/mnist/cnn.go | 13 ++-- example/mnist/main.go | 3 +- nn/conv-transpose.go | 134 ++++++++++++++++++++++++++++++++++++++++++ nn/init.go | 31 +++++++++- 4 files changed, 173 insertions(+), 8 deletions(-) create mode 100644 nn/conv-transpose.go diff --git a/example/mnist/cnn.go b/example/mnist/cnn.go index 6db47c9..f23cb72 100644 --- a/example/mnist/cnn.go +++ b/example/mnist/cnn.go @@ -85,7 +85,7 @@ func runCNN1() { // vs := nn.NewVarStore(gotch.CPU) path := vs.Root() net := newNet(&path) - opt, err := nn.DefaultAdamConfig().Build(vs, LrNN) + opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN) if err != nil { log.Fatal(err) } @@ -136,7 +136,7 @@ func runCNN1() { bImages.MustDrop() bLabels.MustDrop() // logits.MustDrop() - loss.MustDrop() + // loss.MustDrop() } // testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024) @@ -185,7 +185,7 @@ func runCNN2() { bImages := item.Data.MustTo(vs.Device(), true) bLabels := item.Label.MustTo(vs.Device(), true) - _ = ts.MustGradSetEnabled(true) + // _ = ts.MustGradSetEnabled(true) logits := net.ForwardT(bImages, true) loss := logits.CrossEntropyForLogits(bLabels) @@ -199,10 +199,13 @@ func runCNN2() { loss.MustDrop() } - testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN) + fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal) - fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100) + // testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN) + // fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100) } + testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN) + fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100) fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes()) } diff --git a/example/mnist/main.go b/example/mnist/main.go index 1064903..4061897 100644 --- a/example/mnist/main.go +++ b/example/mnist/main.go @@ -21,7 +21,8 @@ func main() { case "nn": runNN() case "cnn": - runCNN2() + // runCNN2() + runCNN1() default: panic("No specified model to run") } diff --git a/nn/conv-transpose.go b/nn/conv-transpose.go new file mode 100644 index 0000000..422a741 --- /dev/null +++ b/nn/conv-transpose.go @@ -0,0 +1,134 @@ +package nn + +// A two dimension transposed convolution layer. + +import ( + ts "github.com/sugarme/gotch/tensor" +) + +type ConvTranspose1DConfig struct { + Stride []int64 + Padding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init +} + +type ConvTranspose2DConfig struct { + Stride []int64 + Padding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init +} + +type ConvTranspose3DConfig struct { + Stride []int64 + Padding []int64 + Dilation []int64 + Groups int64 + Bias bool + WsInit Init + BsInit Init +} + +// DefaultConvConfig create a default 1D ConvConfig +func DefaultConvTranspose1DConfig() ConvTranspose1DConfig { + return ConvTranspose1DConfig{ + Stride: []int64{1}, + Padding: []int64{0}, + Dilation: []int64{1}, + Groups: 1, + Bias: true, + WsInit: NewKaimingUniformInit(), + BsInit: NewConstInit(float64(0.0)), + } +} + +// DefaultConvConfig2D creates a default 2D ConvConfig +func DefaultConvTranspose2DConfig() ConvTranspose2DConfig { + return ConvTranspose2DConfig{ + Stride: []int64{1, 1}, + Padding: []int64{0, 0}, + Dilation: []int64{1, 1}, + Groups: 1, + Bias: true, + WsInit: NewKaimingUniformInit(), + BsInit: NewConstInit(float64(0.0)), + } +} + +type ConvTranspose1D struct { + Ws ts.Tensor + Bs ts.Tensor // optional + Config ConvTranspose1DConfig +} + +func NewConvTranspose1D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose1DConfig) ConvTranspose1D { + var conv ConvTranspose1D + conv.Config = cfg + if cfg.Bias { + conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) + } + weightSize := []int64{outDim, int64(inDim / cfg.Groups)} + weightSize = append(weightSize, k) + conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) + + return conv +} + +type ConvTranspose2D struct { + Ws ts.Tensor + Bs ts.Tensor // optional + Config ConvTranspose2DConfig +} + +func NewConvTranspose2D(vs *Path, inDim, outDim int64, k int64, cfg ConvTranspose2DConfig) ConvTranspose2D { + var conv ConvTranspose2D + conv.Config = cfg + if cfg.Bias { + conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) + } + weightSize := []int64{outDim, int64(inDim / cfg.Groups)} + weightSize = append(weightSize, k, k) + conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) + + return conv +} + +type ConvTranspose3D struct { + Ws ts.Tensor + Bs ts.Tensor // optional + Config ConvTranspose3DConfig +} + +func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DConfig) ConvTranspose3D { + var conv ConvTranspose3D + conv.Config = cfg + if cfg.Bias { + conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit) + } + weightSize := []int64{outDim, int64(inDim / cfg.Groups)} + weightSize = append(weightSize, k, k, k) + conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit) + + return conv +} + +// Implement Module for Conv1D, Conv2D, Conv3D: +// ============================================ + +/* func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor { + * return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) + * } + * + * func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor { + * return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) + * } + * func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor { + * return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups) + * } */ diff --git a/nn/init.go b/nn/init.go index afd6e5e..f011f2d 100644 --- a/nn/init.go +++ b/nn/init.go @@ -148,7 +148,13 @@ func NewKaimingUniformInit() kaimingUniformInit { } func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) { - fanIn := factorial(uint64(len(dims) - 1)) + var fanIn int64 + if len(dims) == 1 { + log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims) + } else { + fanIn = product(dims[1:]) + } + bound := math.Sqrt(1.0 / float64(fanIn)) kind := gotch.Float.CInt() retVal = ts.MustZeros(dims, kind, device.CInt()) @@ -157,6 +163,20 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa return retVal } +// product calculates product by multiplying elements +func product(dims []int64) (retVal int64) { + + for i, v := range dims { + if i == 0 { + retVal = v + } else { + retVal = retVal * v + } + } + + return retVal +} + func factorial(n uint64) (result uint64) { if n > 0 { result = n * factorial(n-1) @@ -170,7 +190,14 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) { if err != nil { log.Fatalf("uniformInit - Set method call error: %v\n", err) } - fanIn := factorial(uint64(len(dims) - 1)) + + var fanIn int64 + if len(dims) == 1 { + log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize()) + } else { + fanIn = product(dims[1:]) + } + bound := math.Sqrt(1.0 / float64(fanIn)) tensor.Uniform_(-bound, bound) }