added AdamW optimizer
This commit is contained in:
parent
50dd5b181c
commit
6f8ec3b69f
|
@ -386,6 +386,15 @@ func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
|
|||
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
|
||||
}
|
||||
|
||||
func AtoAdamW(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
|
||||
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
|
||||
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
|
||||
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
|
||||
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
|
||||
|
||||
return C.ato_adamw(clearningRate, cbeta1, cbeta2, cweightDecay)
|
||||
}
|
||||
|
||||
/*
|
||||
* optimizer ato_rms_prop(double learning_rate,
|
||||
* double alpha,
|
||||
|
|
|
@ -136,6 +136,42 @@ func (c *AdamConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
|
|||
return defaultBuild(c, vs, lr)
|
||||
}
|
||||
|
||||
// AdamW optimizer:
|
||||
// ===============
|
||||
|
||||
type AdamWConfig struct {
|
||||
Beta1 float64
|
||||
Beta2 float64
|
||||
Wd float64
|
||||
}
|
||||
|
||||
// DefaultAdamConfig creates AdamConfig with default values
|
||||
func DefaultAdamWConfig() *AdamConfig {
|
||||
return &AdamConfig{
|
||||
Beta1: 0.9,
|
||||
Beta2: 0.999,
|
||||
Wd: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
// NewAdamConfig creates AdamConfig with specified values
|
||||
func NewAdamWConfig(beta1, beta2, wd float64) *AdamWConfig {
|
||||
return &AdamWConfig{
|
||||
Beta1: beta1,
|
||||
Beta2: beta2,
|
||||
Wd: wd,
|
||||
}
|
||||
}
|
||||
|
||||
// Implement OptimizerConfig interface for AdamConfig
|
||||
func (c *AdamWConfig) buildCOpt(lr float64) (*ts.COptimizer, error) {
|
||||
return ts.AdamW(lr, c.Beta1, c.Beta2, c.Wd)
|
||||
}
|
||||
|
||||
func (c *AdamWConfig) Build(vs *VarStore, lr float64) (*Optimizer, error) {
|
||||
return defaultBuild(c, vs, lr)
|
||||
}
|
||||
|
||||
// RMSProp optimizer:
|
||||
// ===============
|
||||
|
||||
|
|
|
@ -21,6 +21,17 @@ func Adam(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
|
|||
return &COptimizer{coptimizer}, nil
|
||||
}
|
||||
|
||||
// AdamW returns AdamW optimizer
|
||||
func AdamW(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
|
||||
coptimizer := lib.AtoAdamW(lr, beta1, beta2, weightDecay)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &COptimizer{coptimizer}, nil
|
||||
}
|
||||
|
||||
// RmsProp returns RMSProp optimizer
|
||||
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (*COptimizer, error) {
|
||||
var centeredCInt int
|
||||
|
|
Loading…
Reference in New Issue
Block a user