231 lines
8.1 KiB
C++
231 lines
8.1 KiB
C++
#ifndef __TORCH_API_H__
|
|
#define __TORCH_API_H__
|
|
#include <stdint.h>
|
|
|
|
#ifdef __cplusplus
|
|
|
|
#include <stdexcept>
|
|
#include <torch/torch.h>
|
|
using namespace std;
|
|
thread_local char *torch_last_err = nullptr;
|
|
|
|
extern "C" {
|
|
typedef torch::Tensor *tensor;
|
|
typedef torch::Scalar *scalar;
|
|
typedef torch::optim::Optimizer *optimizer;
|
|
typedef torch::jit::script::Module *module;
|
|
typedef torch::jit::IValue *ivalue;
|
|
#define PROTECT(x) \
|
|
try { \
|
|
x \
|
|
} catch (const exception &e) { \
|
|
torch_last_err = strdup(e.what()); \
|
|
}
|
|
#else
|
|
typedef void *tensor;
|
|
typedef void *optimizer;
|
|
typedef void *scalar;
|
|
typedef void *module;
|
|
typedef void *ivalue;
|
|
#endif
|
|
|
|
char *get_and_reset_last_err(); // thread-local
|
|
void at_manual_seed(int64_t);
|
|
tensor at_new_tensor();
|
|
tensor at_tensor_of_blob(void *data, int64_t *dims, size_t ndims,
|
|
int64_t *strides, size_t nstrides, int type,
|
|
int device);
|
|
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims,
|
|
size_t element_size_in_bytes, int type);
|
|
void at_copy_data(tensor tensor, void *vs, size_t numel,
|
|
size_t element_size_in_bytes);
|
|
tensor at_shallow_clone(tensor);
|
|
|
|
void *at_data_ptr(tensor);
|
|
int at_defined(tensor);
|
|
int at_is_mkldnn(tensor);
|
|
int at_is_sparse(tensor);
|
|
int at_device(tensor);
|
|
size_t at_dim(tensor);
|
|
void at_shape(tensor, int64_t *);
|
|
void at_stride(tensor, int64_t *);
|
|
int at_scalar_type(tensor);
|
|
int at_is_contiguous(tensor);
|
|
|
|
void at__amp_non_finite_check_and_unscale(tensor, tensor, tensor);
|
|
|
|
void at_autocast_clear_cache();
|
|
int at_autocast_decrement_nesting();
|
|
int at_autocast_increment_nesting();
|
|
bool at_autocast_is_enabled();
|
|
bool at_autocast_set_enabled(bool b);
|
|
|
|
void at_backward(tensor, int, int);
|
|
int at_requires_grad(tensor);
|
|
int at_grad_set_enabled(int);
|
|
|
|
tensor at_get(tensor, int index);
|
|
void at_fill_double(tensor, double);
|
|
void at_fill_int64(tensor, int64_t);
|
|
|
|
double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
|
int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
|
void at_set_double_value_at_indexes(tensor, int *indexes, int indexes_len,
|
|
double v);
|
|
void at_set_int64_value_at_indexes(tensor, int *indexes, int indexes_len,
|
|
int64_t v);
|
|
|
|
void at_copy_(tensor dst, tensor src);
|
|
|
|
void at_print(tensor);
|
|
char *at_to_string(tensor, int line_size);
|
|
void at_save(tensor, char *filename);
|
|
tensor at_load(char *filename);
|
|
tensor at_load_image(char *filename);
|
|
int at_save_image(tensor, char *filename);
|
|
tensor at_resize_image(tensor, int w, int h);
|
|
|
|
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors,
|
|
char *filename);
|
|
/* [at_load_multi] takes as input an array of nullptr for [tensors]. */
|
|
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors,
|
|
char *filename);
|
|
/* [at_load_multi_] takes as input an array of allocation [tensors]. */
|
|
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors,
|
|
char *filename);
|
|
|
|
void at_load_callback(char *filename, void *data,
|
|
void (*f)(void *, char *, tensor));
|
|
void at_load_callback_with_device(char *filename, void *data,
|
|
void (*f)(void *, char *, tensor),
|
|
int device_id);
|
|
|
|
int at_get_num_interop_threads();
|
|
|
|
int at_get_num_threads();
|
|
|
|
void at_set_num_interop_threads(int n_threads);
|
|
|
|
void at_set_num_threads(int n_threads);
|
|
|
|
void at_free(tensor);
|
|
|
|
void at_run_backward(tensor *tensors, int ntensors, tensor *inputs, int ninputs,
|
|
tensor *outputs, int keep_graph, int create_graph);
|
|
|
|
optimizer ato_adam(double learning_rate, double beta1, double beta2,
|
|
double weight_decay);
|
|
optimizer ato_adamw(double learning_rate, double beta1, double beta2,
|
|
double weight_decay);
|
|
optimizer ato_rms_prop(double learning_rate, double alpha, double eps,
|
|
double weight_decay, double momentum, int centered);
|
|
optimizer ato_sgd(double learning_rate, double momentum, double dampening,
|
|
double weight_decay, int nesterov);
|
|
// NOTE. switch back as param group #261 not updated yet.
|
|
// Backward compat
|
|
void ato_add_parameters_old(optimizer, tensor *, int ntensors);
|
|
void ato_add_parameter(optimizer, tensor, size_t group);
|
|
void ato_set_learning_rate(optimizer, double learning_rate);
|
|
void ato_set_momentum(optimizer, double momentum);
|
|
void ato_set_learning_rate_group(optimizer, size_t group, double learning_rate);
|
|
void ato_set_momentum_group(optimizer, size_t group, double momentum);
|
|
void ato_set_weight_decay(optimizer t, double weight_decay);
|
|
void ato_set_weight_decay_group(optimizer t, size_t group, double weight_decay);
|
|
void ato_zero_grad(optimizer);
|
|
void ato_step(optimizer);
|
|
void ato_free(optimizer);
|
|
|
|
// TT. APIs for learning rate scheduler
|
|
void ato_set_learning_rates(optimizer, double *learning_rates, int lrs_num);
|
|
int64_t ato_param_group_num(optimizer);
|
|
void ato_get_learning_rates(optimizer, double *lrs, int *ngroup);
|
|
void ato_add_param_group(optimizer, tensor *params, int param_num);
|
|
|
|
// TT. added option pad value. Original generated API `atg_constant_pad_nd` no
|
|
// option of adding pad value.
|
|
void ato_constant_pad_nd(tensor *, tensor self, int64_t *pad_data, int pad_len,
|
|
scalar value);
|
|
|
|
scalar ats_int(int64_t);
|
|
scalar ats_float(double);
|
|
int64_t ats_to_int(scalar);
|
|
double ats_to_float(scalar);
|
|
char *ats_to_string(scalar);
|
|
void ats_free(scalar);
|
|
|
|
int atc_cuda_device_count();
|
|
int atc_cuda_is_available();
|
|
int atc_cudnn_is_available();
|
|
void atc_set_benchmark_cudnn(int b);
|
|
void atc_synchronize(int64_t device_index);
|
|
|
|
// TT. added for testing qt
|
|
// ref. https://github.com/pytorch/pytorch/issues/14959
|
|
int atc_get_device();
|
|
void atc_set_device(int device_index);
|
|
|
|
module atm_load(char *);
|
|
module atm_load_on_device(char *, int device);
|
|
module atm_load_str(char *, size_t sz);
|
|
module atm_load_str_on_device(char *, size_t sz, int device);
|
|
tensor atm_forward(module, tensor *tensors, int ntensors);
|
|
ivalue atm_forward_(module, ivalue *ivalues, int nivalues);
|
|
tensor atm_method(module, char *method_name, tensor *tensors, int ntensors);
|
|
ivalue atm_method_(module, char *method_name, ivalue *ivalues, int nivalues);
|
|
void atm_free(module);
|
|
void atm_to(module m, int device, int dtype, bool non_blocking);
|
|
void atm_save(module m, char *);
|
|
int atm_get_profiling_mode();
|
|
void atm_set_profiling_mode(int);
|
|
void atm_named_parameters(module, void *data,
|
|
void (*f)(void *, char *, tensor));
|
|
void atm_eval(module);
|
|
void atm_train(module);
|
|
|
|
ivalue ati_none();
|
|
ivalue ati_tensor(tensor);
|
|
ivalue ati_int(int64_t);
|
|
ivalue ati_double(double);
|
|
ivalue ati_bool(int);
|
|
ivalue ati_string(char *);
|
|
ivalue ati_tuple(ivalue *, int);
|
|
ivalue ati_generic_list(ivalue *, int);
|
|
ivalue ati_generic_dict(ivalue *, int);
|
|
ivalue ati_int_list(int64_t *, int);
|
|
ivalue ati_double_list(double *, int);
|
|
ivalue ati_bool_list(char *, int);
|
|
ivalue ati_string_list(char **, int);
|
|
ivalue ati_tensor_list(tensor *, int);
|
|
|
|
tensor ati_to_tensor(ivalue);
|
|
int64_t ati_to_int(ivalue);
|
|
double ati_to_double(ivalue);
|
|
char *ati_to_string(ivalue);
|
|
int ati_to_bool(ivalue);
|
|
int ati_length(ivalue);
|
|
int ati_tuple_length(ivalue);
|
|
void ati_to_tuple(ivalue, ivalue *, int);
|
|
void ati_to_generic_list(ivalue, ivalue *, int);
|
|
void ati_to_generic_dict(ivalue, ivalue *, int);
|
|
void ati_to_int_list(ivalue, int64_t *, int);
|
|
void ati_to_double_list(ivalue, double *, int);
|
|
void ati_to_bool_list(ivalue, char *, int);
|
|
void ati_to_tensor_list(ivalue, tensor *, int);
|
|
|
|
int ati_tag(ivalue);
|
|
|
|
void ati_free(ivalue);
|
|
|
|
#include "torch_api_generated.h"
|
|
|
|
#ifdef __cplusplus
|
|
}; // extern "C"
|
|
|
|
std::vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len);
|
|
at::Device device_of_int(int d);
|
|
c10::List<c10::optional<torch::Tensor>> of_carray_tensor_opt(torch::Tensor **vs,
|
|
int len);
|
|
|
|
#endif
|
|
#endif
|