From ca8e51fad82df33e317008c64f06a6a8acd40f5e Mon Sep 17 00:00:00 2001 From: sugarme Date: Tue, 3 Nov 2020 00:14:56 +1100 Subject: [PATCH] switched back to lib.ato_add_parameters_old as param group not updated yet --- libtch/tensor.go | 17 +++++++++++++++++ libtch/torch_api.cpp | 8 ++++++++ libtch/torch_api.h | 3 +++ tensor/optimizer.go | 3 ++- 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/libtch/tensor.go b/libtch/tensor.go index b7999a5..4e60c4c 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -401,7 +401,24 @@ func AtoSgd(learningRate, momentum, dampening, weightDecay float64, nesterov int return C.ato_sgd(clearningRate, cmomentum, cdampening, cweightDecay, cnesterov) } +// NOTE. Backward compat for param group not updated (#261) // void ato_add_parameters(optimizer, tensor *, int ntensors); +func AtoAddParametersOld(coptimizer Coptimizer, tensors []Ctensor, ntensors int) { + + var ctensors []C.tensor + for i := 0; i < len(tensors); i++ { + ctensors = append(ctensors, (C.tensor)(tensors[i])) + } + + cntensors := *(*C.int)(unsafe.Pointer(&ntensors)) + + // Just give pointer to the first element of ctensors slice + C.ato_add_parameters_old(coptimizer, &ctensors[0], cntensors) +} + +// NOTE. This function is not working correctly. Need to update!!! +// DO NOT USE!!!!! +// TODO. updated func AtoAddParameters(coptimizer Coptimizer, tensors []Ctensor, ntensors int) { var ctensors []C.tensor diff --git a/libtch/torch_api.cpp b/libtch/torch_api.cpp index 1aa4262..0462519 100644 --- a/libtch/torch_api.cpp +++ b/libtch/torch_api.cpp @@ -530,6 +530,14 @@ optimizer ato_sgd(double learning_rate, return nullptr; } +// NOTE. backward compat as param group (#261) not updated yet. +void ato_add_parameters_old(optimizer t, tensor *tensors, int ntensors) { + PROTECT( + for (int i = 0; i < ntensors; ++i) + t->param_groups()[0].params().push_back(*(tensors[i])); + ) +} + void ato_add_parameters(optimizer t, tensor tensor, size_t group) { PROTECT( auto &groups = t->param_groups(); diff --git a/libtch/torch_api.h b/libtch/torch_api.h index 42a6948..8e29e46 100644 --- a/libtch/torch_api.h +++ b/libtch/torch_api.h @@ -116,6 +116,9 @@ optimizer ato_rms_prop(double learning_rate, double alpha, double eps, double weight_decay, double momentum, int centered); optimizer ato_sgd(double learning_rate, double momentum, double dampening, double weight_decay, int nesterov); +// NOTE. switch back as param group #261 not updated yet. +// Backward compat +void ato_add_parameters_old(optimizer, tensor *, int ntensors); void ato_add_parameters(optimizer, tensor, size_t group); void ato_set_learning_rate(optimizer, double learning_rate); void ato_set_momentum(optimizer, double momentum); diff --git a/tensor/optimizer.go b/tensor/optimizer.go index 6be1460..10d7da9 100644 --- a/tensor/optimizer.go +++ b/tensor/optimizer.go @@ -67,7 +67,8 @@ func (co *COptimizer) AddParameters(tensors []Tensor) error { ntensors := len(tensors) - lib.AtoAddParameters(co.coptimizer, ctensors, ntensors) + // NOTE. temporary switch back as param group not updated yet! + lib.AtoAddParametersOld(co.coptimizer, ctensors, ntensors) return TorchErr() }