fix(KaimingUniformInit): fixed incorrect init of KaimingUniform method
This commit is contained in:
parent
31a3f0e587
commit
71fb5ae79b
|
@ -85,7 +85,7 @@ func runCNN1() {
|
||||||
// vs := nn.NewVarStore(gotch.CPU)
|
// vs := nn.NewVarStore(gotch.CPU)
|
||||||
path := vs.Root()
|
path := vs.Root()
|
||||||
net := newNet(&path)
|
net := newNet(&path)
|
||||||
opt, err := nn.DefaultAdamConfig().Build(vs, LrNN)
|
opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -136,7 +136,7 @@ func runCNN1() {
|
||||||
bImages.MustDrop()
|
bImages.MustDrop()
|
||||||
bLabels.MustDrop()
|
bLabels.MustDrop()
|
||||||
// logits.MustDrop()
|
// logits.MustDrop()
|
||||||
loss.MustDrop()
|
// loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
// testAccuracy := ts.BatchAccuracyForLogitsIdx(net, testImages, testLabels, vs.Device(), 1024)
|
||||||
|
@ -185,7 +185,7 @@ func runCNN2() {
|
||||||
bImages := item.Data.MustTo(vs.Device(), true)
|
bImages := item.Data.MustTo(vs.Device(), true)
|
||||||
bLabels := item.Label.MustTo(vs.Device(), true)
|
bLabels := item.Label.MustTo(vs.Device(), true)
|
||||||
|
|
||||||
_ = ts.MustGradSetEnabled(true)
|
// _ = ts.MustGradSetEnabled(true)
|
||||||
|
|
||||||
logits := net.ForwardT(bImages, true)
|
logits := net.ForwardT(bImages, true)
|
||||||
loss := logits.CrossEntropyForLogits(bLabels)
|
loss := logits.CrossEntropyForLogits(bLabels)
|
||||||
|
@ -199,10 +199,13 @@ func runCNN2() {
|
||||||
loss.MustDrop()
|
loss.MustDrop()
|
||||||
}
|
}
|
||||||
|
|
||||||
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\n", epoch, lossVal)
|
||||||
|
|
||||||
fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
|
// testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||||
|
// fmt.Printf("Epoch:\t %v\tLoss: \t %.2f\t Accuracy: %.2f\n", epoch, lossVal, testAcc*100)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
testAcc := ts.BatchAccuracyForLogits(net, ds.TestImages, ds.TestLabels, vs.Device(), batchCNN)
|
||||||
|
fmt.Printf("Loss: \t %.2f\t Accuracy: %.2f\n", lossVal, testAcc*100)
|
||||||
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
fmt.Printf("Taken time:\t%.2f mins\n", time.Since(startTime).Minutes())
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,7 +21,8 @@ func main() {
|
||||||
case "nn":
|
case "nn":
|
||||||
runNN()
|
runNN()
|
||||||
case "cnn":
|
case "cnn":
|
||||||
runCNN2()
|
// runCNN2()
|
||||||
|
runCNN1()
|
||||||
default:
|
default:
|
||||||
panic("No specified model to run")
|
panic("No specified model to run")
|
||||||
}
|
}
|
||||||
|
|
134
nn/conv-transpose.go
Normal file
134
nn/conv-transpose.go
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
package nn
|
||||||
|
|
||||||
|
// A two dimension transposed convolution layer.
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConvTranspose1DConfig struct {
|
||||||
|
Stride []int64
|
||||||
|
Padding []int64
|
||||||
|
Dilation []int64
|
||||||
|
Groups int64
|
||||||
|
Bias bool
|
||||||
|
WsInit Init
|
||||||
|
BsInit Init
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConvTranspose2DConfig struct {
|
||||||
|
Stride []int64
|
||||||
|
Padding []int64
|
||||||
|
Dilation []int64
|
||||||
|
Groups int64
|
||||||
|
Bias bool
|
||||||
|
WsInit Init
|
||||||
|
BsInit Init
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConvTranspose3DConfig struct {
|
||||||
|
Stride []int64
|
||||||
|
Padding []int64
|
||||||
|
Dilation []int64
|
||||||
|
Groups int64
|
||||||
|
Bias bool
|
||||||
|
WsInit Init
|
||||||
|
BsInit Init
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConvConfig create a default 1D ConvConfig
|
||||||
|
func DefaultConvTranspose1DConfig() ConvTranspose1DConfig {
|
||||||
|
return ConvTranspose1DConfig{
|
||||||
|
Stride: []int64{1},
|
||||||
|
Padding: []int64{0},
|
||||||
|
Dilation: []int64{1},
|
||||||
|
Groups: 1,
|
||||||
|
Bias: true,
|
||||||
|
WsInit: NewKaimingUniformInit(),
|
||||||
|
BsInit: NewConstInit(float64(0.0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConvConfig2D creates a default 2D ConvConfig
|
||||||
|
func DefaultConvTranspose2DConfig() ConvTranspose2DConfig {
|
||||||
|
return ConvTranspose2DConfig{
|
||||||
|
Stride: []int64{1, 1},
|
||||||
|
Padding: []int64{0, 0},
|
||||||
|
Dilation: []int64{1, 1},
|
||||||
|
Groups: 1,
|
||||||
|
Bias: true,
|
||||||
|
WsInit: NewKaimingUniformInit(),
|
||||||
|
BsInit: NewConstInit(float64(0.0)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConvTranspose1D struct {
|
||||||
|
Ws ts.Tensor
|
||||||
|
Bs ts.Tensor // optional
|
||||||
|
Config ConvTranspose1DConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConvTranspose1D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
|
||||||
|
var conv ConvTranspose1D
|
||||||
|
conv.Config = cfg
|
||||||
|
if cfg.Bias {
|
||||||
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
|
}
|
||||||
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
|
weightSize = append(weightSize, k)
|
||||||
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConvTranspose2D struct {
|
||||||
|
Ws ts.Tensor
|
||||||
|
Bs ts.Tensor // optional
|
||||||
|
Config ConvTranspose2DConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConvTranspose2D(vs *Path, inDim, outDim int64, k int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
|
||||||
|
var conv ConvTranspose2D
|
||||||
|
conv.Config = cfg
|
||||||
|
if cfg.Bias {
|
||||||
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
|
}
|
||||||
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
|
weightSize = append(weightSize, k, k)
|
||||||
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConvTranspose3D struct {
|
||||||
|
Ws ts.Tensor
|
||||||
|
Bs ts.Tensor // optional
|
||||||
|
Config ConvTranspose3DConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConvTranspose3D(vs *Path, inDim, outDim, k int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
|
||||||
|
var conv ConvTranspose3D
|
||||||
|
conv.Config = cfg
|
||||||
|
if cfg.Bias {
|
||||||
|
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
|
||||||
|
}
|
||||||
|
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
|
||||||
|
weightSize = append(weightSize, k, k, k)
|
||||||
|
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
|
||||||
|
|
||||||
|
return conv
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement Module for Conv1D, Conv2D, Conv3D:
|
||||||
|
// ============================================
|
||||||
|
|
||||||
|
/* func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
|
* return ts.MustConvTranspose1D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
|
* return ts.MustConvTranspose2D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
|
* }
|
||||||
|
* func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
|
||||||
|
* return ts.MustConvTranspose3D(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.Dilation, c.Config.Groups)
|
||||||
|
* } */
|
31
nn/init.go
31
nn/init.go
|
@ -148,7 +148,13 @@ func NewKaimingUniformInit() kaimingUniformInit {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVal ts.Tensor) {
|
||||||
fanIn := factorial(uint64(len(dims) - 1))
|
var fanIn int64
|
||||||
|
if len(dims) == 1 {
|
||||||
|
log.Fatalf("KaimingUniformInit method call: dims (%v) should have length > 1", dims)
|
||||||
|
} else {
|
||||||
|
fanIn = product(dims[1:])
|
||||||
|
}
|
||||||
|
|
||||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||||
kind := gotch.Float.CInt()
|
kind := gotch.Float.CInt()
|
||||||
retVal = ts.MustZeros(dims, kind, device.CInt())
|
retVal = ts.MustZeros(dims, kind, device.CInt())
|
||||||
|
@ -157,6 +163,20 @@ func (k kaimingUniformInit) InitTensor(dims []int64, device gotch.Device) (retVa
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// product calculates product by multiplying elements
|
||||||
|
func product(dims []int64) (retVal int64) {
|
||||||
|
|
||||||
|
for i, v := range dims {
|
||||||
|
if i == 0 {
|
||||||
|
retVal = v
|
||||||
|
} else {
|
||||||
|
retVal = retVal * v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
func factorial(n uint64) (result uint64) {
|
func factorial(n uint64) (result uint64) {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
result = n * factorial(n-1)
|
result = n * factorial(n-1)
|
||||||
|
@ -170,7 +190,14 @@ func (k kaimingUniformInit) Set(tensor ts.Tensor) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("uniformInit - Set method call error: %v\n", err)
|
log.Fatalf("uniformInit - Set method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
fanIn := factorial(uint64(len(dims) - 1))
|
|
||||||
|
var fanIn int64
|
||||||
|
if len(dims) == 1 {
|
||||||
|
log.Fatalf("KaimingUniformInit Set method call: Tensor (%v) should have length > 1", tensor.MustSize())
|
||||||
|
} else {
|
||||||
|
fanIn = product(dims[1:])
|
||||||
|
}
|
||||||
|
|
||||||
bound := math.Sqrt(1.0 / float64(fanIn))
|
bound := math.Sqrt(1.0 / float64(fanIn))
|
||||||
tensor.Uniform_(-bound, bound)
|
tensor.Uniform_(-bound, bound)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user