diff --git a/example/scheduler/main.go b/example/scheduler/main.go new file mode 100644 index 0000000..66a16db --- /dev/null +++ b/example/scheduler/main.go @@ -0,0 +1,33 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + "github.com/sugarme/gotch/nn" + "github.com/sugarme/gotch/vision" + // ts "github.com/sugarme/gotch/tensor" +) + +func main() { + + vs := nn.NewVarStore(gotch.CPU) + model := vision.EfficientNetB4(vs.Root(), 1000) + vs.Load("../../data/pretrained/efficientnet-b4.pt") + + adamConfig := nn.DefaultAdamConfig() + o, err := adamConfig.Build(vs, 0.001) + if err != nil { + log.Fatal(err) + } + + ngroup := o.ParamGroupNum() + lrs := o.GetLRs() + + fmt.Printf("Number of param groups: %v\n", ngroup) + fmt.Printf("Learning rates: %+v\n", lrs) + + log.Print(model) + +} diff --git a/libtch/tensor.go b/libtch/tensor.go index 19568bf..cc82f32 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -440,6 +440,31 @@ func AtoSetLearningRate(coptimizer Coptimizer, learningRate float64) { C.ato_set_learning_rate(coptimizer, clearningRate) } +func AtoGetLearningRates(coptimizer Coptimizer) []float64 { + cLRsPtr := (*C.double)(unsafe.Pointer(C.malloc(0))) + cngroup := (*C.int)(unsafe.Pointer(C.malloc(0))) + + C.ato_get_learning_rates(coptimizer, cLRsPtr, cngroup) + ngroup := *(*int)(unsafe.Pointer(cngroup)) + + var lrs []float64 = make([]float64, ngroup) + var currPtr *C.double = cLRsPtr + for i := 0; i < ngroup; i++ { + lrs[i] = *(*float64)(unsafe.Pointer(currPtr)) + nextPtr := (*C.double)(unsafe.Pointer(uintptr(unsafe.Pointer(currPtr)) + unsafe.Sizeof(currPtr))) + currPtr = nextPtr + } + + return lrs +} + +func AtoParamGroupNum(coptimizer Coptimizer) int64 { + cpgNum := C.ato_param_group_num(coptimizer) + + pgNum := *(*int64)(unsafe.Pointer(&cpgNum)) + return pgNum +} + // void ato_set_momentum(optimizer, double momentum); func AtoSetMomentum(coptimizer Coptimizer, momentum float64) { cmomentum := *(*C.double)(unsafe.Pointer(&momentum)) diff --git a/libtch/torch_api.cpp b/libtch/torch_api.cpp index addd6b0..a0e31a5 100644 --- a/libtch/torch_api.cpp +++ b/libtch/torch_api.cpp @@ -590,6 +590,73 @@ void ato_set_learning_rate_group(optimizer t, size_t group, double learning_rate ) } +// ============ set/get learning rates ============================== +// TT. added for learning rate scheduler +// lr scheduler APIs will be in Pytorch 1.9? +// Ref. https://github.com/pytorch/pytorch/issues/50577 +template +void set_lrs(optimizer t, std::vector &learning_rates) { + torch::optim::OptimizerOptions *d = &(t->defaults()); + if (auto p = dynamic_cast(d)) { + for (std::size_t i = 0; i < t->param_groups().size(); i++) { + auto ¶m_group = t->param_groups()[i]; + torch::optim::OptimizerOptions *d = &(param_group.options()); + if (auto p2 = dynamic_cast(d)) { + p2->lr(learning_rates[i]); + } else + throw std::invalid_argument("unexpected param group type"); + } + } +} + +void ato_set_learning_rates(optimizer t, std::vector &learning_rates) { + PROTECT( + set_lrs(t, learning_rates); + set_lrs(t, learning_rates); + set_lrs(t, learning_rates); + set_lrs(t, learning_rates); + ) +} + +template void get_lrs(optimizer t, vector &lrs) { + torch::optim::OptimizerOptions *d = &(t->defaults()); + if (auto p = dynamic_cast(d)) { + for (std::size_t i = 0; i < t->param_groups().size(); i++) { + auto ¶m_group = t->param_groups()[i]; + torch::optim::OptimizerOptions *d = &(param_group.options()); + if (auto p2 = dynamic_cast(d)) { + lrs[i] = p2->lr(); + } else + throw std::invalid_argument("unexpected param group type"); + } + } +} + +void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) { + PROTECT( + int ngroup = t->param_groups().size(); + static vector learning_rates(ngroup); + get_lrs(t, learning_rates); + get_lrs(t, learning_rates); + get_lrs(t, learning_rates); + get_lrs(t, learning_rates); + + copy(learning_rates.begin(), learning_rates.end(), lrs); + param_group_num[0] = ngroup; + ) +} + +int64_t ato_param_group_num(optimizer t) { + PROTECT( + return t->param_groups().size(); + ) + + return -1; +} + + +// ============ End of set/get learning rates ============================== + void ato_set_momentum(optimizer t, double momentum) { PROTECT( torch::optim::OptimizerOptions* d = &(t->defaults()); diff --git a/libtch/torch_api.h b/libtch/torch_api.h index a9420cc..a585277 100644 --- a/libtch/torch_api.h +++ b/libtch/torch_api.h @@ -130,6 +130,12 @@ void ato_zero_grad(optimizer); void ato_step(optimizer); void ato_free(optimizer); +// TT. APIs for learning rate scheduler +void ato_set_learning_rates(optimizer, double* learning_rates); +//double *ato_get_learning_rates(optimizer); +int64_t ato_param_group_num(optimizer); +void ato_get_learning_rates(optimizer, double *lrs, int *ngroup); + scalar ats_int(int64_t); scalar ats_float(double); int64_t ats_to_int(scalar); @@ -197,7 +203,7 @@ void ati_free(ivalue); #include "torch_api_generated.h" #ifdef __cplusplus -}; +}; // extern "C" #endif #endif diff --git a/nn/optimizer.go b/nn/optimizer.go index d2d85e6..f432fee 100644 --- a/nn/optimizer.go +++ b/nn/optimizer.go @@ -276,6 +276,15 @@ func (opt *Optimizer) SetLR(lr float64) { } } +func (opt *Optimizer) GetLRs() []float64 { + lrs, err := opt.opt.GetLearningRates() + if err != nil { + log.Fatalf("Optimizer - GetLRs method call error: %v\n", err) + } + + return lrs +} + // SetMomentum sets the optimizer momentum. func (opt *Optimizer) SetMomentum(m float64) { err := opt.opt.SetMomentum(m) @@ -283,3 +292,12 @@ func (opt *Optimizer) SetMomentum(m float64) { log.Fatalf("Optimizer - SetMomentum method call error: %v\n", err) } } + +func (opt *Optimizer) ParamGroupNum() int64 { + ngroup, err := opt.opt.ParamGroupNum() + if err != nil { + log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err) + } + + return ngroup +} diff --git a/tensor/optimizer.go b/tensor/optimizer.go index 10d7da9..91ecb79 100644 --- a/tensor/optimizer.go +++ b/tensor/optimizer.go @@ -80,6 +80,27 @@ func (co *COptimizer) SetLearningRate(lr float64) error { return TorchErr() } +// GetLeanringRates get learning rates for the optimizer +func (co *COptimizer) GetLearningRates() ([]float64, error) { + lrs := lib.AtoGetLearningRates(co.coptimizer) + + if err := TorchErr(); err != nil { + return nil, err + } + + return lrs, nil +} + +func (co *COptimizer) ParamGroupNum() (int64, error) { + ngroup := lib.AtoParamGroupNum(co.coptimizer) + + if err := TorchErr(); err != nil { + return -1, err + } + + return ngroup, nil +} + // SetMomentum sets a momentum for the optimizer func (co *COptimizer) SetMomentum(m float64) error { lib.AtoSetMomentum(co.coptimizer, m)