switched back to lib.ato_add_parameters_old as param group not updated yet

This commit is contained in:
sugarme 2020-11-03 00:14:56 +11:00
parent c38c909977
commit ca8e51fad8
4 changed files with 30 additions and 1 deletions

View File

@ -401,7 +401,24 @@ func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int
return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov)
}
// NOTE. Backward compat for param group not updated (#261)
// void ato_add_parameters(optimizer, tensor *, int ntensors);
func AtoAddParametersOld(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
var ctensors []C.tensor
for i := 0; i < len(tensors); i++ {
ctensors = append(ctensors, (C.tensor)(tensors[i]))
}
cntensors := *(*C.int)(unsafe.Pointer(&ntensors))
// Just give pointer to the first element of ctensors slice
C.ato_add_parameters_old(coptimizer, &ctensors[0], cntensors)
}
// NOTE. This function is not working correctly. Need to update!!!
// DO NOT USE!!!!!
// TODO. updated
func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) {
var ctensors []C.tensor

View File

@ -530,6 +530,14 @@ optimizer ato_sgd(double learning_rate,
return nullptr;
}
// NOTE. backward compat as param group (#261) not updated yet.
void ato_add_parameters_old(optimizer t, tensor *tensors, int ntensors) {
PROTECT(
for (int i = 0; i < ntensors; ++i)
t->param_groups()[0].params().push_back(*(tensors[i]));
)
}
void ato_add_parameters(optimizer t, tensor tensor, size_t group) {
PROTECT(
auto &groups = t->param_groups();

View File

@ -116,6 +116,9 @@ optimizer ato_rms_prop(double learning_rate, double alpha, double eps,
double weight_decay, double momentum, int centered);
optimizer ato_sgd(double learning_rate, double momentum, double dampening,
double weight_decay, int nesterov);
// NOTE. switch back as param group #261 not updated yet.
// Backward compat
void ato_add_parameters_old(optimizer, tensor *, int ntensors);
void ato_add_parameters(optimizer, tensor, size_t group);
void ato_set_learning_rate(optimizer, double learning_rate);
void ato_set_momentum(optimizer, double momentum);

View File

@ -67,7 +67,8 @@ func (co *COptimizer) AddParameters(tensors []Tensor) error {
ntensors := len(tensors)
lib.AtoAddParameters(co.coptimizer, ctensors, ntensors)
// NOTE. temporary switch back as param group not updated yet!
lib.AtoAddParametersOld(co.coptimizer, ctensors, ntensors)
return TorchErr()
}