added libtorch APIs GetLRs and ParamGroupNum
This commit is contained in:
parent
51bff4f402
commit
b02185df22
33
example/scheduler/main.go
Normal file
33
example/scheduler/main.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
// ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
vs := nn.NewVarStore(gotch.CPU)
|
||||||
|
model := vision.EfficientNetB4(vs.Root(), 1000)
|
||||||
|
vs.Load("../../data/pretrained/efficientnet-b4.pt")
|
||||||
|
|
||||||
|
adamConfig := nn.DefaultAdamConfig()
|
||||||
|
o, err := adamConfig.Build(vs, 0.001)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ngroup := o.ParamGroupNum()
|
||||||
|
lrs := o.GetLRs()
|
||||||
|
|
||||||
|
fmt.Printf("Number of param groups: %v\n", ngroup)
|
||||||
|
fmt.Printf("Learning rates: %+v\n", lrs)
|
||||||
|
|
||||||
|
log.Print(model)
|
||||||
|
|
||||||
|
}
|
|
@ -440,6 +440,31 @@ func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
|
||||||
C.ato_set_learning_rate(coptimizer, clearningRate)
|
C.ato_set_learning_rate(coptimizer, clearningRate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AtoGetLearningRates(coptimizer Coptimizer) []float64 {
|
||||||
|
cLRsPtr := (*C.double)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
cngroup := (*C.int)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
C.ato_get_learning_rates(coptimizer, cLRsPtr, cngroup)
|
||||||
|
ngroup := *(*int)(unsafe.Pointer(cngroup))
|
||||||
|
|
||||||
|
var lrs []float64 = make([]float64, ngroup)
|
||||||
|
var currPtr *C.double = cLRsPtr
|
||||||
|
for i := 0; i < ngroup; i++ {
|
||||||
|
lrs[i] = *(*float64)(unsafe.Pointer(currPtr))
|
||||||
|
nextPtr := (*C.double)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(currPtr)))
|
||||||
|
currPtr = nextPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
return lrs
|
||||||
|
}
|
||||||
|
|
||||||
|
func AtoParamGroupNum(coptimizer Coptimizer) int64 {
|
||||||
|
cpgNum := C.ato_param_group_num(coptimizer)
|
||||||
|
|
||||||
|
pgNum := *(*int64)(unsafe.Pointer(&cpgNum))
|
||||||
|
return pgNum
|
||||||
|
}
|
||||||
|
|
||||||
// void ato_set_momentum(optimizer, double momentum);
|
// void ato_set_momentum(optimizer, double momentum);
|
||||||
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
|
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
|
||||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||||
|
|
|
@ -590,6 +590,73 @@ void ato_set_learning_rate_group(optimizer t, size_t group, double learning_rate
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ set/get learning rates ==============================
|
||||||
|
// TT. added for learning rate scheduler
|
||||||
|
// lr scheduler APIs will be in Pytorch 1.9?
|
||||||
|
// Ref. https://github.com/pytorch/pytorch/issues/50577
|
||||||
|
template <class T>
|
||||||
|
void set_lrs(optimizer t, std::vector<double> &learning_rates) {
|
||||||
|
torch::optim::OptimizerOptions *d = &(t->defaults());
|
||||||
|
if (auto p = dynamic_cast<T *>(d)) {
|
||||||
|
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
|
||||||
|
auto ¶m_group = t->param_groups()[i];
|
||||||
|
torch::optim::OptimizerOptions *d = &(param_group.options());
|
||||||
|
if (auto p2 = dynamic_cast<T *>(d)) {
|
||||||
|
p2->lr(learning_rates[i]);
|
||||||
|
} else
|
||||||
|
throw std::invalid_argument("unexpected param group type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ato_set_learning_rates(optimizer t, std::vector<double> &learning_rates) {
|
||||||
|
PROTECT(
|
||||||
|
set_lrs<torch::optim::AdamOptions>(t, learning_rates);
|
||||||
|
set_lrs<torch::optim::AdamWOptions>(t, learning_rates);
|
||||||
|
set_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||||
|
set_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
|
||||||
|
torch::optim::OptimizerOptions *d = &(t->defaults());
|
||||||
|
if (auto p = dynamic_cast<T *>(d)) {
|
||||||
|
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
|
||||||
|
auto ¶m_group = t->param_groups()[i];
|
||||||
|
torch::optim::OptimizerOptions *d = &(param_group.options());
|
||||||
|
if (auto p2 = dynamic_cast<T *>(d)) {
|
||||||
|
lrs[i] = p2->lr();
|
||||||
|
} else
|
||||||
|
throw std::invalid_argument("unexpected param group type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) {
|
||||||
|
PROTECT(
|
||||||
|
int ngroup = t->param_groups().size();
|
||||||
|
static vector<double> learning_rates(ngroup);
|
||||||
|
get_lrs<torch::optim::AdamOptions>(t, learning_rates);
|
||||||
|
get_lrs<torch::optim::AdamWOptions>(t, learning_rates);
|
||||||
|
get_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||||
|
get_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||||
|
|
||||||
|
copy(learning_rates.begin(), learning_rates.end(), lrs);
|
||||||
|
param_group_num[0] = ngroup;
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t ato_param_group_num(optimizer t) {
|
||||||
|
PROTECT(
|
||||||
|
return t->param_groups().size();
|
||||||
|
)
|
||||||
|
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// ============ End of set/get learning rates ==============================
|
||||||
|
|
||||||
void ato_set_momentum(optimizer t, double momentum) {
|
void ato_set_momentum(optimizer t, double momentum) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
torch::optim::OptimizerOptions* d = &(t->defaults());
|
torch::optim::OptimizerOptions* d = &(t->defaults());
|
||||||
|
|
|
@ -130,6 +130,12 @@ void ato_zero_grad(optimizer);
|
||||||
void ato_step(optimizer);
|
void ato_step(optimizer);
|
||||||
void ato_free(optimizer);
|
void ato_free(optimizer);
|
||||||
|
|
||||||
|
// TT. APIs for learning rate scheduler
|
||||||
|
void ato_set_learning_rates(optimizer, double* learning_rates);
|
||||||
|
//double *ato_get_learning_rates(optimizer);
|
||||||
|
int64_t ato_param_group_num(optimizer);
|
||||||
|
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||||
|
|
||||||
scalar ats_int(int64_t);
|
scalar ats_int(int64_t);
|
||||||
scalar ats_float(double);
|
scalar ats_float(double);
|
||||||
int64_t ats_to_int(scalar);
|
int64_t ats_to_int(scalar);
|
||||||
|
@ -197,7 +203,7 @@ void ati_free(ivalue);
|
||||||
#include "torch_api_generated.h"
|
#include "torch_api_generated.h"
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
};
|
}; // extern "C"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -276,6 +276,15 @@ func (opt *Optimizer) SetLR(lr float64) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opt *Optimizer) GetLRs() []float64 {
|
||||||
|
lrs, err := opt.opt.GetLearningRates()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Optimizer - GetLRs method call error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return lrs
|
||||||
|
}
|
||||||
|
|
||||||
// SetMomentum sets the optimizer momentum.
|
// SetMomentum sets the optimizer momentum.
|
||||||
func (opt *Optimizer) SetMomentum(m float64) {
|
func (opt *Optimizer) SetMomentum(m float64) {
|
||||||
err := opt.opt.SetMomentum(m)
|
err := opt.opt.SetMomentum(m)
|
||||||
|
@ -283,3 +292,12 @@ func (opt *Optimizer) SetMomentum(m float64) {
|
||||||
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
|
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (opt *Optimizer) ParamGroupNum() int64 {
|
||||||
|
ngroup, err := opt.opt.ParamGroupNum()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ngroup
|
||||||
|
}
|
||||||
|
|
|
@ -80,6 +80,27 @@ func (co *COptimizer) SetLearningRate(lr float64) error {
|
||||||
return TorchErr()
|
return TorchErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLeanringRates get learning rates for the optimizer
|
||||||
|
func (co *COptimizer) GetLearningRates() ([]float64, error) {
|
||||||
|
lrs := lib.AtoGetLearningRates(co.coptimizer)
|
||||||
|
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return lrs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (co *COptimizer) ParamGroupNum() (int64, error) {
|
||||||
|
ngroup := lib.AtoParamGroupNum(co.coptimizer)
|
||||||
|
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ngroup, nil
|
||||||
|
}
|
||||||
|
|
||||||
// SetMomentum sets a momentum for the optimizer
|
// SetMomentum sets a momentum for the optimizer
|
||||||
func (co *COptimizer) SetMomentum(m float64) error {
|
func (co *COptimizer) SetMomentum(m float64) error {
|
||||||
lib.AtoSetMomentum(co.coptimizer, m)
|
lib.AtoSetMomentum(co.coptimizer, m)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user