added AdamW optimizer

This commit is contained in:
sugarme 2021-06-12 18:40:42 +10:00
parent 50dd5b181c
commit 6f8ec3b69f
3 changed files with 56 additions and 0 deletions

View File

@ -386,6 +386,15 @@ func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
}
func AtoAdamW(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
return C.ato_adamw(clearningRate, cbeta1, cbeta2, cweightDecay)
}
/*
* optimizer ato_rms_prop(double learning_rate,
* double alpha,

View File

@ -136,6 +136,42 @@ func (c *AdamConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr)
}
// AdamW optimizer:
// ===============
type AdamWConfig struct {
Beta1 float64
Beta2 float64
Wd float64
}
// DefaultAdamConfig creates AdamConfig with default values
func DefaultAdamWConfig() *AdamConfig {
return &AdamConfig{
Beta1: 0.9,
Beta2: 0.999,
Wd: 0.0,
}
}
// NewAdamConfig creates AdamConfig with specified values
func NewAdamWConfig(beta1, beta2, wd float64) *AdamWConfig {
return &AdamWConfig{
Beta1: beta1,
Beta2: beta2,
Wd: wd,
}
}
// Implement OptimizerConfig interface for AdamConfig
func (c *AdamWConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
return ts.AdamW(lr, c.Beta1, c.Beta2, c.Wd)
}
func (c *AdamWConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
return defaultBuild(c, vs, lr)
}
// RMSProp optimizer:
// ===============

View File

@ -21,6 +21,17 @@ func Adam(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
return &COptimizer{coptimizer}, nil
}
// AdamW returns AdamW optimizer
func AdamW(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
coptimizer := lib.AtoAdamW(lr, beta1, beta2, weightDecay)
if err := TorchErr(); err != nil {
return nil, err
}
return &COptimizer{coptimizer}, nil
}
// RmsProp returns RMSProp optimizer
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (*COptimizer, error) {
var centeredCInt int