feat(wrapper/optimizer): added optimizer.go

This commit is contained in:
sugarme 2020-06-11 16:10:29 +10:00
parent 3de38ffa27
commit 6b092c129f
2 changed files with 215 additions and 0 deletions

View File

@ -15,6 +15,7 @@ import (
// NOTE: C.tensor is a C pointer to torch::Tensor
type Ctensor = C.tensor
type Cscalar = C.scalar
type Coptimizer = C.optimizer
type NamedCtensor struct {
Name string
@ -318,3 +319,98 @@ func AtGradSetEnabled(b int) int {
cretVal := C.at_grad_set_enabled(cbool)
return *(*int)(unsafe.Pointer(&cretVal))
}
/*
* optimizer ato_adam(double learning_rate,
* double beta1,
* double beta2,
* double weight_decay);
* */
func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
}
/*
* optimizer ato_rms_prop(double learning_rate,
* double alpha,
* double eps,
* double weight_decay,
* double momentum,
* int centered);
* */
func AtoRmsProp(learningRate, alpha, eps, weightDecay, momentum float64, centered int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
calpha := *(*C.double)(unsafe.Pointer(&alpha))
ceps := *(*C.double)(unsafe.Pointer(&eps))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
ccentered := *(*C.int)(unsafe.Pointer(&centered))
return C.ato_rms_prop(clearningRate, calpha, ceps, cweightDecay, cmomentum, ccentered)
}
/*
* optimizer ato_sgd(double learning_rate,
* double momentum,
* double dampening,
* double weight_decay,
* int nesterov);
* */
func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int) Coptimizer {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
cdampening := *(*C.double)(unsafe.Pointer(&dampening))
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
cnesterov := *(*C.int)(unsafe.Pointer(&nesterov))
return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov)
}
// void ato_add_parameters(optimizer, tensor *, int ntensors);
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
// Just give pointer to the first element of ctensors slice
C.ato_add_parameters(coptimizer, &ctensors[0], cntensors)
}
// void ato_set_learning_rate(optimizer, double learning_rate);
func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
C.ato_set_learning_rate(coptimizer, clearningRate)
}
// void ato_set_momentum(optimizer, double momentum);
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
C.ato_set_momentum(coptimizer, cmomentum)
}
// void ato_zero_grad(optimizer);
func AtoZeroGrad(coptimizer Coptimizer) {
C.ato_zero_grad(coptimizer)
}
// void ato_step(optimizer);
func AtoStep(coptimizer Coptimizer) {
C.ato_step(coptimizer)
}
// void ato_free(optimizer);
func AtoFree(coptimizer Coptimizer) {
C.ato_free(coptimizer)
}

119
wrapper/optimizer.go Normal file
View File

@ -0,0 +1,119 @@
package wrapper
import (
"log"
lib "github.com/sugarme/gotch/libtch"
)
type COptimizer struct {
coptimizer lib.Coptimizer
}
// Adam returns Adam optimizer
func Adam(lr, beta1, beta2, weightDecay float64) (retVal COptimizer, err error) {
coptimizer := lib.AtoAdam(lr, beta1, beta2, weightDecay)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = COptimizer{coptimizer}
return retVal, nil
}
// RmsProp returns RMSProp optimizer
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (retVal COptimizer, err error) {
var centeredCInt int
switch centered {
case true:
centeredCInt = 1
case false:
centeredCInt = 0
}
coptimizer := lib.AtoRmsProp(lr, alpha, eps, wd, momentum, centeredCInt)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = COptimizer{coptimizer}
return retVal, nil
}
// Sgd returns SGD optimizer
func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (retVal COptimizer, err error) {
var nesterovCInt int
switch nesterov {
case true:
nesterovCInt = 1
case false:
nesterovCInt = 0
}
coptimizer := lib.AtoSgd(lr, momentum, dampening, wd, nesterovCInt)
err = TorchErr()
if err != nil {
return retVal, err
}
retVal = COptimizer{coptimizer}
return retVal, nil
}
// AddParameters adds parameters as a slice of tensors to optimizer
func (co COptimizer) AddParameters(tensors []Tensor) (err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ntensors := len(tensors)
lib.AtoAddParameters(co.coptimizer, ctensors, ntensors)
return TorchErr()
}
// SetLeanringRate sets learning rate for the optimizer
func (co COptimizer) SetLearningRate(lr float64) (err error) {
lib.AtoSetLearningRate(co.coptimizer, lr)
return TorchErr()
}
// SetMomentum sets a momentum for the optimizer
func (co COptimizer) SetMomentum(m float64) (err error) {
lib.AtoSetMomentum(co.coptimizer, m)
return TorchErr()
}
// ZeroGrad sets gradients to zero
func (co COptimizer) ZeroGrad() (err error) {
lib.AtoZeroGrad(co.coptimizer)
return TorchErr()
}
// Steps proceeds optimizer
func (co COptimizer) Step() (err error) {
lib.AtoStep(co.coptimizer)
return TorchErr()
}
// Drop removes optimizer and frees up memory.
func (co COptimizer) Drop() {
lib.AtoFree(co.coptimizer)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}