added conv3dconfig

This commit is contained in:
sugarme 2021-08-03 12:36:10 +10:00
parent 6c38d54cec
commit 620fccf452
2 changed files with 76 additions and 0 deletions

View File

@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
- Added Conv3DConfig and Conv3DConfig Option
## [Nofix]
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.

View File

@ -200,6 +200,81 @@ type Conv3DConfig struct {
BsInit Init
}
// Conv3DConfigOpt is option type for Conv3DConfig.
type Conv3DConfigOpt func(*Conv3DConfig)
// WithStride3D adds stride 3D option.
func WithStride3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Stride = []int64{val, val}
}
}
// WithPadding3D adds padding 3D option.
func WithPadding3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Padding = []int64{val, val}
}
}
// WithDilation3D adds dilation 3D option.
func WithDilation3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Dilation = []int64{val, val}
}
}
// WithGroup3D adds group 3D option.
func WithGroup3D(val int64) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Groups = val
}
}
// WithBias3D adds bias 3D option.
func WithBias3D(val bool) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.Bias = val
}
}
// WithWsInit3D adds WsInit 3D option.
func WithWsInit3D(val Init) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.WsInit = val
}
}
// WithBsInit adds BsInit 3D option.
func WithBsInit3D(val Init) Conv3DConfigOpt {
return func(cfg *Conv3DConfig) {
cfg.BsInit = val
}
}
// DefaultConvConfig3D creates a default 3D ConvConfig
func DefaultConv3DConfig() *Conv3DConfig {
return &Conv3DConfig{
Stride: []int64{1, 1},
Padding: []int64{0, 0},
Dilation: []int64{1, 1},
Groups: 1,
Bias: true,
WsInit: NewKaimingUniformInit(),
BsInit: NewConstInit(float64(0.0)),
}
}
// NewConv3DConfig creates Conv3DConfig.
func NewConv3DConfig(opts ...Conv3DConfigOpt) *Conv3DConfig {
cfg := DefaultConv3DConfig()
for _, o := range opts {
o(cfg)
}
return cfg
}
// Conv1D is convolution 1D struct.
type Conv1D struct {
Ws *ts.Tensor