gotch/libtch/torch_api.cpp
2022-01-17 21:41:16 +11:00

1046 lines
37 KiB
C++

#include<torch/csrc/autograd/engine.h>
#include<torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/passes/fixup_trace_scope_blocks.h>
#include <torch/csrc/jit/passes/normalize_ops.h>
#include<torch/torch.h>
#include<ATen/autocast_mode.h>
#include<torch/script.h>
#include<stdexcept>
#include<vector>
#include "torch_api.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "stb_image_resize.h"
using namespace std;
char *get_and_reset_last_err() {
char *tmp = torch_last_err;
torch_last_err = nullptr;
return tmp;
}
void at_manual_seed(int64_t seed) { torch::manual_seed(seed); }
vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len) {
vector<torch::Tensor> result;
for (int i = 0; i < len; ++i)
result.push_back(*(vs[i]));
return result;
}
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs, int len) {
vector<c10::optional<torch::Tensor>> result;
for (int i = 0; i < len; ++i) {
result.push_back(vs[i] != nullptr ? c10::optional<torch::Tensor>(*(vs[i])) : c10::nullopt);
}
return c10::List<c10::optional<torch::Tensor>>(result);
}
at::Device device_of_int(int d) {
if (d < 0)
return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
}
tensor at_new_tensor() {
PROTECT(return new torch::Tensor();)
return nullptr;
}
tensor at_tensor_of_blob(void *data, int64_t *dims, size_t ndims,
int64_t *strides, size_t nstrides, int type,
int device) {
PROTECT(at::TensorOptions blobOptions = at::TensorOptions()
.device(device_of_int(device))
.dtype(torch::ScalarType(type));
return new torch::Tensor(torch::from_blob(
data, torch::IntArrayRef(dims, ndims),
torch::IntArrayRef(strides, nstrides), blobOptions));)
return nullptr;
}
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
size_t element_size_in_bytes, int type) {
PROTECT(torch::Tensor tensor = torch::zeros(torch::IntArrayRef(dims, ndims),
torch::ScalarType(type));
if (element_size_in_bytes != tensor.element_size()) throw std::
invalid_argument("incoherent element sizes in bytes");
void *tensor_data = tensor.data_ptr();
memcpy(tensor_data, vs, tensor.numel() * element_size_in_bytes);
return new torch::Tensor(tensor);)
return nullptr;
}
void at_copy_data(tensor tensor, void *vs, size_t numel,
size_t elt_size_in_bytes) {
PROTECT(if (elt_size_in_bytes != tensor->element_size()) throw std::
invalid_argument("incoherent element sizes in bytes");
if (numel > tensor->numel()) throw std::invalid_argument(
"target numel is larger than tensor numel");
if (tensor->device().type() != at::kCPU) {
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
} else {
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
})
}
tensor at_shallow_clone(tensor t) {
PROTECT(return new torch::Tensor(*t);)
return nullptr;
}
void *at_data_ptr(tensor t) {
PROTECT(return t->data_ptr();)
return nullptr;
}
int at_defined(tensor t) {
PROTECT(return t->defined();)
return -1;
}
int at_is_mkldnn(tensor t) {
PROTECT(return t->is_mkldnn();)
return -1;
}
int at_is_sparse(tensor t) {
PROTECT(return t->is_sparse();)
return -1;
}
size_t at_dim(tensor t) {
PROTECT(return t->dim();)
return -1;
}
void at_shape(tensor t, int64_t *dims) {
PROTECT(int i = 0; for (int64_t dim : t->sizes()) dims[i++] = dim;)
}
void at_stride(tensor t, int64_t *dims) {
PROTECT(int i = 0; for (int64_t dim : t->strides()) dims[i++] = dim;)
}
int at_scalar_type(tensor t) {
PROTECT(return static_cast<int>(t->scalar_type());)
return -1;
}
// void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor
// inf_scale) { PROTECT( at::_amp_non_finite_check_and_unscale_(*t, *found_inf,
// *inf_scale);
// )
// }
void at__amp_non_finite_check_and_unscale(tensor t, tensor found_inf, tensor inf_scale) {
PROTECT(
at::_amp_foreach_non_finite_check_and_unscale_(*t, *found_inf, *inf_scale);
)
}
void at_autocast_clear_cache() { at::autocast::clear_cache(); }
int at_autocast_decrement_nesting() {
PROTECT(return at::autocast::decrement_nesting();)
return -1;
}
int at_autocast_increment_nesting() {
PROTECT(return at::autocast::increment_nesting();)
return -1;
}
bool at_autocast_is_enabled() {
PROTECT(return at::autocast::is_enabled();)
return -1;
}
bool at_autocast_set_enabled(bool b) {
PROTECT(bool is_enabled = at::autocast::is_enabled();
at::autocast::set_enabled(b); return is_enabled;)
return -1;
}
int at_device(tensor t) {
PROTECT(auto device = t->device(); if (device.type() == at::kCPU) return -1;
if (device.type() == at::kCUDA) return device.index();)
return -2;
}
void at_backward(tensor t, int keep_graph, int create_graph) {
PROTECT(t->backward({}, keep_graph, create_graph);)
}
int at_requires_grad(tensor t) {
PROTECT(return t->requires_grad();)
return -1;
}
int at_grad_set_enabled(int b) {
PROTECT(bool is_enabled = torch::autograd::GradMode::is_enabled();
torch::autograd::GradMode::set_enabled(b); return is_enabled;)
return -1;
}
tensor at_get(tensor t, int index) {
PROTECT(return new torch::Tensor((*t)[index]);)
return nullptr;
}
template <typename T>
T at_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
PROTECT(torch::Tensor tensor = *t; for (int i = 0; i < indexes_len; ++i) {
tensor = tensor[indexes[i]];
} return tensor.item<T>();)
return T();
}
double at_double_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
return at_value_at_indexes<double>(t, indexes, indexes_len);
}
int64_t at_int64_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
return at_value_at_indexes<int64_t>(t, indexes, indexes_len);
}
template <typename T>
void at_set_value_at_indexes(tensor t, int *indexes, int indexes_len, T v) {
PROTECT(torch::Tensor tensor = *t; for (int i = 0; i < indexes_len; ++i) {
tensor = tensor[indexes[i]];
} tensor.fill_(v);)
}
void at_set_double_value_at_indexes(tensor t, int *indexes, int indexes_len,
double v) {
at_set_value_at_indexes<double>(t, indexes, indexes_len, v);
}
void at_set_int64_value_at_indexes(tensor t, int *indexes, int indexes_len,
int64_t v) {
at_set_value_at_indexes<int64_t>(t, indexes, indexes_len, v);
}
void at_fill_double(tensor t, double v) { PROTECT(t->fill_(v);) }
void at_fill_int64(tensor t, int64_t v) { PROTECT(t->fill_(v);) }
void at_print(tensor t) {
PROTECT(torch::Tensor *tensor = (torch::Tensor *)t; cout << *tensor << endl;)
}
char *at_to_string(tensor t, int line_size) {
PROTECT(std::ostringstream oss; torch::print(oss, *t, line_size);
return strdup(oss.str().c_str());)
return nullptr;
}
void at_copy_(tensor dst, tensor src) { PROTECT(dst->copy_(*src);) }
void at_save(tensor t, char *filename) { PROTECT(torch::save(*t, filename);) }
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors,
char *filename) {
PROTECT(torch::serialize::OutputArchive archive;
for (int i = 0; i < ntensors; ++i) archive.write(
std::string(tensor_names[i]), *(tensors[i]), /* buffer=*/false);
archive.save_to(filename);)
}
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors,
char *filename) {
PROTECT(torch::serialize::InputArchive archive;
archive.load_from(std::string(filename));
vector<torch::Tensor> ts(ntensors);
for (int i = 0; i < ntensors; ++i)
archive.read(std::string(tensor_names[i]), ts[i]);
// Only allocate the new tensor now so that if there is an exception
// raised during [read], no memory has to be freed.
for (int i = 0; i < ntensors; ++i) tensors[i] =
new torch::Tensor(ts[i]);)
}
void at_load_callback(char *filename, void *data,
void (*f)(void *, char *, tensor)) {
PROTECT(auto module = torch::jit::load(filename);
for (const auto &p
: module.named_parameters()) {
auto v = p.value;
f(data, (char *)p.name.c_str(), new torch::Tensor(v));
})
}
void at_load_callback_with_device(char *filename, void *data,
void (*f)(void *, char *, tensor),
int device_id) {
PROTECT(auto module = torch::jit::load(filename, device_of_int(device_id));
for (const auto &p
: module.named_parameters()) {
auto v = p.value;
f(data, (char *)p.name.c_str(), new torch::Tensor(v));
})
}
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors,
char *filename) {
PROTECT(torch::NoGradGuard no_grad; torch::serialize::InputArchive archive;
archive.load_from(std::string(filename));
for (int i = 0; i < ntensors; ++i) {
if (tensors[i]->device().type() == at::kCPU)
archive.read(std::string(tensor_names[i]), *(tensors[i]));
else {
torch::Tensor tmp_tensor =
torch::empty_like(*(tensors[i]), at::device(at::kCPU));
archive.read(std::string(tensor_names[i]), tmp_tensor);
tensors[i]->copy_(tmp_tensor);
}
})
}
tensor at_load(char *filename) {
PROTECT(torch::Tensor tensor; torch::load(tensor, filename);
return new torch::Tensor(tensor);)
return nullptr;
}
tensor at_load_image(char *filename) {
PROTECT(
int w = -1; int h = -1; int c = -1;
void *data = stbi_load(filename, &w, &h, &c, 3);
if (data == nullptr) throw std::invalid_argument(stbi_failure_reason());
torch::Tensor tensor = torch::zeros({h, w, 3}, at::ScalarType::Byte);
memcpy(tensor.data_ptr(), data, h * w * 3); free(data);
return new torch::Tensor(tensor);)
return nullptr;
}
bool ends_with(const char *str, const char *suffix) {
int suffix_len = strlen(suffix);
int str_len = strlen(str);
if (str_len < suffix_len)
return false;
for (int i = 1; i <= suffix_len; ++i)
if (str[str_len - i] != suffix[suffix_len - i])
return false;
return true;
}
int at_save_image(tensor tensor, char *filename) {
PROTECT(auto sizes = tensor->sizes();
if (sizes.size() != 3) throw std::invalid_argument(
"invalid number of dimensions, should be 3");
int h = sizes[0]; int w = sizes[1]; int c = sizes[2];
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
if (ends_with(filename, ".jpg")) return stbi_write_jpg(
filename, w, h, c, tensor_data, 90);
if (ends_with(filename, ".bmp")) return stbi_write_bmp(
filename, w, h, c, tensor_data);
if (ends_with(filename, ".tga")) return stbi_write_tga(
filename, w, h, c, tensor_data);
return stbi_write_png(filename, w, h, c, tensor_data, 0);)
return -1;
}
int at_get_num_interop_threads() {
PROTECT(return at::get_num_interop_threads();)
return -1;
}
int at_get_num_threads() {
PROTECT(return at::get_num_threads();)
return -1;
}
void at_set_num_interop_threads(int n_threads) {
PROTECT(at::set_num_interop_threads(n_threads);)
}
void at_set_num_threads(int n_threads) {
PROTECT(at::set_num_threads(n_threads);)
}
tensor at_resize_image(tensor tensor, int out_w, int out_h) {
PROTECT(
auto sizes = tensor->sizes();
if (sizes.size() != 3) throw std::invalid_argument(
"invalid number of dimensions, should be 3");
int h = sizes[0]; int w = sizes[1]; int c = sizes[2];
auto tmp_tensor = tensor->contiguous();
const unsigned char *tensor_data = (unsigned char *)tmp_tensor.data_ptr();
torch::Tensor out = torch::zeros({out_h, out_w, c}, at::ScalarType::Byte);
stbir_resize_uint8(tensor_data, w, h, 0, (unsigned char *)out.data_ptr(),
out_w, out_h, 0, c);
return new torch::Tensor(out);)
return nullptr;
}
void at_free(tensor t) { delete (t); }
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
tensor *outputs, int keep_graph, int create_graph) {
PROTECT(
vector<torch::autograd::Edge> roots;
for (int i = 0; i < ntensors; ++i)
roots.push_back(torch::autograd::impl::gradient_edge(*tensors[i]));
vector<torch::autograd::Edge> inputs_; for (int i = 0; i < ninputs; ++i) {
if (!inputs[i]->requires_grad())
throw std::invalid_argument(
"one of the input tensor does not use set_requires_grad");
inputs_.push_back(torch::autograd::impl::gradient_edge(*inputs[i]));
}
vector<torch::autograd::Variable>
grads;
for (int i = 0; i < ntensors; ++i)
grads.push_back(torch::ones_like(*tensors[i]));
auto vl = torch::autograd::Engine::get_default_engine().execute(roots, grads, keep_graph, create_graph, false, inputs_);
for (int i = 0; i < ninputs; ++i) {
outputs[i] = static_cast<tensor>(new torch::autograd::Variable(vl[i]));
})
}
optimizer ato_adam(double learning_rate, double beta1, double beta2,
double weight_decay) {
PROTECT(auto options = torch::optim::AdamOptions(learning_rate)
.betas(std::tuple<double, double>(beta1, beta2))
.weight_decay(weight_decay);
return new torch::optim::Adam(vector<torch::Tensor>(), options);)
return nullptr;
}
optimizer ato_adamw(double learning_rate, double beta1, double beta2,
double weight_decay) {
PROTECT(auto options = torch::optim::AdamWOptions(learning_rate)
.betas(std::tuple<double, double>(beta1, beta2))
.weight_decay(weight_decay);
return new torch::optim::AdamW(vector<torch::Tensor>(), options);)
return nullptr;
}
optimizer ato_rms_prop(double learning_rate, double alpha, double eps,
double weight_decay, double momentum, int centered) {
PROTECT(auto options = torch::optim::RMSpropOptions(learning_rate)
.alpha(alpha)
.eps(eps)
.weight_decay(weight_decay)
.momentum(momentum)
.centered(centered != 0);
return new torch::optim::RMSprop(vector<torch::Tensor>(), options);)
return nullptr;
}
optimizer ato_sgd(double learning_rate, double momentum, double dampening,
double weight_decay, int nesterov) {
PROTECT(auto options = torch::optim::SGDOptions(learning_rate)
.momentum(momentum)
.dampening(dampening)
.weight_decay(weight_decay)
.nesterov(nesterov);
return new torch::optim::SGD(vector<torch::Tensor>(), options);)
return nullptr;
}
// NOTE. backward compat as param group (#261) not updated yet.
// Deprecated
void ato_add_parameters_old(optimizer t, tensor *tensors, int ntensors) {
PROTECT(for (int i = 0; i < ntensors; ++i) t->param_groups()[0]
.params()
.push_back(*(tensors[i]));)
}
void ato_add_parameter(optimizer t, tensor tensor, size_t group) {
PROTECT(auto &groups = t->param_groups(); while (groups.size() <= group) {
groups.push_back(
torch::optim::OptimizerParamGroup({}, t->defaults().clone()));
} groups[group]
.params()
.push_back(*tensor);)
}
template <class T> void set_lr(optimizer t, double learning_rate) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
p->lr(learning_rate);
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
p2->lr(learning_rate);
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_set_learning_rate(optimizer t, double learning_rate) {
PROTECT(set_lr<torch::optim::AdamOptions>(t, learning_rate);
set_lr<torch::optim::AdamWOptions>(t, learning_rate);
set_lr<torch::optim::RMSpropOptions>(t, learning_rate);
set_lr<torch::optim::SGDOptions>(t, learning_rate);)
}
template <class T>
void set_lr_group(optimizer t, size_t group, double learning_rate) {
auto &param_group = t->param_groups().at(group);
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p = dynamic_cast<T *>(d)) {
p->lr(learning_rate);
}
}
void ato_set_learning_rate_group(optimizer t, size_t group,
double learning_rate) {
PROTECT(set_lr_group<torch::optim::AdamOptions>(t, group, learning_rate);
set_lr_group<torch::optim::AdamWOptions>(t, group, learning_rate);
set_lr_group<torch::optim::RMSpropOptions>(t, group, learning_rate);
set_lr_group<torch::optim::SGDOptions>(t, group, learning_rate);)
}
void ato_constant_pad_nd(tensor *out__, tensor self, int64_t *pad_data, int pad_len, scalar value) {
PROTECT(
auto outputs__ = torch::constant_pad_nd(*self, torch::IntArrayRef(pad_data, pad_len), *value);
out__[0] = new torch::Tensor(outputs__);
)
}
// ============ set/get learning rates ==============================
// TT. added for learning rate scheduler
// lr scheduler APIs will be in Pytorch 1.9?
// Ref. https://github.com/pytorch/pytorch/issues/50577
template <class T> void set_lrs(optimizer t, double *learning_rates) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
for (int i = 0; i < t->param_groups().size(); i++) {
auto &param_group = t->param_groups()[i];
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
p2->lr(learning_rates[i]);
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_set_learning_rates(optimizer t, double *lrs, int lrs_num) {
PROTECT(int ngroup = t->param_groups().size(); if (lrs == nullptr) {
throw std::invalid_argument("Input learning rates should not be null");
} if (ngroup != lrs_num) {
throw std::invalid_argument("Size of input learning rates is unequal to "
"number of parameter groups.");
} set_lrs<torch::optim::AdamOptions>(t, lrs);
set_lrs<torch::optim::AdamWOptions>(t, lrs);
set_lrs<torch::optim::RMSpropOptions>(t, lrs);
set_lrs<torch::optim::SGDOptions>(t, lrs);)
}
template <class T> void get_lrs(optimizer t, vector<double> &lrs) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
for (std::size_t i = 0; i < t->param_groups().size(); i++) {
auto &param_group = t->param_groups()[i];
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
lrs[i] = p2->lr();
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_get_learning_rates(optimizer t, double *lrs, int *param_group_num) {
PROTECT(int ngroup = t->param_groups().size();
static vector<double> learning_rates(ngroup);
get_lrs<torch::optim::AdamOptions>(t, learning_rates);
get_lrs<torch::optim::AdamWOptions>(t, learning_rates);
get_lrs<torch::optim::RMSpropOptions>(t, learning_rates);
get_lrs<torch::optim::SGDOptions>(t, learning_rates);
for (int i = 0; i < ngroup; i++) {
lrs[i] = learning_rates[i];
} param_group_num[0] = ngroup;)
}
int64_t ato_param_group_num(optimizer t) {
PROTECT(return t->param_groups().size();)
return -1;
}
void ato_add_param_group(optimizer t, tensor *tensors, int ntensors) {
PROTECT(
// ref.https://github.com/pytorch/pytorch/blob/45c5bac87049a9a0593d4d5f060cc8d4b0f83db8/test/cpp/api/optim.cpp#L153-L154
std::vector<torch::Tensor> params; for (int i = 0; i < ntensors; ++i) {
params.push_back(*tensors[i]);
} auto &params_groups = t->param_groups();
params_groups.push_back(torch::optim::OptimizerParamGroup(params));
std::unique_ptr<torch::optim::OptimizerOptions> opt =
params_groups[0].options().clone();
int ngroup = params_groups.size();
params_groups[ngroup - 1].set_options(std::move(opt));
// set default learning rate.
double default_lr = 0.001;
ato_set_learning_rate_group(t, ngroup - 1, default_lr);)
}
// ============ End of set/get learning rates ==============================
void ato_set_momentum(optimizer t, double momentum) {
PROTECT(
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto adam = dynamic_cast<torch::optim::AdamOptions *>(d)) {
auto betas = adam->betas();
adam->betas(std::tuple<double, double>(momentum, get<1>(betas)));
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto adam2 = dynamic_cast<torch::optim::AdamOptions *>(d)) {
adam2->betas(std::tuple<double, double>(momentum, get<1>(betas)));
} else
throw std::invalid_argument("unexpected param group type");
}
} else if (auto adamw = dynamic_cast<torch::optim::AdamWOptions *>(d)) {
auto betas = adamw->betas();
adamw->betas(std::tuple<double, double>(momentum, get<1>(betas)));
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto adamw2 = dynamic_cast<torch::optim::AdamWOptions *>(d)) {
adamw2->betas(std::tuple<double, double>(momentum, get<1>(betas)));
} else
throw std::invalid_argument("unexpected param group type");
}
} else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions *>(d)) {
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto rms2 = dynamic_cast<torch::optim::RMSpropOptions *>(d)) {
rms2->momentum(momentum);
} else
throw std::invalid_argument("unexpected param group type");
}
} else if (auto sgd = dynamic_cast<torch::optim::SGDOptions *>(d)) {
sgd->momentum(momentum);
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto sgd2 = dynamic_cast<torch::optim::SGDOptions *>(d)) {
sgd2->momentum(momentum);
} else
throw std::invalid_argument("unexpected param group type");
}
} else throw std::invalid_argument("unexpected optimizer");)
}
void ato_set_momentum_group(optimizer t, size_t group, double momentum) {
PROTECT(
auto &param_group = t->param_groups().at(group);
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto adam = dynamic_cast<torch::optim::AdamOptions *>(d)) {
auto betas = adam->betas();
adam->betas(std::tuple<double, double>(momentum, get<1>(betas)));
} else if (auto adamw = dynamic_cast<torch::optim::AdamWOptions *>(d)) {
auto betas = adamw->betas();
adamw->betas(std::tuple<double, double>(momentum, get<1>(betas)));
} else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions *>(d)) {
rms->momentum(momentum);
} if (auto sgd = dynamic_cast<torch::optim::SGDOptions *>(d)) {
sgd->momentum(momentum);
} else throw std::invalid_argument("unexpected optimizer");)
}
template <class T> void set_weight_decay(optimizer t, double weight_decay) {
torch::optim::OptimizerOptions *d = &(t->defaults());
if (auto p = dynamic_cast<T *>(d)) {
p->weight_decay(weight_decay);
for (auto &param_group : t->param_groups()) {
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p2 = dynamic_cast<T *>(d)) {
p2->weight_decay(weight_decay);
} else
throw std::invalid_argument("unexpected param group type");
}
}
}
void ato_set_weight_decay(optimizer t, double weight_decay) {
PROTECT(set_weight_decay<torch::optim::AdamOptions>(t, weight_decay);
set_weight_decay<torch::optim::AdamWOptions>(t, weight_decay);
set_weight_decay<torch::optim::RMSpropOptions>(t, weight_decay);
set_weight_decay<torch::optim::SGDOptions>(t, weight_decay);)
}
template <class T>
void set_weight_decay_group(optimizer t, size_t group, double weight_decay) {
auto &param_group = t->param_groups().at(group);
torch::optim::OptimizerOptions *d = &(param_group.options());
if (auto p = dynamic_cast<T *>(d)) {
p->weight_decay(weight_decay);
}
}
void ato_set_weight_decay_group(optimizer t, size_t group,
double weight_decay) {
PROTECT(
set_weight_decay_group<torch::optim::AdamOptions>(t, group, weight_decay);
set_weight_decay_group<torch::optim::AdamWOptions>(t, group,
weight_decay);
set_weight_decay_group<torch::optim::RMSpropOptions>(t, group,
weight_decay);
set_weight_decay_group<torch::optim::SGDOptions>(t, group, weight_decay);)
}
void ato_zero_grad(optimizer t) { PROTECT(t->zero_grad();) }
void ato_step(optimizer t) { PROTECT(t->step();) }
void ato_free(optimizer t) { delete (t); }
scalar ats_int(int64_t v) {
PROTECT(return new torch::Scalar(v);)
return nullptr;
}
scalar ats_float(double v) {
PROTECT(return new torch::Scalar(v);)
return nullptr;
}
int64_t ats_to_int(scalar s) {
PROTECT(return s->toLong();)
return -1;
}
double ats_to_float(scalar s) {
PROTECT(return s->toDouble();)
return 0.;
}
char *ats_to_string(scalar s) {
PROTECT(using namespace at; std::ostringstream oss; oss << (*s);
return strdup(oss.str().c_str());)
return nullptr;
}
void ats_free(scalar s) { delete (s); }
int atc_cuda_device_count() {
PROTECT(return torch::cuda::device_count();)
return -1;
}
int atc_cuda_is_available() {
PROTECT(return torch::cuda::is_available();)
return -1;
}
int atc_cudnn_is_available() {
PROTECT(return torch::cuda::cudnn_is_available();)
return -1;
}
void atc_set_benchmark_cudnn(int b) {
at::globalContext().setBenchmarkCuDNN(b);
}
module atm_load(char *filename) {
PROTECT(return new torch::jit::script::Module(torch::jit::load(filename));)
return nullptr;
}
module atm_load_on_device(char *filename, int device) {
PROTECT(return new torch::jit::script::Module(
torch::jit::load(filename, device_of_int(device)));)
return nullptr;
}
module atm_load_str(char *data, size_t sz) {
PROTECT(std::istringstream stream(std::string(data, sz));
return new torch::jit::script::Module(torch::jit::load(stream));)
return nullptr;
}
module atm_load_str_on_device(char *data, size_t sz, int device) {
PROTECT(std::istringstream stream(std::string(data, sz));
return new torch::jit::script::Module(
torch::jit::load(stream, device_of_int(device)));)
return nullptr;
}
tensor atm_forward(module m, tensor *tensors, int ntensors) {
PROTECT(std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < ntensors; ++i) inputs.push_back(*(tensors[i]));
torch::jit::IValue output = m->forward(std::move(inputs));
if (!output.isTensor()) throw std::invalid_argument(
"forward did not return a tensor");
return new torch::Tensor(output.toTensor());)
return nullptr;
}
ivalue atm_forward_(module m, ivalue *ivalues, int nivalues) {
PROTECT(std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < nivalues; ++i) inputs.push_back(*(ivalues[i]));
torch::jit::IValue output = m->forward(std::move(inputs));
return new torch::jit::IValue(output);)
return nullptr;
}
tensor atm_method(module m, char *method_name, tensor *tensors, int ntensors) {
PROTECT(std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < ntensors; ++i) inputs.push_back(*(tensors[i]));
torch::jit::IValue output =
m->get_method(method_name)(std::move(inputs));
if (!output.isTensor()) throw std::invalid_argument(
"method did not return a tensor");
return new torch::Tensor(output.toTensor());)
return nullptr;
}
ivalue atm_method_(module m, char *method_name, ivalue *ivalues, int nivalues) {
PROTECT(std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < nivalues; ++i) inputs.push_back(*(ivalues[i]));
torch::jit::IValue output =
m->get_method(method_name)(std::move(inputs));
return new torch::jit::IValue(output);)
return nullptr;
}
void atm_eval(module m) { PROTECT(m->eval();) }
void atm_train(module m) { PROTECT(m->train();) }
void atm_free(module m) { delete (m); }
void atm_save(module m, char *filename) { PROTECT(m->save(filename);) }
void atm_to(module m, int device, int dtype, bool non_blocking) {
PROTECT(m->to(device_of_int(device), at::ScalarType(dtype), non_blocking);)
}
int atm_get_profiling_mode() {
PROTECT(return torch::jit::getProfilingMode();)
return 0;
}
void atm_set_profiling_mode(int b) {
PROTECT(torch::jit::getProfilingMode() = (bool)b;)
}
void atm_named_parameters(module m, void *data,
void (*f)(void *, char *, tensor)) {
PROTECT(for (const auto &p
: m->named_parameters()) {
auto v = p.value;
f(data, (char *)p.name.c_str(), new torch::Tensor(v));
})
}
ivalue ati_tensor(tensor t) {
PROTECT(return new torch::jit::IValue(*t);)
return nullptr;
}
ivalue ati_int(int64_t i) {
PROTECT(return new torch::jit::IValue(i);)
return nullptr;
}
ivalue ati_double(double d) {
PROTECT(return new torch::jit::IValue(d);)
return nullptr;
}
ivalue ati_bool(int i) {
PROTECT(return new torch::jit::IValue((bool)i);)
return nullptr;
}
ivalue ati_string(char *s) {
PROTECT(string str(s); return new torch::jit::IValue(str);)
return nullptr;
}
ivalue ati_none() {
PROTECT(return new torch::jit::IValue();)
return nullptr;
}
ivalue ati_tuple(ivalue *is, int nvalues) {
PROTECT(vector<torch::jit::IValue> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(torch::ivalue::Tuple::create(vec));)
return nullptr;
}
ivalue ati_generic_list(ivalue *is, int nvalues) {
PROTECT(c10::List<torch::jit::IValue> vec(c10::AnyType::get());
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(c10::List<torch::jit::IValue>(vec));)
return nullptr;
}
ivalue ati_generic_dict(ivalue *is, int nvalues) {
c10::Dict<torch::jit::IValue, torch::jit::IValue> dict(c10::AnyType::get(),
c10::AnyType::get());
PROTECT(for (int i = 0; i < nvalues; ++i)
dict.insert(*(is[2 * i]), *(is[2 * i + 1]));
return new torch::jit::IValue(dict);)
return nullptr;
}
ivalue ati_int_list(int64_t *is, int nvalues) {
PROTECT(c10::List<int64_t> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
return new torch::jit::IValue(vec);)
return nullptr;
}
ivalue ati_double_list(double *is, int nvalues) {
PROTECT(c10::List<double> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
return new torch::jit::IValue(vec);)
return nullptr;
}
ivalue ati_bool_list(char *is, int nvalues) {
PROTECT(c10::List<bool> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i] != 0);
return new torch::jit::IValue(vec);)
return nullptr;
}
ivalue ati_string_list(char **is, int nvalues) {
PROTECT(c10::List<string> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(string(is[i]));
return new torch::jit::IValue(vec);)
return nullptr;
}
ivalue ati_tensor_list(tensor *is, int nvalues) {
PROTECT(c10::List<at::Tensor> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(vec);)
return nullptr;
}
int ati_tag(ivalue i) {
PROTECT(
if (i->isNone()) return 0; else if (i->isTensor()) return 1;
else if (i->isDouble()) return 2; else if (i->isInt()) return 3;
else if (i->isBool()) return 4; else if (i->isTuple()) return 5;
else if (i->isIntList()) return 6; else if (i->isDoubleList()) return 7;
else if (i->isBoolList()) return 8; else if (i->isString()) return 9;
else if (i->isTensorList()) return 10; else if (i->isList()) return 12;
else if (i->isGenericDict()) return 13;
throw std::invalid_argument(("unsupported tag" + i->tagKind()).c_str());
return -1;)
return -1;
}
int64_t ati_to_int(ivalue i) {
PROTECT(return i->toInt();)
return -1;
}
double ati_to_double(ivalue i) {
PROTECT(return i->toDouble();)
return 0;
}
int ati_to_bool(ivalue i) {
PROTECT(return i->toBool();)
return -1;
}
char *ati_to_string(ivalue i) {
PROTECT(auto str = i->toStringRef(); return strdup(str.c_str());)
return nullptr;
}
tensor ati_to_tensor(ivalue i) {
PROTECT(return new torch::Tensor(i->toTensor());)
return nullptr;
}
int ati_length(ivalue i) {
PROTECT(if (i->isTuple()) return i->toTuple()->elements().size();
else if (i->isIntList()) return i->toIntList().size();
else if (i->isDoubleList()) return i->toDoubleList().size();
else if (i->isBoolList()) return i->toBoolList().size();
else if (i->isString()) return i->toStringRef().size();
else if (i->isTensorList()) return i->toTensorList().size();
else if (i->isList()) return i->toList().size();
else if (i->isGenericDict()) return i->toGenericDict().size();
throw std::invalid_argument(
("unsupported tag for length " + i->tagKind()).c_str());
return -1;)
return -1;
}
int ati_tuple_length(ivalue i) {
PROTECT(return i->toTuple()->elements().size();)
return -1;
}
void ati_to_tuple(ivalue i, ivalue *outputs, int noutputs) {
PROTECT(auto vec = i->toTuple()->elements(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected tuple size");
} for (int i = 0; i < noutputs;
++i) outputs[i] = new torch::jit::IValue(vec[i]);)
}
void ati_to_generic_list(ivalue i, ivalue *outputs, int noutputs) {
PROTECT(auto vec = i->toList(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list size");
} for (int i = 0; i < noutputs; ++i) outputs[i] =
new torch::jit::IValue(vec[i]);)
}
void ati_to_generic_dict(ivalue i, ivalue *outputs, int noutputs) {
PROTECT(auto dict = i->toGenericDict(); if (dict.size() != noutputs) {
throw std::invalid_argument("unexpected dict size");
} int k = 0;
for (auto it = dict.begin(); it != dict.end(); ++it) {
outputs[k++] = new torch::jit::IValue(it->key());
outputs[k++] = new torch::jit::IValue(it->value());
})
}
void ati_to_int_list(ivalue i, int64_t *outputs, int noutputs) {
PROTECT(auto vec = i->toIntList(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list<int> size");
} for (int i = 0; i < noutputs; ++i) outputs[i] = vec[i];)
}
void ati_to_double_list(ivalue i, double *outputs, int noutputs) {
PROTECT(auto vec = i->toDoubleList(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list<double> size");
} for (int i = 0; i < noutputs; ++i) outputs[i] = vec[i];)
}
void ati_to_bool_list(ivalue i, char *outputs, int noutputs) {
PROTECT(auto vec = i->toBoolList(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list<bool> size");
} for (int i = 0; i < noutputs; ++i) outputs[i] = vec[i];)
}
void ati_to_tensor_list(ivalue i, tensor *outputs, int noutputs) {
PROTECT(auto vec = i->toTensorList(); if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list<tensor> size");
} for (int i = 0; i < noutputs; ++i) outputs[i] = new torch::Tensor(vec[i]);)
}
void ati_free(ivalue i) { delete (i); }
#include "torch_api_generated.cpp.h"