feat(nn/batch-norm): completed

This commit is contained in:
sugarme 2020-06-25 20:24:10 +10:00
parent dd267ebf1c
commit 8f095ea43a
3 changed files with 137 additions and 0 deletions

View File

@ -508,3 +508,14 @@ func AtgLayerNorm(ptr *Ctensor, input Ctensor, normalizedShapeData []int64, norm
C.atg_layer_norm(ptr, input, cnormalizedShapeDataPtr, cnormalizedShapeLen, weight, bias, ceps, ccudnnEnable)
}
// void atg_batch_norm(tensor *, tensor input, tensor weight, tensor bias, tensor running_mean, tensor running_var, int training, double momentum, double eps, int cudnn_enabled);
func AtgBatchNorm(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, runningMean Ctensor, runningVar Ctensor, training int, momentum float64, eps float64, cudnnEnable int) {
ctraining := *(*C.int)(unsafe.Pointer(&training))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
ceps := *(*C.double)(unsafe.Pointer(&eps))
ccudnnEnable := *(*C.int)(unsafe.Pointer(&cudnnEnable))
C.atg_batch_norm(ptr, input, weight, bias, runningMean, runningVar, ctraining, cmomentum, ceps, ccudnnEnable)
}

92
nn/batch-norm.go Normal file
View File

@ -0,0 +1,92 @@
package nn
// A batch-normalization layer.
import (
"log"
ts "github.com/sugarme/gotch/tensor"
)
// Batch-normalization config.
type BatchNormConfig struct {
CudnnEnable bool
Eps float64
Momentum float64
WsInit Init
BsInit Init
}
func DefaultBatchNormConfig() BatchNormConfig {
return BatchNormConfig{
CudnnEnable: true,
Eps: 1e-5,
Momentum: 0.1,
WsInit: NewUniformInit(0.0, 1.0),
BsInit: NewConstInit(0.0),
}
}
// A batch-normalization layer.
type BatchNorm struct {
config BatchNormConfig
RunningMean ts.Tensor
RunningVar ts.Tensor
Ws ts.Tensor
Bs ts.Tensor
Nd uint
}
// NewBatchNorm creates a new BatchNorm layer
func NewBatchNorm(vs Path, nd uint, outDim int64, config BatchNormConfig) BatchNorm {
return BatchNorm{
config: config,
RunningMean: vs.ZerosNoTrain("running_mean", []int64{outDim}),
RunningVar: vs.OnesNoTrain("running_var", []int64{outDim}),
Ws: vs.NewVar("weight", []int64{outDim}, config.WsInit),
Bs: vs.NewVar("bias", []int64{outDim}, config.BsInit),
}
}
// Applies Batch Normalization over a three dimension input.
//
// The input shape is assumed to be (N, C, L). Normalization
// is performed over the first batch dimension N.
func BatchNorm1D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
return NewBatchNorm(vs, 1, outDim, config)
}
// Applies Batch Normalization over a four dimension input.
//
// The input shape is assumed to be (N, C, H, W). Normalization
// is performed over the first batch dimension N.
func BatchNorm2D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
return NewBatchNorm(vs, 2, outDim, config)
}
// Applies Batch Normalization over a five dimension input.
//
// The input shape is assumed to be (N, C, D, H, W). Normalization
// is performed over the first batch dimension N.
func BatchNorm3D(vs Path, outDim int64, config BatchNormConfig) BatchNorm {
return NewBatchNorm(vs, 3, outDim, config)
}
// Implement ModuleT interface for BatchNorm:
// ==========================================
func (bn BatchNorm) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
dim := xs.Dim()
if bn.Nd == 1 && dim != 2 && dim != 3 {
log.Fatalf("Expected an input tensor with 2 or 3 dims, got %v\n", xs.MustSize())
}
if bn.Nd > 1 && int(dim) != int(bn.Nd)+2 {
log.Fatalf("Expected an input tensor with %v dims, got %v\n", bn.Nd+2, xs.MustSize())
}
return ts.MustBatchNorm(xs, bn.Ws, bn.Bs, bn.RunningMean, bn.RunningVar, train, bn.config.Momentum, bn.config.Eps, bn.config.CudnnEnable)
}

View File

@ -1477,3 +1477,37 @@ func MustLayerNorm(input Tensor, normalizedShape []int64, weight, bias Tensor, e
return retVal
}
func BatchNorm(input Tensor, weight, bias, runningMean, runningVar Tensor, train bool, momentum float64, eps float64, cudnnEnable bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ccudnnEnable := 0
if cudnnEnable {
ccudnnEnable = 1
}
ctrain := 0
if train {
ctrain = 1
}
lib.AtgBatchNorm(ptr, input.ctensor, weight.ctensor, bias.ctensor, runningMean.ctensor, runningVar.ctensor, ctrain, momentum, eps, ccudnnEnable)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func MustBatchNorm(input Tensor, weight, bias, runningMean, runningVar Tensor, train bool, momentum float64, eps float64, cudnnEnable bool) (retVal Tensor) {
retVal, err := BatchNorm(input, weight, bias, runningMean, runningVar, train, momentum, eps, cudnnEnable)
if err != nil {
log.Fatal(err)
}
return retVal
}