fixed conv3D config

This commit is contained in:
sugarme 2021-08-03 17:56:25 +10:00
parent 620fccf452
commit 880a1b25df

View File

@ -206,21 +206,21 @@ type Conv3DConfigOpt func(*Conv3DConfig)
// WithStride3D adds stride 3D option.
func WithStride3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Stride = []int64{val, val}
cfg.Stride = []int64{val, val, val}
}
}
// WithPadding3D adds padding 3D option.
func WithPadding3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Padding = []int64{val, val}
cfg.Padding = []int64{val, val, val}
}
}
// WithDilation3D adds dilation 3D option.
func WithDilation3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Dilation = []int64{val, val}
cfg.Dilation = []int64{val, val, val}
}
}
@ -255,9 +255,9 @@ func WithBsInit3D(val Init) Conv3DConfigOpt {
// DefaultConvConfig3D creates a default 3D ConvConfig
func DefaultConv3DConfig() *Conv3DConfig {
return &Conv3DConfig{
Stride: []int64{1, 1},
Padding: []int64{0, 0},
Dilation: []int64{1, 1},
Stride: []int64{1, 1, 1},
Padding: []int64{0, 0, 0},
Dilation: []int64{1, 1, 1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),