added conv3dconfig
This commit is contained in:
parent
6c38d54cec
commit
620fccf452
|
@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
- Added Conv3DConfig and Conv3DConfig Option
|
||||
|
||||
## [Nofix]
|
||||
- ctype `long` caused compiling error in MacOS as noted on [#44]. Not working on linux box.
|
||||
|
|
75
nn/conv.go
75
nn/conv.go
|
@ -200,6 +200,81 @@ type Conv3DConfig struct {
|
|||
BsInit Init
|
||||
}
|
||||
|
||||
// Conv3DConfigOpt is option type for Conv3DConfig.
|
||||
type Conv3DConfigOpt func(*Conv3DConfig)
|
||||
|
||||
// WithStride3D adds stride 3D option.
|
||||
func WithStride3D(val int64) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.Stride = []int64{val, val}
|
||||
}
|
||||
}
|
||||
|
||||
// WithPadding3D adds padding 3D option.
|
||||
func WithPadding3D(val int64) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.Padding = []int64{val, val}
|
||||
}
|
||||
}
|
||||
|
||||
// WithDilation3D adds dilation 3D option.
|
||||
func WithDilation3D(val int64) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.Dilation = []int64{val, val}
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroup3D adds group 3D option.
|
||||
func WithGroup3D(val int64) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.Groups = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithBias3D adds bias 3D option.
|
||||
func WithBias3D(val bool) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.Bias = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithWsInit3D adds WsInit 3D option.
|
||||
func WithWsInit3D(val Init) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.WsInit = val
|
||||
}
|
||||
}
|
||||
|
||||
// WithBsInit adds BsInit 3D option.
|
||||
func WithBsInit3D(val Init) Conv3DConfigOpt {
|
||||
return func(cfg *Conv3DConfig) {
|
||||
cfg.BsInit = val
|
||||
}
|
||||
}
|
||||
|
||||
// DefaultConvConfig3D creates a default 3D ConvConfig
|
||||
func DefaultConv3DConfig() *Conv3DConfig {
|
||||
return &Conv3DConfig{
|
||||
Stride: []int64{1, 1},
|
||||
Padding: []int64{0, 0},
|
||||
Dilation: []int64{1, 1},
|
||||
Groups: 1,
|
||||
Bias: true,
|
||||
WsInit: NewKaimingUniformInit(),
|
||||
BsInit: NewConstInit(float64(0.0)),
|
||||
}
|
||||
}
|
||||
|
||||
// NewConv3DConfig creates Conv3DConfig.
|
||||
func NewConv3DConfig(opts ...Conv3DConfigOpt) *Conv3DConfig {
|
||||
cfg := DefaultConv3DConfig()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Conv1D is convolution 1D struct.
|
||||
type Conv1D struct {
|
||||
Ws *ts.Tensor
|
||||
|
|
Loading…
Reference in New Issue
Block a user