56 lines
1.3 KiB
Go
56 lines
1.3 KiB
Go
package nn
|
|
|
|
// A layer-normalization layer.
|
|
import (
|
|
ts "github.com/sugarme/gotch/tensor"
|
|
)
|
|
|
|
// Layer-normalization config.
|
|
type LayerNormConfig struct {
|
|
CudnnEnable bool
|
|
Eps float64
|
|
ElementwiseAffine bool
|
|
WsInit Init
|
|
BsInit Init
|
|
}
|
|
|
|
func DefaultLayerNormConfig() LayerNormConfig {
|
|
return LayerNormConfig{
|
|
CudnnEnable: true,
|
|
Eps: 1e-5,
|
|
ElementwiseAffine: true,
|
|
WsInit: NewConstInit(1.0),
|
|
BsInit: NewConstInit(0.0),
|
|
}
|
|
}
|
|
|
|
// A layer-normalization layer.
|
|
type LayerNorm struct {
|
|
Config LayerNormConfig
|
|
Ws ts.Tensor // optional
|
|
Bs ts.Tensor // optional
|
|
NormalizedShape []int64
|
|
}
|
|
|
|
func NewLayerNorm(vs Path, normalizedShape []int64, config LayerNormConfig) LayerNorm {
|
|
|
|
var (
|
|
ws ts.Tensor
|
|
bs ts.Tensor
|
|
)
|
|
if config.ElementwiseAffine {
|
|
ws = vs.NewVar("weight", normalizedShape, config.WsInit)
|
|
bs = vs.NewVar("bias", normalizedShape, config.BsInit)
|
|
}
|
|
|
|
return LayerNorm{config, ws, bs, normalizedShape}
|
|
}
|
|
|
|
// Implement Module interface for LayerNorm:
|
|
// =========================================
|
|
|
|
func (ln LayerNorm) Forward(xs ts.Tensor) (retVal ts.Tensor) {
|
|
|
|
return ts.MustLayerNorm(xs, ln.NormalizedShape, ln.Ws, ln.Bs, ln.Config.Eps, ln.Config.CudnnEnable)
|
|
}
|