gotch/nn/conv-transpose.go
2020-07-22 15:56:30 +10:00

139 lines
3.7 KiB
Go

package nn
// A two dimension transposed convolution layer.
import (
"log"
ts "github.com/sugarme/gotch/tensor"
)
type ConvTranspose1DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose2DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose3DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
// DefaultConvConfig create a default 1D ConvConfig
func DefaultConvTranspose1DConfig() ConvTranspose1DConfig {
return ConvTranspose1DConfig{
Stride: []int64{1},
Padding: []int64{0},
OutputPadding: []int64{0},
Dilation: []int64{1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
type ConvTranspose1D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config ConvTranspose1DConfig
}
func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose1DConfig) ConvTranspose1D {
if len(ksizes) != 1 {
log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes))
}
var conv ConvTranspose1D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
type ConvTranspose2D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config ConvTranspose2DConfig
}
func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose2DConfig) ConvTranspose2D {
if len(ksizes) != 2 {
log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes))
}
var conv ConvTranspose2D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
type ConvTranspose3D struct {
Ws ts.Tensor
Bs ts.Tensor // optional
Config ConvTranspose3DConfig
}
func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg ConvTranspose3DConfig) ConvTranspose3D {
if len(ksizes) != 3 {
log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes))
}
var conv ConvTranspose3D
conv.Config = cfg
if cfg.Bias {
conv.Bs = vs.NewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
conv.Ws = vs.NewVar("weight", weightSize, cfg.WsInit)
return conv
}
// Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================
func (c ConvTranspose1D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}
func (c ConvTranspose2D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}
func (c ConvTranspose3D) Forward(xs ts.Tensor) ts.Tensor {
return ts.MustConvTranspose3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}