gotch/nn/conv-transpose.go
2022-03-12 18:20:20 +11:00

163 lines
3.9 KiB
Go

package nn
// A two dimension transposed convolution layer.
import (
"log"
"github.com/sugarme/gotch/ts"
)
type ConvTranspose1DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose2DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
type ConvTranspose3DConfig struct {
Stride []int64
Padding []int64
OutputPadding []int64
Dilation []int64
Groups int64
Bias bool
WsInit Init
BsInit Init
}
// DefaultConvConfig create a default 1D ConvConfig
func DefaultConvTranspose1DConfig() *ConvTranspose1DConfig {
return &ConvTranspose1DConfig{
Stride: []int64{1},
Padding: []int64{0},
OutputPadding: []int64{0},
Dilation: []int64{1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
type ConvTranspose1D struct {
Ws *ts.Tensor
Bs *ts.Tensor // optional
Config *ConvTranspose1DConfig
}
func NewConvTranspose1D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose1DConfig) *ConvTranspose1D {
if len(ksizes) != 1 {
log.Fatalf("NewConvTranspose1D method call: Kernel size should be 1. Got %v\n", len(ksizes))
}
var (
ws *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
if cfg.Bias {
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
}
return &ConvTranspose1D{
Ws: ws,
Bs: bs,
Config: cfg,
}
}
type ConvTranspose2D struct {
Ws *ts.Tensor
Bs *ts.Tensor // optional
Config *ConvTranspose2DConfig
}
func NewConvTranspose2D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose2DConfig) *ConvTranspose2D {
if len(ksizes) != 2 {
log.Fatalf("NewConvTranspose2D method call: Kernel size should be 2. Got %v\n", len(ksizes))
}
var (
ws *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
return &ConvTranspose2D{
Ws: ws,
Bs: bs,
Config: cfg,
}
}
type ConvTranspose3D struct {
Ws *ts.Tensor
Bs *ts.Tensor // optional
Config *ConvTranspose3DConfig
}
func NewConvTranspose3D(vs *Path, inDim, outDim int64, ksizes []int64, cfg *ConvTranspose3DConfig) *ConvTranspose3D {
if len(ksizes) != 3 {
log.Fatalf("NewConvTranspose3D method call: Kernel size should be 3. Got %v\n", len(ksizes))
}
var (
ws *ts.Tensor
bs *ts.Tensor = ts.NewTensor()
)
if cfg.Bias {
bs = vs.MustNewVar("bias", []int64{outDim}, cfg.BsInit)
}
weightSize := []int64{outDim, int64(inDim / cfg.Groups)}
weightSize = append(weightSize, ksizes...)
ws = vs.MustNewVar("weight", weightSize, cfg.WsInit)
return &ConvTranspose3D{
Ws: ws,
Bs: bs,
Config: cfg,
}
}
// Implement Module for Conv1D, Conv2D, Conv3D:
// ============================================
func (c *ConvTranspose1D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose1d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}
func (c *ConvTranspose2D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose2d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}
func (c *ConvTranspose3D) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustConvTranspose3d(xs, c.Ws, c.Bs, c.Config.Stride, c.Config.Padding, c.Config.OutputPadding, c.Config.Groups, c.Config.Dilation)
}