added libtorch APIs GetLRs and ParamGroupNum

This commit is contained in:
sugarme 2021-06-04 23:07:59 +10:00
parent 51bff4f402
commit b02185df22
6 changed files with 171 additions and 1 deletions

33
example/scheduler/main.go Normal file
View File

@ -0,0 +1,33 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
"github.com/sugarme/gotch/vision"
// ts "github.com/sugarme/gotch/tensor"
)
func main() {
vs := nn.NewVarStore(gotch.CPU)
model := vision.EfficientNetB4(vs.Root(), 1000)
vs.Load("../../data/pretrained/efficientnet-b4.pt")
adamConfig := nn.DefaultAdamConfig()
o, err := adamConfig.Build(vs, 0.001)
if err != nil {
log.Fatal(err)
}
ngroup := o.ParamGroupNum()
lrs := o.GetLRs()
fmt.Printf("Number of param groups: %v\n", ngroup)
fmt.Printf("Learning rates: %+v\n", lrs)
log.Print(model)
}

View File

@ -440,6 +440,31 @@ func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
C.ato_set_learning_rate(coptimizer, clearningRate)
}
func AtoGetLearningRates(coptimizer Coptimizer) []float64 {
cLRsPtr := (*C.double)(unsafe.Pointer(C.malloc(0)))
cngroup := (*C.int)(unsafe.Pointer(C.malloc(0)))
C.ato_get_learning_rates(coptimizer, cLRsPtr, cngroup)
ngroup := *(*int)(unsafe.Pointer(cngroup))
var lrs []float64 = make([]float64, ngroup)
var currPtr *C.double = cLRsPtr
for i := 0; i < ngroup; i++ {
lrs[i] = *(*float64)(unsafe.Pointer(currPtr))
nextPtr := (*C.double)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(currPtr)))
currPtr = nextPtr
}
return lrs
}
func AtoParamGroupNum(coptimizer Coptimizer) int64 {
cpgNum := C.ato_param_group_num(coptimizer)
pgNum := *(*int64)(unsafe.Pointer(&cpgNum))
return pgNum
}
// void ato_set_momentum(optimizer, double momentum);
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))

View File

@ -590,6 +590,73 @@ void ato_set_learning_rate_group(optimizer t, size_t group, double learning_rate
)
}
// ============ set/get learning rates ==============================
// TT. added for learning rate scheduler
// lr scheduler APIs will be in Pytorch 1.9?
// Ref. https://github.com/pytorch/pytorch/issues/50577
template <class T>
void set_lrs(optimizer t, std::vector<double> &learning_rates) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
auto &param_group = t->param_groups()[i];
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
p2->lr(learning_rates[i]);
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_set_learning_rates(optimizer t, std::vector<double> &learning_rates) {
PROTECT(
set_lrs<torch::optim::AdamOptions>(t, learning_rates);
set_lrs<torch::optim::AdamWOptions>(t, learning_rates);
set_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
set_lrs<torch::optim::SGDOptions>(t, learning_rates);
)
}
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
auto &param_group = t->param_groups()[i];
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
lrs[i] = p2->lr();
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) {
PROTECT(
int ngroup = t->param_groups().size();
static vector<double> learning_rates(ngroup);
get_lrs<torch::optim::AdamOptions>(t, learning_rates);
get_lrs<torch::optim::AdamWOptions>(t, learning_rates);
get_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
get_lrs<torch::optim::SGDOptions>(t, learning_rates);
copy(learning_rates.begin(), learning_rates.end(), lrs);
param_group_num[0] = ngroup;
)
}
int64_t ato_param_group_num(optimizer t) {
PROTECT(
return t->param_groups().size();
)
return -1;
}
// ============ End of set/get learning rates ==============================
void ato_set_momentum(optimizer t, double momentum) {
PROTECT(
torch::optim::OptimizerOptions* d = &(t->defaults());

View File

@ -130,6 +130,12 @@ void ato_zero_grad(optimizer);
void ato_step(optimizer);
void ato_free(optimizer);
// TT. APIs for learning rate scheduler
void ato_set_learning_rates(optimizer, double* learning_rates);
//double *ato_get_learning_rates(optimizer);
int64_t ato_param_group_num(optimizer);
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
scalar ats_int(int64_t);
scalar ats_float(double);
int64_t ats_to_int(scalar);
@ -197,7 +203,7 @@ void ati_free(ivalue);
#include "torch_api_generated.h"
#ifdef __cplusplus
};
}; // extern "C"
#endif
#endif

View File

@ -276,6 +276,15 @@ func (opt *Optimizer) SetLR(lr float64) {
}
}
func (opt *Optimizer) GetLRs() []float64 {
lrs, err := opt.opt.GetLearningRates()
if err != nil {
log.Fatalf("Optimizer - GetLRs method call error: %v\n", err)
}
return lrs
}
// SetMomentum sets the optimizer momentum.
func (opt *Optimizer) SetMomentum(m float64) {
err := opt.opt.SetMomentum(m)
@ -283,3 +292,12 @@ func (opt *Optimizer) SetMomentum(m float64) {
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
}
}
func (opt *Optimizer) ParamGroupNum() int64 {
ngroup, err := opt.opt.ParamGroupNum()
if err != nil {
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
}
return ngroup
}

View File

@ -80,6 +80,27 @@ func (co *COptimizer) SetLearningRate(lr float64) error {
return TorchErr()
}
// GetLeanringRates get learning rates for the optimizer
func (co *COptimizer) GetLearningRates() ([]float64, error) {
lrs := lib.AtoGetLearningRates(co.coptimizer)
if err := TorchErr(); err != nil {
return nil, err
}
return lrs, nil
}
func (co *COptimizer) ParamGroupNum() (int64, error) {
ngroup := lib.AtoParamGroupNum(co.coptimizer)
if err := TorchErr(); err != nil {
return -1, err
}
return ngroup, nil
}
// SetMomentum sets a momentum for the optimizer
func (co *COptimizer) SetMomentum(m float64) error {
lib.AtoSetMomentum(co.coptimizer, m)