added libtorch APIs GetLRs and ParamGroupNum
This commit is contained in:
parent
51bff4f402
commit
b02185df22
33
example/scheduler/main.go
Normal file
33
example/scheduler/main.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
// ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
model := vision.EfficientNetB4(vs.Root(), 1000)
|
||||
vs.Load("../../data/pretrained/efficientnet-b4.pt")
|
||||
|
||||
adamConfig := nn.DefaultAdamConfig()
|
||||
o, err := adamConfig.Build(vs, 0.001)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ngroup := o.ParamGroupNum()
|
||||
lrs := o.GetLRs()
|
||||
|
||||
fmt.Printf("Number of param groups: %v\n", ngroup)
|
||||
fmt.Printf("Learning rates: %+v\n", lrs)
|
||||
|
||||
log.Print(model)
|
||||
|
||||
}
|
|
@ -440,6 +440,31 @@ func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) {
|
|||
C.ato_set_learning_rate(coptimizer, clearningRate)
|
||||
}
|
||||
|
||||
func AtoGetLearningRates(coptimizer Coptimizer) []float64 {
|
||||
cLRsPtr := (*C.double)(unsafe.Pointer(C.malloc(0)))
|
||||
cngroup := (*C.int)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
C.ato_get_learning_rates(coptimizer, cLRsPtr, cngroup)
|
||||
ngroup := *(*int)(unsafe.Pointer(cngroup))
|
||||
|
||||
var lrs []float64 = make([]float64, ngroup)
|
||||
var currPtr *C.double = cLRsPtr
|
||||
for i := 0; i < ngroup; i++ {
|
||||
lrs[i] = *(*float64)(unsafe.Pointer(currPtr))
|
||||
nextPtr := (*C.double)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(currPtr)))
|
||||
currPtr = nextPtr
|
||||
}
|
||||
|
||||
return lrs
|
||||
}
|
||||
|
||||
func AtoParamGroupNum(coptimizer Coptimizer) int64 {
|
||||
cpgNum := C.ato_param_group_num(coptimizer)
|
||||
|
||||
pgNum := *(*int64)(unsafe.Pointer(&cpgNum))
|
||||
return pgNum
|
||||
}
|
||||
|
||||
// void ato_set_momentum(optimizer, double momentum);
|
||||
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
|
|
|
@ -590,6 +590,73 @@ void ato_set_learning_rate_group(optimizer t, size_t group, double learning_rate
|
|||
)
|
||||
}
|
||||
|
||||
// ============ set/get learning rates ==============================
|
||||
// TT. added for learning rate scheduler
|
||||
// lr scheduler APIs will be in Pytorch 1.9?
|
||||
// Ref. https://github.com/pytorch/pytorch/issues/50577
|
||||
template <class T>
|
||||
void set_lrs(optimizer t, std::vector<double> &learning_rates) {
|
||||
torch::optim::OptimizerOptions *d = &(t->defaults());
|
||||
if (auto p = dynamic_cast<T *>(d)) {
|
||||
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
|
||||
auto ¶m_group = t->param_groups()[i];
|
||||
torch::optim::OptimizerOptions *d = &(param_group.options());
|
||||
if (auto p2 = dynamic_cast<T *>(d)) {
|
||||
p2->lr(learning_rates[i]);
|
||||
} else
|
||||
throw std::invalid_argument("unexpected param group type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ato_set_learning_rates(optimizer t, std::vector<double> &learning_rates) {
|
||||
PROTECT(
|
||||
set_lrs<torch::optim::AdamOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::AdamWOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||
)
|
||||
}
|
||||
|
||||
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
|
||||
torch::optim::OptimizerOptions *d = &(t->defaults());
|
||||
if (auto p = dynamic_cast<T *>(d)) {
|
||||
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
|
||||
auto ¶m_group = t->param_groups()[i];
|
||||
torch::optim::OptimizerOptions *d = &(param_group.options());
|
||||
if (auto p2 = dynamic_cast<T *>(d)) {
|
||||
lrs[i] = p2->lr();
|
||||
} else
|
||||
throw std::invalid_argument("unexpected param group type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) {
|
||||
PROTECT(
|
||||
int ngroup = t->param_groups().size();
|
||||
static vector<double> learning_rates(ngroup);
|
||||
get_lrs<torch::optim::AdamOptions>(t, learning_rates);
|
||||
get_lrs<torch::optim::AdamWOptions>(t, learning_rates);
|
||||
get_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||
get_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||
|
||||
copy(learning_rates.begin(), learning_rates.end(), lrs);
|
||||
param_group_num[0] = ngroup;
|
||||
)
|
||||
}
|
||||
|
||||
int64_t ato_param_group_num(optimizer t) {
|
||||
PROTECT(
|
||||
return t->param_groups().size();
|
||||
)
|
||||
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
// ============ End of set/get learning rates ==============================
|
||||
|
||||
void ato_set_momentum(optimizer t, double momentum) {
|
||||
PROTECT(
|
||||
torch::optim::OptimizerOptions* d = &(t->defaults());
|
||||
|
|
|
@ -130,6 +130,12 @@ void ato_zero_grad(optimizer);
|
|||
void ato_step(optimizer);
|
||||
void ato_free(optimizer);
|
||||
|
||||
// TT. APIs for learning rate scheduler
|
||||
void ato_set_learning_rates(optimizer, double* learning_rates);
|
||||
//double *ato_get_learning_rates(optimizer);
|
||||
int64_t ato_param_group_num(optimizer);
|
||||
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||
|
||||
scalar ats_int(int64_t);
|
||||
scalar ats_float(double);
|
||||
int64_t ats_to_int(scalar);
|
||||
|
@ -197,7 +203,7 @@ void ati_free(ivalue);
|
|||
#include "torch_api_generated.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
};
|
||||
}; // extern "C"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
|
@ -276,6 +276,15 @@ func (opt *Optimizer) SetLR(lr float64) {
|
|||
}
|
||||
}
|
||||
|
||||
func (opt *Optimizer) GetLRs() []float64 {
|
||||
lrs, err := opt.opt.GetLearningRates()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - GetLRs method call error: %v\n", err)
|
||||
}
|
||||
|
||||
return lrs
|
||||
}
|
||||
|
||||
// SetMomentum sets the optimizer momentum.
|
||||
func (opt *Optimizer) SetMomentum(m float64) {
|
||||
err := opt.opt.SetMomentum(m)
|
||||
|
@ -283,3 +292,12 @@ func (opt *Optimizer) SetMomentum(m float64) {
|
|||
log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (opt *Optimizer) ParamGroupNum() int64 {
|
||||
ngroup, err := opt.opt.ParamGroupNum()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
|
||||
}
|
||||
|
||||
return ngroup
|
||||
}
|
||||
|
|
|
@ -80,6 +80,27 @@ func (co *COptimizer) SetLearningRate(lr float64) error {
|
|||
return TorchErr()
|
||||
}
|
||||
|
||||
// GetLeanringRates get learning rates for the optimizer
|
||||
func (co *COptimizer) GetLearningRates() ([]float64, error) {
|
||||
lrs := lib.AtoGetLearningRates(co.coptimizer)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return lrs, nil
|
||||
}
|
||||
|
||||
func (co *COptimizer) ParamGroupNum() (int64, error) {
|
||||
ngroup := lib.AtoParamGroupNum(co.coptimizer)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
return ngroup, nil
|
||||
}
|
||||
|
||||
// SetMomentum sets a momentum for the optimizer
|
||||
func (co *COptimizer) SetMomentum(m float64) error {
|
||||
lib.AtoSetMomentum(co.coptimizer, m)
|
||||
|
|
Loading…
Reference in New Issue
Block a user