gotch/nn/layer-norm.go
2020-06-25 17:40:01 +10:00

56 lines
1.3 KiB
Go

package nn
// A layer-normalization layer.
import (
ts "github.com/sugarme/gotch/tensor"
)
// Layer-normalization config.
type LayerNormConfig struct {
CudnnEnable bool
Eps float64
ElementwiseAffine bool
WsInit Init
BsInit Init
}
func DefaultLayerNormConfig() LayerNormConfig {
return LayerNormConfig{
CudnnEnable: true,
Eps: 1e-5,
ElementwiseAffine: true,
WsInit: NewConstInit(1.0),
BsInit: NewConstInit(0.0),
}
}
// A layer-normalization layer.
type LayerNorm struct {
Config LayerNormConfig
Ws ts.Tensor // optional
Bs ts.Tensor // optional
NormalizedShape []int64
}
func NewLayerNorm(vs Path, normalizedShape []int64, config LayerNormConfig) LayerNorm {
var (
ws ts.Tensor
bs ts.Tensor
)
if config.ElementwiseAffine {
ws = vs.NewVar("weight", normalizedShape, config.WsInit)
bs = vs.NewVar("bias", normalizedShape, config.BsInit)
}
return LayerNorm{config, ws, bs, normalizedShape}
}
// Implement Module interface for LayerNorm:
// =========================================
func (ln LayerNorm) Forward(xs ts.Tensor) (retVal ts.Tensor) {
return ts.MustLayerNorm(xs, ln.NormalizedShape, ln.Ws, ln.Bs, ln.Config.Eps, ln.Config.CudnnEnable)
}