added libtorch APIs get/set optimizer learning rates
This commit is contained in:
parent
b02185df22
commit
dbab26c5aa
|
@ -6,16 +6,18 @@ import (
|
|||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/nn"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
// ts "github.com/sugarme/gotch/tensor"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
vs := nn.NewVarStore(gotch.CPU)
|
||||
model := vision.EfficientNetB4(vs.Root(), 1000)
|
||||
vs.Load("../../data/pretrained/efficientnet-b4.pt")
|
||||
device := gotch.CPU
|
||||
vs := nn.NewVarStore(device)
|
||||
// model := vision.EfficientNetB4(vs.Root(), 1000)
|
||||
// vs.Load("../../data/pretrained/efficientnet-b4.pt")
|
||||
|
||||
model := newNet(vs.Root())
|
||||
adamConfig := nn.DefaultAdamConfig()
|
||||
o, err := adamConfig.Build(vs, 0.001)
|
||||
if err != nil {
|
||||
|
@ -28,6 +30,56 @@ func main() {
|
|||
fmt.Printf("Number of param groups: %v\n", ngroup)
|
||||
fmt.Printf("Learning rates: %+v\n", lrs)
|
||||
|
||||
log.Print(model)
|
||||
newLRs := []float64{0.005}
|
||||
o.SetLRs(newLRs)
|
||||
fmt.Printf("New LRs: %+v\n", o.GetLRs())
|
||||
|
||||
zerosTs := ts.MustZeros([]int64{2, 2}, gotch.Float, device)
|
||||
onesTs := ts.MustOnes([]int64{3, 5}, gotch.Float, device)
|
||||
|
||||
o.AddParamGroup([]ts.Tensor{*zerosTs, *onesTs})
|
||||
fmt.Printf("New num of param groups: %v\n", o.ParamGroupNum())
|
||||
|
||||
fmt.Printf("New LRs: %+v\n", o.GetLRs())
|
||||
|
||||
// Set new lrs
|
||||
newLRs = []float64{0.0003, 0.0006}
|
||||
o.SetLRs(newLRs)
|
||||
fmt.Printf("New LRs: %+v\n", o.GetLRs())
|
||||
|
||||
log.Print(model)
|
||||
}
|
||||
|
||||
type Net struct {
|
||||
conv1 *nn.Conv2D
|
||||
conv2 *nn.Conv2D
|
||||
fc *nn.Linear
|
||||
}
|
||||
|
||||
func newNet(vs *nn.Path) *Net {
|
||||
conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig())
|
||||
conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig())
|
||||
fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig())
|
||||
|
||||
return &Net{
|
||||
conv1,
|
||||
conv2,
|
||||
fc,
|
||||
}
|
||||
}
|
||||
|
||||
func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
|
||||
xs = xs.MustView([]int64{-1, 1, 8, 8}, false)
|
||||
|
||||
outC1 := xs.Apply(n.conv1)
|
||||
outMP1 := outC1.MaxPool2DDefault(2, true)
|
||||
defer outMP1.MustDrop()
|
||||
|
||||
outC2 := outMP1.Apply(n.conv2)
|
||||
outMP2 := outC2.MaxPool2DDefault(2, true)
|
||||
outView2 := outMP2.MustView([]int64{-1, 10}, true)
|
||||
defer outView2.MustDrop()
|
||||
|
||||
outFC := outView2.Apply(n.fc)
|
||||
return outFC.MustRelu(true)
|
||||
}
|
||||
|
|
|
@ -9,10 +9,29 @@ package libtch
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var nativeEndian binary.ByteOrder
|
||||
|
||||
func init() {
|
||||
buf := [2]byte{}
|
||||
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
||||
|
||||
switch buf {
|
||||
case [2]byte{0xCD, 0xAB}:
|
||||
nativeEndian = binary.LittleEndian
|
||||
case [2]byte{0xAB, 0xCD}:
|
||||
nativeEndian = binary.BigEndian
|
||||
default:
|
||||
panic("Could not determine native endianness.")
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: C.tensor is a C pointer to torch::Tensor
|
||||
type Ctensor = C.tensor
|
||||
type Cscalar = C.scalar
|
||||
|
@ -418,20 +437,9 @@ func AtoAddParametersOld(coptimizer Coptimizer, tensors []Ctensor, ntensors int)
|
|||
C.ato_add_parameters_old(coptimizer, &ctensors[0], cntensors)
|
||||
}
|
||||
|
||||
// NOTE. This function is not working correctly. Need to update!!!
|
||||
// DO NOT USE!!!!!
|
||||
// TODO. updated
|
||||
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
|
||||
|
||||
var ctensors []C.tensor
|
||||
for i := 0; i < len(tensors); i++ {
|
||||
ctensors = append(ctensors, (C.tensor)(tensors[i]))
|
||||
}
|
||||
|
||||
cntensors := *(*C.size_t)(unsafe.Pointer(&ntensors))
|
||||
|
||||
// Just give pointer to the first element of ctensors slice
|
||||
C.ato_add_parameters(coptimizer, ctensors[0], cntensors)
|
||||
func AtoAddParameter(coptimizer Coptimizer, tensor Ctensor, group uint) {
|
||||
cgroup := *(*C.ulong)(unsafe.Pointer(&group))
|
||||
C.ato_add_parameter(coptimizer, tensor, cgroup)
|
||||
}
|
||||
|
||||
// void ato_set_learning_rate(optimizer, double learning_rate);
|
||||
|
@ -458,6 +466,24 @@ func AtoGetLearningRates(coptimizer Coptimizer) []float64 {
|
|||
return lrs
|
||||
}
|
||||
|
||||
func AtoSetLearningRates(coptimizer Coptimizer, lrs []float64) {
|
||||
elementNum := len(lrs)
|
||||
eltSizeInBytes := 8 // float64 takes 8 Bytes
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
dataPtr := C.malloc(C.size_t(nbytes))
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
dataSlice := (*[1 << 32]byte)(dataPtr)[:nbytes:nbytes] // 4294967296
|
||||
buf := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
if err := binary.Write(buf, nativeEndian, lrs); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
clrs := (*C.double)(dataPtr)
|
||||
lrsNum := len(lrs)
|
||||
clrsNum := *(*C.int)(unsafe.Pointer(&lrsNum))
|
||||
C.ato_set_learning_rates(coptimizer, clrs, clrsNum)
|
||||
}
|
||||
|
||||
func AtoParamGroupNum(coptimizer Coptimizer) int64 {
|
||||
cpgNum := C.ato_param_group_num(coptimizer)
|
||||
|
||||
|
@ -465,6 +491,16 @@ func AtoParamGroupNum(coptimizer Coptimizer) int64 {
|
|||
return pgNum
|
||||
}
|
||||
|
||||
func AtoAddParamGroup(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
|
||||
var ctensors []C.tensor
|
||||
for i := 0; i < len(tensors); i++ {
|
||||
ctensors = append(ctensors, (C.tensor)(tensors[i]))
|
||||
}
|
||||
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
|
||||
|
||||
C.ato_add_param_group(coptimizer, &ctensors[0], cntensors)
|
||||
}
|
||||
|
||||
// void ato_set_momentum(optimizer, double momentum);
|
||||
func AtoSetMomentum(coptimizer Coptimizer, momentum float64) {
|
||||
cmomentum := *(*C.double)(unsafe.Pointer(&momentum))
|
||||
|
|
|
@ -531,6 +531,7 @@ optimizer ato_sgd(double learning_rate,
|
|||
}
|
||||
|
||||
// NOTE. backward compat as param group (#261) not updated yet.
|
||||
// Deprecated
|
||||
void ato_add_parameters_old(optimizer t, tensor *tensors, int ntensors) {
|
||||
PROTECT(
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
|
@ -538,7 +539,7 @@ void ato_add_parameters_old(optimizer t, tensor *tensors, int ntensors) {
|
|||
)
|
||||
}
|
||||
|
||||
void ato_add_parameters(optimizer t, tensor tensor, size_t group) {
|
||||
void ato_add_parameter(optimizer t, tensor tensor, size_t group) {
|
||||
PROTECT(
|
||||
auto &groups = t->param_groups();
|
||||
while (groups.size() <= group) {
|
||||
|
@ -595,10 +596,10 @@ void ato_set_learning_rate_group(optimizer t, size_t group, double learning_rate
|
|||
// lr scheduler APIs will be in Pytorch 1.9?
|
||||
// Ref. https://github.com/pytorch/pytorch/issues/50577
|
||||
template <class T>
|
||||
void set_lrs(optimizer t, std::vector<double> &learning_rates) {
|
||||
void set_lrs(optimizer t, double *learning_rates) {
|
||||
torch::optim::OptimizerOptions *d = &(t->defaults());
|
||||
if (auto p = dynamic_cast<T *>(d)) {
|
||||
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
|
||||
for (int i = 0; i < t->param_groups().size(); i++) {
|
||||
auto ¶m_group = t->param_groups()[i];
|
||||
torch::optim::OptimizerOptions *d = &(param_group.options());
|
||||
if (auto p2 = dynamic_cast<T *>(d)) {
|
||||
|
@ -609,12 +610,19 @@ void set_lrs(optimizer t, std::vector<double> &learning_rates) {
|
|||
}
|
||||
}
|
||||
|
||||
void ato_set_learning_rates(optimizer t, std::vector<double> &learning_rates) {
|
||||
void ato_set_learning_rates(optimizer t, double *lrs, int lrs_num) {
|
||||
PROTECT(
|
||||
set_lrs<torch::optim::AdamOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::AdamWOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||
set_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||
int ngroup = t->param_groups().size();
|
||||
if (lrs == nullptr){
|
||||
throw std::invalid_argument("Input learning rates should not be null");
|
||||
}
|
||||
if (ngroup != lrs_num){
|
||||
throw std::invalid_argument("Size of input learning rates is unequal to number of parameter groups.");
|
||||
}
|
||||
set_lrs<torch::optim::AdamOptions>(t, lrs);
|
||||
set_lrs<torch::optim::AdamWOptions>(t, lrs);
|
||||
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
|
||||
set_lrs<torch::optim::SGDOptions>(t, lrs);
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -641,7 +649,9 @@ void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) {
|
|||
get_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
|
||||
get_lrs<torch::optim::SGDOptions>(t, learning_rates);
|
||||
|
||||
copy(learning_rates.begin(), learning_rates.end(), lrs);
|
||||
for (int i = 0; i < ngroup; i++){
|
||||
lrs[i] = learning_rates[i];
|
||||
}
|
||||
param_group_num[0] = ngroup;
|
||||
)
|
||||
}
|
||||
|
@ -654,6 +664,24 @@ int64_t ato_param_group_num(optimizer t) {
|
|||
return -1;
|
||||
}
|
||||
|
||||
void ato_add_param_group(optimizer t, tensor *tensors, int ntensors){
|
||||
PROTECT(
|
||||
// ref.https://github.com/pytorch/pytorch/blob/45c5bac87049a9a0593d4d5f060cc8d4b0f83db8/test/cpp/api/optim.cpp#L153-L154
|
||||
std::vector<torch::Tensor> params;
|
||||
for (int i = 0; i < ntensors; ++i){
|
||||
params.push_back(*tensors[i]);
|
||||
}
|
||||
auto& params_groups = t->param_groups();
|
||||
params_groups.push_back(torch::optim::OptimizerParamGroup(params));
|
||||
std::unique_ptr<torch::optim::OptimizerOptions> opt = params_groups[0].options().clone();
|
||||
int ngroup = params_groups.size();
|
||||
params_groups[ngroup - 1].set_options(std::move(opt));
|
||||
// set default learning rate.
|
||||
double default_lr = 0.001;
|
||||
ato_set_learning_rate_group(t, ngroup-1, default_lr);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
// ============ End of set/get learning rates ==============================
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ optimizer ato_sgd(double learning_rate, double momentum, double dampening,
|
|||
// NOTE. switch back as param group #261 not updated yet.
|
||||
// Backward compat
|
||||
void ato_add_parameters_old(optimizer, tensor *, int ntensors);
|
||||
void ato_add_parameters(optimizer, tensor, size_t group);
|
||||
void ato_add_parameter(optimizer, tensor, size_t group);
|
||||
void ato_set_learning_rate(optimizer, double learning_rate);
|
||||
void ato_set_momentum(optimizer, double momentum);
|
||||
void ato_set_learning_rate_group(optimizer, size_t group, double learning_rate);
|
||||
|
@ -131,10 +131,10 @@ void ato_step(optimizer);
|
|||
void ato_free(optimizer);
|
||||
|
||||
// TT. APIs for learning rate scheduler
|
||||
void ato_set_learning_rates(optimizer, double* learning_rates);
|
||||
//double *ato_get_learning_rates(optimizer);
|
||||
void ato_set_learning_rates(optimizer, double* learning_rates, int lrs_num);
|
||||
int64_t ato_param_group_num(optimizer);
|
||||
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
||||
void ato_add_param_group(optimizer, tensor *params, int param_num);
|
||||
|
||||
scalar ats_int(int64_t);
|
||||
scalar ats_float(double);
|
||||
|
|
|
@ -3,6 +3,7 @@ package nn
|
|||
// Optimizers to be used for gradient-descent based training.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -40,20 +41,15 @@ func defaultBuild(config OptimizerConfig, vs *VarStore, lr float64) (retVal *Opt
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
var parameters []ts.Tensor
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
param := v.MustShallowClone()
|
||||
parameters = append(parameters, *param)
|
||||
}
|
||||
|
||||
if len(vs.Vars.TrainableVariables) > 0 {
|
||||
if err = opt.AddParameters(vs.Vars.TrainableVariables); err != nil {
|
||||
return retVal, err
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
if err = opt.AddParameter(v.Tensor, v.Group); err != nil {
|
||||
err = fmt.Errorf("Optimizer defaultBuild - AddParameter failed: %w\n", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: should we clone or copy?
|
||||
|
||||
return &Optimizer{
|
||||
opt: opt,
|
||||
// variables: vs.Vars,
|
||||
|
@ -230,21 +226,16 @@ func (opt *Optimizer) Step() {
|
|||
|
||||
// BackwardStep applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
func (opt *Optimizer) BackwardStep(loss *ts.Tensor) {
|
||||
|
||||
opt.addMissingVariables()
|
||||
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - ZeroGrad error: %v\n", err)
|
||||
}
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStep method call - Step() error: %v\n", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// BackwardStepClip applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
|
@ -252,23 +243,39 @@ func (opt *Optimizer) BackwardStep(loss *ts.Tensor) {
|
|||
// The gradients are clipped based on `max` before being applied.
|
||||
func (opt *Optimizer) BackwardStepClip(loss *ts.Tensor, max float64) {
|
||||
opt.addMissingVariables()
|
||||
|
||||
err := opt.opt.ZeroGrad()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - ZeroGrad error: %v\n", err)
|
||||
}
|
||||
|
||||
loss.MustBackward()
|
||||
|
||||
opt.ClipGradValue(max)
|
||||
|
||||
err = opt.opt.Step()
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - BackwardStepClip method call - Step() error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
/// TODO. Clips gradient L2 norm over all trainable parameters.
|
||||
//
|
||||
// The norm is computed over all gradients together, as if they were
|
||||
// concatenated into a single vector.
|
||||
func (opt *Optimizer) ClipGradNorm(max float64) {
|
||||
// TODO.
|
||||
log.Fatalf("Not implemented yet!")
|
||||
}
|
||||
|
||||
// TODO. Applies a backward step pass, update the gradients, and performs an optimization step.
|
||||
//
|
||||
// The gradients L2 norm is clipped based on `max`.
|
||||
func (opt *Optimizer) BackwardStepClipNorm(loss *ts.Tensor, max float64) {
|
||||
// TODO.
|
||||
log.Fatalf("Not implemented yet!")
|
||||
}
|
||||
|
||||
// SetLR sets the optimizer learning rate.
|
||||
//
|
||||
// NOTE. it sets a SINGLE value of learning rate for all parameter groups.
|
||||
// Most of the time, there's one parameter group.
|
||||
func (opt *Optimizer) SetLR(lr float64) {
|
||||
err := opt.opt.SetLearningRate(lr)
|
||||
if err != nil {
|
||||
|
@ -285,6 +292,14 @@ func (opt *Optimizer) GetLRs() []float64 {
|
|||
return lrs
|
||||
}
|
||||
|
||||
// SetLRs sets learning rates for ALL parameter groups respectively.
|
||||
func (opt *Optimizer) SetLRs(lrs []float64) {
|
||||
err := opt.opt.SetLearningRates(lrs)
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - SetLRs method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMomentum sets the optimizer momentum.
|
||||
func (opt *Optimizer) SetMomentum(m float64) {
|
||||
err := opt.opt.SetMomentum(m)
|
||||
|
@ -301,3 +316,10 @@ func (opt *Optimizer) ParamGroupNum() int64 {
|
|||
|
||||
return ngroup
|
||||
}
|
||||
|
||||
func (opt *Optimizer) AddParamGroup(tensors []ts.Tensor) {
|
||||
err := opt.opt.AddParamGroup(tensors)
|
||||
if err != nil {
|
||||
log.Fatalf("Optimizer - ParamGroupNum method call error: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,6 +14,11 @@ import (
|
|||
// SEP is a separator to separate path elements in the tensor names.
|
||||
const SEP = "."
|
||||
|
||||
type Var struct {
|
||||
Tensor *ts.Tensor
|
||||
Group uint // optimizer parameter group
|
||||
}
|
||||
|
||||
// Variables represents a collection of tensors.
|
||||
//
|
||||
// NOTE: When the variable store is frozen, trainable still is set to tree,
|
||||
|
@ -21,7 +26,8 @@ const SEP = "."
|
|||
type Variables struct {
|
||||
mutex *sync.Mutex
|
||||
NamedVariables map[string]*ts.Tensor
|
||||
TrainableVariables []ts.Tensor
|
||||
TrainableVariables []Var
|
||||
// TrainableVariables []ts.Tensor
|
||||
}
|
||||
|
||||
// VarStore is used to store variables used by one or multiple layers.
|
||||
|
@ -35,6 +41,7 @@ type VarStore struct {
|
|||
type Path struct {
|
||||
path []string
|
||||
varstore *VarStore
|
||||
group uint // optimizer parameter group
|
||||
}
|
||||
|
||||
// Entry holds an entry corresponding to a given name in Path.
|
||||
|
@ -49,7 +56,8 @@ func NewVarStore(device gotch.Device) *VarStore {
|
|||
variables := Variables{
|
||||
mutex: &sync.Mutex{},
|
||||
NamedVariables: make(map[string]*ts.Tensor, 0),
|
||||
TrainableVariables: make([]ts.Tensor, 0),
|
||||
TrainableVariables: make([]Var, 0),
|
||||
// TrainableVariables: make([]ts.Tensor, 0),
|
||||
}
|
||||
|
||||
return &VarStore{
|
||||
|
@ -88,9 +96,9 @@ func (vs *VarStore) TrainableVariables() []ts.Tensor {
|
|||
vs.Vars.mutex.Lock()
|
||||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
retVal := vs.Vars.TrainableVariables
|
||||
for _, t := range vs.Vars.TrainableVariables {
|
||||
retVal = append(retVal, *t.MustShallowClone())
|
||||
var retVal []ts.Tensor
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
retVal = append(retVal, *v.Tensor.MustShallowClone())
|
||||
}
|
||||
|
||||
return retVal
|
||||
|
@ -119,6 +127,7 @@ func (vs *VarStore) Root() *Path {
|
|||
return &Path{
|
||||
path: []string{},
|
||||
varstore: vs,
|
||||
group: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -257,7 +266,7 @@ func (vs *VarStore) Freeze() {
|
|||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(false, false)
|
||||
_, err := v.Tensor.SetRequiresGrad(false, false)
|
||||
if err != nil {
|
||||
log.Fatalf("Freeze() Error: %v\n", err)
|
||||
}
|
||||
|
@ -272,7 +281,7 @@ func (vs *VarStore) Unfreeze() {
|
|||
defer vs.Vars.mutex.Unlock()
|
||||
|
||||
for _, v := range vs.Vars.TrainableVariables {
|
||||
_, err := v.SetRequiresGrad(true, false)
|
||||
_, err := v.Tensor.SetRequiresGrad(true, false)
|
||||
if err != nil {
|
||||
log.Fatalf("Unfreeze() Error: %v\n", err)
|
||||
}
|
||||
|
@ -328,6 +337,7 @@ func (p *Path) Sub(str string) *Path {
|
|||
return &Path{
|
||||
path: path,
|
||||
varstore: p.varstore,
|
||||
group: p.group,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -375,7 +385,11 @@ func (p *Path) add(name string, newTs *ts.Tensor, trainable bool) *ts.Tensor {
|
|||
}
|
||||
|
||||
if trainable {
|
||||
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, *tensor)
|
||||
v := Var{
|
||||
Tensor: tensor,
|
||||
Group: p.group,
|
||||
}
|
||||
p.varstore.Vars.TrainableVariables = append(p.varstore.Vars.TrainableVariables, v)
|
||||
}
|
||||
|
||||
p.varstore.Vars.NamedVariables[path] = tensor
|
||||
|
@ -409,7 +423,11 @@ func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool,
|
|||
}
|
||||
|
||||
if trainable {
|
||||
variables.TrainableVariables = append(variables.TrainableVariables, *ttensor)
|
||||
v := Var{
|
||||
Tensor: ttensor,
|
||||
Group: p.group,
|
||||
}
|
||||
variables.TrainableVariables = append(variables.TrainableVariables, v)
|
||||
}
|
||||
|
||||
variables.NamedVariables[path] = ttensor
|
||||
|
@ -417,6 +435,10 @@ func (p *Path) getOrAddWithLock(name string, tensor *ts.Tensor, trainable bool,
|
|||
return ttensor
|
||||
}
|
||||
|
||||
func (p *Path) SetGroup(g uint) {
|
||||
p.group = g
|
||||
}
|
||||
|
||||
// ZerosNoTrain creates a new variable initialized with zeros.
|
||||
//
|
||||
// The new variable is named according to the name parameter and
|
||||
|
|
|
@ -59,7 +59,6 @@ func Sgd(lr, momentum, dampening, wd float64, nesterov bool) (*COptimizer, error
|
|||
|
||||
// AddParameters adds parameters as a slice of tensors to optimizer
|
||||
func (co *COptimizer) AddParameters(tensors []Tensor) error {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
|
@ -73,6 +72,13 @@ func (co *COptimizer) AddParameters(tensors []Tensor) error {
|
|||
return TorchErr()
|
||||
}
|
||||
|
||||
// AddParameter adds a single parameter to parameter group.
|
||||
func (co *COptimizer) AddParameter(param *Tensor, group uint) error {
|
||||
lib.AtoAddParameter(co.coptimizer, param.ctensor, group)
|
||||
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// SetLeanringRate sets learning rate for the optimizer
|
||||
func (co *COptimizer) SetLearningRate(lr float64) error {
|
||||
lib.AtoSetLearningRate(co.coptimizer, lr)
|
||||
|
@ -91,6 +97,15 @@ func (co *COptimizer) GetLearningRates() ([]float64, error) {
|
|||
return lrs, nil
|
||||
}
|
||||
|
||||
func (co *COptimizer) SetLearningRates(lrs []float64) error {
|
||||
lib.AtoSetLearningRates(co.coptimizer, lrs)
|
||||
if err := TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (co *COptimizer) ParamGroupNum() (int64, error) {
|
||||
ngroup := lib.AtoParamGroupNum(co.coptimizer)
|
||||
|
||||
|
@ -101,6 +116,18 @@ func (co *COptimizer) ParamGroupNum() (int64, error) {
|
|||
return ngroup, nil
|
||||
}
|
||||
|
||||
func (co *COptimizer) AddParamGroup(tensors []Tensor) error {
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
ntensors := len(tensors)
|
||||
|
||||
lib.AtoAddParamGroup(co.coptimizer, ctensors, ntensors)
|
||||
return TorchErr()
|
||||
}
|
||||
|
||||
// SetMomentum sets a momentum for the optimizer
|
||||
func (co *COptimizer) SetMomentum(m float64) error {
|
||||
lib.AtoSetMomentum(co.coptimizer, m)
|
||||
|
|
Loading…
Reference in New Issue
Block a user