171 lines
3.6 KiB
Go
171 lines
3.6 KiB
Go
package ts
|
|
|
|
import (
|
|
"log"
|
|
|
|
lib "git.andr3h3nriqu3s.com/andr3/gotch/libtch"
|
|
)
|
|
|
|
type COptimizer struct {
|
|
coptimizer lib.Coptimizer
|
|
}
|
|
|
|
// Adam returns Adam optimizer
|
|
func Adam(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
|
|
coptimizer := lib.AtoAdam(lr, beta1, beta2, weightDecay)
|
|
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &COptimizer{coptimizer}, nil
|
|
}
|
|
|
|
// AdamW returns AdamW optimizer
|
|
func AdamW(lr, beta1, beta2, weightDecay float64) (*COptimizer, error) {
|
|
coptimizer := lib.AtoAdamW(lr, beta1, beta2, weightDecay)
|
|
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &COptimizer{coptimizer}, nil
|
|
}
|
|
|
|
// RmsProp returns RMSProp optimizer
|
|
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (*COptimizer, error) {
|
|
var centeredCInt int
|
|
switch centered {
|
|
case true:
|
|
centeredCInt = 1
|
|
case false:
|
|
centeredCInt = 0
|
|
}
|
|
|
|
coptimizer := lib.AtoRmsProp(lr, alpha, eps, wd, momentum, centeredCInt)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &COptimizer{coptimizer}, nil
|
|
}
|
|
|
|
// Sgd returns SGD optimizer
|
|
func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (*COptimizer, error) {
|
|
var nesterovCInt int
|
|
switch nesterov {
|
|
case true:
|
|
nesterovCInt = 1
|
|
case false:
|
|
nesterovCInt = 0
|
|
}
|
|
|
|
coptimizer := lib.AtoSgd(lr, momentum, dampening, wd, nesterovCInt)
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &COptimizer{coptimizer}, nil
|
|
}
|
|
|
|
// AddParameters adds parameters as a slice of tensors to optimizer
|
|
func (co *COptimizer) AddParameters(tensors []*Tensor) error {
|
|
var ctensors []lib.Ctensor
|
|
for _, t := range tensors {
|
|
ctensors = append(ctensors, t.ctensor)
|
|
}
|
|
|
|
ntensors := len(tensors)
|
|
|
|
// NOTE. temporary switch back as param group not updated yet!
|
|
lib.AtoAddParametersOld(co.coptimizer, ctensors, ntensors)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// AddParameter adds a single parameter to parameter group.
|
|
func (co *COptimizer) AddParameter(param *Tensor, group uint) error {
|
|
lib.AtoAddParameter(co.coptimizer, param.ctensor, group)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// SetLeanringRate sets learning rate for the optimizer
|
|
func (co *COptimizer) SetLearningRate(lr float64) error {
|
|
lib.AtoSetLearningRate(co.coptimizer, lr)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// GetLeanringRates get learning rates for the optimizer
|
|
func (co *COptimizer) GetLearningRates() ([]float64, error) {
|
|
lrs := lib.AtoGetLearningRates(co.coptimizer)
|
|
|
|
if err := TorchErr(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return lrs, nil
|
|
}
|
|
|
|
func (co *COptimizer) SetLearningRates(lrs []float64) error {
|
|
lib.AtoSetLearningRates(co.coptimizer, lrs)
|
|
if err := TorchErr(); err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (co *COptimizer) ParamGroupNum() (int64, error) {
|
|
ngroup := lib.AtoParamGroupNum(co.coptimizer)
|
|
|
|
if err := TorchErr(); err != nil {
|
|
return -1, err
|
|
}
|
|
|
|
return ngroup, nil
|
|
}
|
|
|
|
func (co *COptimizer) AddParamGroup(tensors []*Tensor) error {
|
|
var ctensors []lib.Ctensor
|
|
for _, t := range tensors {
|
|
ctensors = append(ctensors, t.ctensor)
|
|
}
|
|
|
|
ntensors := len(tensors)
|
|
|
|
lib.AtoAddParamGroup(co.coptimizer, ctensors, ntensors)
|
|
return TorchErr()
|
|
}
|
|
|
|
// SetMomentum sets a momentum for the optimizer
|
|
func (co *COptimizer) SetMomentum(m float64) error {
|
|
lib.AtoSetMomentum(co.coptimizer, m)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// ZeroGrad sets gradients to zero
|
|
func (co *COptimizer) ZeroGrad() error {
|
|
lib.AtoZeroGrad(co.coptimizer)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// Steps proceeds optimizer
|
|
func (co *COptimizer) Step() error {
|
|
lib.AtoStep(co.coptimizer)
|
|
|
|
return TorchErr()
|
|
}
|
|
|
|
// Drop removes optimizer and frees up memory.
|
|
func (co *COptimizer) Drop() {
|
|
lib.AtoFree(co.coptimizer)
|
|
|
|
if err := TorchErr(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|