feat(nn/batch-norm): completed
This commit is contained in:
parent
dd267ebf1c
commit
8f095ea43a
|
@ -508,3 +508,14 @@ func AtgLayerNorm(ptr *Ctensor, input Ctensor, normalizedShapeData []int64, norm
|
|||
|
||||
C.atg_layer_norm(ptr, input, cnormalizedShapeDataPtr, cnormalizedShapeLen, weight, bias, ceps, ccudnnEnable)
|
||||
}
|
||||
|
||||
// void atg_batch_norm(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, int training, double momentum, double eps, int cudnn_enabled);
|
||||
func AtgBatchNorm(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, runningMean Ctensor, runningVar Ctensor, training int, momentum float64, eps float64, cudnnEnable int) {
|
||||
|
||||
ctraining := *(*C.int)(unsafe.Pointer(&training))
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
ceps := *(*C.double)(unsafe.Pointer(&eps))
|
||||
ccudnnEnable := *(*C.int)(unsafe.Pointer(&cudnnEnable))
|
||||
|
||||
C.atg_batch_norm(ptr, input, weight, bias, runningMean, runningVar, ctraining, cmomentum, ceps, ccudnnEnable)
|
||||
}
|
||||
|
|
92
nn/batch-norm.go
Normal file
92
nn/batch-norm.go
Normal file
|
@ -0,0 +1,92 @@
|
|||
package nn
|
||||
|
||||
// A batch-normalization layer.
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Batch-normalization config.
|
||||
type BatchNormConfig struct {
|
||||
CudnnEnable bool
|
||||
Eps float64
|
||||
Momentum float64
|
||||
WsInit Init
|
||||
BsInit Init
|
||||
}
|
||||
|
||||
func DefaultBatchNormConfig() BatchNormConfig {
|
||||
return BatchNormConfig{
|
||||
CudnnEnable: true,
|
||||
Eps: 1e-5,
|
||||
Momentum: 0.1,
|
||||
WsInit: NewUniformInit(0.0, 1.0),
|
||||
BsInit: NewConstInit(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
// A batch-normalization layer.
|
||||
type BatchNorm struct {
|
||||
config BatchNormConfig
|
||||
RunningMean ts.Tensor
|
||||
RunningVar ts.Tensor
|
||||
Ws ts.Tensor
|
||||
Bs ts.Tensor
|
||||
Nd uint
|
||||
}
|
||||
|
||||
// NewBatchNorm creates a new BatchNorm layer
|
||||
func NewBatchNorm(vs Path, nd uint, outDim int64, config BatchNormConfig) BatchNorm {
|
||||
return BatchNorm{
|
||||
config: config,
|
||||
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
|
||||
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
|
||||
Ws: vs.NewVar("weight", []int64{outDim}, config.WsInit),
|
||||
Bs: vs.NewVar("bias", []int64{outDim}, config.BsInit),
|
||||
}
|
||||
}
|
||||
|
||||
// Applies Batch Normalization over a three dimension input.
|
||||
//
|
||||
// The input shape is assumed to be (N, C, L). Normalization
|
||||
// is performed over the first batch dimension N.
|
||||
func BatchNorm1D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
|
||||
return NewBatchNorm(vs, 1, outDim, config)
|
||||
}
|
||||
|
||||
// Applies Batch Normalization over a four dimension input.
|
||||
//
|
||||
// The input shape is assumed to be (N, C, H, W). Normalization
|
||||
// is performed over the first batch dimension N.
|
||||
func BatchNorm2D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
|
||||
return NewBatchNorm(vs, 2, outDim, config)
|
||||
}
|
||||
|
||||
// Applies Batch Normalization over a five dimension input.
|
||||
//
|
||||
// The input shape is assumed to be (N, C, D, H, W). Normalization
|
||||
// is performed over the first batch dimension N.
|
||||
func BatchNorm3D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
|
||||
return NewBatchNorm(vs, 3, outDim, config)
|
||||
}
|
||||
|
||||
// Implement ModuleT interface for BatchNorm:
|
||||
// ==========================================
|
||||
|
||||
func (bn BatchNorm) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
|
||||
|
||||
dim := xs.Dim()
|
||||
|
||||
if bn.Nd == 1 && dim != 2 && dim != 3 {
|
||||
log.Fatalf("Expected an input tensor with 2 or 3 dims, got %v\n", xs.MustSize())
|
||||
}
|
||||
|
||||
if bn.Nd > 1 && int(dim) != int(bn.Nd)+2 {
|
||||
log.Fatalf("Expected an input tensor with %v dims, got %v\n", bn.Nd+2, xs.MustSize())
|
||||
}
|
||||
|
||||
return ts.MustBatchNorm(xs, bn.Ws, bn.Bs, bn.RunningMean, bn.RunningVar, train, bn.config.Momentum, bn.config.Eps, bn.config.CudnnEnable)
|
||||
|
||||
}
|
|
@ -1477,3 +1477,37 @@ func MustLayerNorm(input Tensor, normalizedShape []int64, weight, bias Tensor, e
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func BatchNorm(input Tensor, weight, bias, runningMean, runningVar Tensor, train bool, momentum float64, eps float64, cudnnEnable bool) (retVal Tensor, err error) {
|
||||
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ccudnnEnable := 0
|
||||
if cudnnEnable {
|
||||
ccudnnEnable = 1
|
||||
}
|
||||
ctrain := 0
|
||||
if train {
|
||||
ctrain = 1
|
||||
}
|
||||
|
||||
lib.AtgBatchNorm(ptr, input.ctensor, weight.ctensor, bias.ctensor, runningMean.ctensor, runningVar.ctensor, ctrain, momentum, eps, ccudnnEnable)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func MustBatchNorm(input Tensor, weight, bias, runningMean, runningVar Tensor, train bool, momentum float64, eps float64, cudnnEnable bool) (retVal Tensor) {
|
||||
retVal, err := BatchNorm(input, weight, bias, runningMean, runningVar, train, momentum, eps, cudnnEnable)
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user