feat(wrapper/optimizer): added optimizer.go
This commit is contained in:
parent
3de38ffa27
commit
6b092c129f
|
@ -15,6 +15,7 @@ import (
|
|||
// NOTE: C.tensor is a C pointer to torch::Tensor
|
||||
type Ctensor = C.tensor
|
||||
type Cscalar = C.scalar
|
||||
type Coptimizer = C.optimizer
|
||||
|
||||
type NamedCtensor struct {
|
||||
Name string
|
||||
|
@ -318,3 +319,98 @@ func AtGradSetEnabled(b int) int {
|
|||
cretVal := C.at_grad_set_enabled(cbool)
|
||||
return *(*int)(unsafe.Pointer(&cretVal))
|
||||
}
|
||||
|
||||
/*
|
||||
* optimizer ato_adam(double learning_rate,
|
||||
* double beta1,
|
||||
* double beta2,
|
||||
* double weight_decay);
|
||||
* */
|
||||
func AtoAdam(learningRate, beta1, beta2, weightDecay float64) Coptimizer {
|
||||
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
|
||||
cbeta1 := *(*C.double)(unsafe.Pointer(&beta1))
|
||||
cbeta2 := *(*C.double)(unsafe.Pointer(&beta2))
|
||||
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
|
||||
|
||||
return C.ato_adam(clearningRate, cbeta1, cbeta2, cweightDecay)
|
||||
}
|
||||
|
||||
/*
|
||||
* optimizer ato_rms_prop(double learning_rate,
|
||||
* double alpha,
|
||||
* double eps,
|
||||
* double weight_decay,
|
||||
* double momentum,
|
||||
* int centered);
|
||||
* */
|
||||
func AtoRmsProp(learningRate, alpha, eps, weightDecay, momentum float64, centered int) Coptimizer {
|
||||
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
|
||||
calpha := *(*C.double)(unsafe.Pointer(&alpha))
|
||||
ceps := *(*C.double)(unsafe.Pointer(&eps))
|
||||
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
ccentered := *(*C.int)(unsafe.Pointer(¢ered))
|
||||
|
||||
return C.ato_rms_prop(clearningRate, calpha, ceps, cweightDecay, cmomentum, ccentered)
|
||||
}
|
||||
|
||||
/*
|
||||
* optimizer ato_sgd(double learning_rate,
|
||||
* double momentum,
|
||||
* double dampening,
|
||||
* double weight_decay,
|
||||
* int nesterov);
|
||||
* */
|
||||
func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int) Coptimizer {
|
||||
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
cdampening := *(*C.double)(unsafe.Pointer(&dampening))
|
||||
cweightDecay := *(*C.double)(unsafe.Pointer(&weightDecay))
|
||||
cnesterov := *(*C.int)(unsafe.Pointer(&nesterov))
|
||||
|
||||
return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov)
|
||||
}
|
||||
|
||||
// void ato_add_parameters(optimizer, tensor *, int ntensors);
|
||||
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
|
||||
|
||||
var ctensors []C.tensor
|
||||
for i := 0; i < len(tensors); i++ {
|
||||
ctensors = append(ctensors, (C.tensor)(tensors[i]))
|
||||
}
|
||||
|
||||
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
|
||||
|
||||
// Just give pointer to the first element of ctensors slice
|
||||
C.ato_add_parameters(coptimizer, &ctensors[0], cntensors)
|
||||
}
|
||||
|
||||
// void ato_set_learning_rate(optimizer, double learning_rate);
|
||||
func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
|
||||
clearningRate := *(*C.double)(unsafe.Pointer(&learningRate))
|
||||
C.ato_set_learning_rate(coptimizer, clearningRate)
|
||||
}
|
||||
|
||||
// void ato_set_momentum(optimizer, double momentum);
|
||||
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
|
||||
C.ato_set_momentum(coptimizer, cmomentum)
|
||||
}
|
||||
|
||||
// void ato_zero_grad(optimizer);
|
||||
func AtoZeroGrad(coptimizer Coptimizer) {
|
||||
|
||||
C.ato_zero_grad(coptimizer)
|
||||
}
|
||||
|
||||
// void ato_step(optimizer);
|
||||
func AtoStep(coptimizer Coptimizer) {
|
||||
|
||||
C.ato_step(coptimizer)
|
||||
}
|
||||
|
||||
// void ato_free(optimizer);
|
||||
func AtoFree(coptimizer Coptimizer) {
|
||||
C.ato_free(coptimizer)
|
||||
}
|
||||
|
|
119
wrapper/optimizer.go
Normal file
119
wrapper/optimizer.go
Normal file
|
@ -0,0 +1,119 @@
|
|||
package wrapper
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
type COptimizer struct {
|
||||
coptimizer lib.Coptimizer
|
||||
}
|
||||
|
||||
// Adam returns Adam optimizer
|
||||
func Adam(lr, beta1, beta2, weightDecay float64) (retVal COptimizer, err error) {
|
||||
coptimizer := lib.AtoAdam(lr, beta1, beta2, weightDecay)
|
||||
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = COptimizer{coptimizer}
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// RmsProp returns RMSProp optimizer
|
||||
func RmsProp(lr, alpha, eps, wd, momentum float64, centered bool) (retVal COptimizer, err error) {
|
||||
var centeredCInt int
|
||||
switch centered {
|
||||
case true:
|
||||
centeredCInt = 1
|
||||
case false:
|
||||
centeredCInt = 0
|
||||
}
|
||||
|
||||
coptimizer := lib.AtoRmsProp(lr, alpha, eps, wd, momentum, centeredCInt)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = COptimizer{coptimizer}
|
||||
|
||||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
// Sgd returns SGD optimizer
|
||||
func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (retVal COptimizer, err error) {
|
||||
var nesterovCInt int
|
||||
switch nesterov {
|
||||
case true:
|
||||
nesterovCInt = 1
|
||||
case false:
|
||||
nesterovCInt = 0
|
||||
}
|
||||
|
||||
coptimizer := lib.AtoSgd(lr, momentum, dampening, wd, nesterovCInt)
|
||||
err = TorchErr()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = COptimizer{coptimizer}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// AddParameters adds parameters as a slice of tensors to optimizer
|
||||
func (co COptimizer) AddParameters(tensors []Tensor) (err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
ntensors := len(tensors)
|
||||
|
||||
lib.AtoAddParameters(co.coptimizer, ctensors, ntensors)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// SetLeanringRate sets learning rate for the optimizer
|
||||
func (co COptimizer) SetLearningRate(lr float64) (err error) {
|
||||
lib.AtoSetLearningRate(co.coptimizer, lr)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// SetMomentum sets a momentum for the optimizer
|
||||
func (co COptimizer) SetMomentum(m float64) (err error) {
|
||||
lib.AtoSetMomentum(co.coptimizer, m)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// ZeroGrad sets gradients to zero
|
||||
func (co COptimizer) ZeroGrad() (err error) {
|
||||
lib.AtoZeroGrad(co.coptimizer)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// Steps proceeds optimizer
|
||||
func (co COptimizer) Step() (err error) {
|
||||
lib.AtoStep(co.coptimizer)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// Drop removes optimizer and frees up memory.
|
||||
func (co COptimizer) Drop() {
|
||||
lib.AtoFree(co.coptimizer)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user