gotch/nn/batch-norm.go

93 lines
2.6 KiB
Go

package nn
// A batch-normalization layer.
import (
"log"
ts "github.com/sugarme/gotch/tensor"
)
// Batch-normalization config.
type BatchNormConfig struct {
CudnnEnable bool
Eps float64
Momentum float64
WsInit Init
BsInit Init
}
func DefaultBatchNormConfig() *BatchNormConfig {
return &BatchNormConfig{
CudnnEnable: true,
Eps: 1e-5,
Momentum: 0.1,
WsInit: NewUniformInit(0.0, 1.0),
BsInit: NewConstInit(0.0),
}
}
// A batch-normalization layer.
type BatchNorm struct {
config *BatchNormConfig
RunningMean *ts.Tensor
RunningVar *ts.Tensor
Ws *ts.Tensor
Bs *ts.Tensor
Nd uint
}
// NewBatchNorm creates a new BatchNorm layer
func NewBatchNorm(vs *Path, nd uint, outDim int64, config *BatchNormConfig) *BatchNorm {
return &BatchNorm{
config: config,
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
Ws: vs.NewVar("weight", []int64{outDim}, config.WsInit),
Bs: vs.NewVar("bias", []int64{outDim}, config.BsInit),
}
}
// Applies Batch Normalization over a three dimension input.
//
// The input shape is assumed to be (N, C, L). Normalization
// is performed over the first batch dimension N.
func BatchNorm1D(vs *Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 1, outDim, config)
}
// Applies Batch Normalization over a four dimension input.
//
// The input shape is assumed to be (N, C, H, W). Normalization
// is performed over the first batch dimension N.
func BatchNorm2D(vs *Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 2, outDim, config)
}
// Applies Batch Normalization over a five dimension input.
//
// The input shape is assumed to be (N, C, D, H, W). Normalization
// is performed over the first batch dimension N.
func BatchNorm3D(vs *Path, outDim int64, config *BatchNormConfig) *BatchNorm {
return NewBatchNorm(vs, 3, outDim, config)
}
// Implement ModuleT interface for BatchNorm:
// ==========================================
func (bn *BatchNorm) ForwardT(xs *ts.Tensor, train bool) (retVal *ts.Tensor) {
dim := xs.Dim()
if bn.Nd == 1 && dim != 2 && dim != 3 {
log.Fatalf("Expected an input tensor with 2 or 3 dims, got %v\n", xs.MustSize())
}
if bn.Nd > 1 && int(dim) != int(bn.Nd)+2 {
log.Fatalf("Expected an input tensor with %v dims, got %v\n", bn.Nd+2, xs.MustSize())
}
return ts.MustBatchNorm(xs, bn.Ws, bn.Bs, bn.RunningMean, bn.RunningVar, train, bn.config.Momentum, bn.config.Eps, bn.config.CudnnEnable)
}