initial commit
This commit is contained in:
commit
0b2b0c1bf2
21
.gitignore
vendored
Normal file
21
.gitignore
vendored
Normal file
|
@ -0,0 +1,21 @@
|
|||
.directory
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
*.dat
|
||||
*.log
|
||||
*.bak
|
||||
*.out
|
||||
*.dot
|
||||
*.rs
|
||||
|
||||
*.txt
|
||||
*.json
|
||||
|
||||
target/
|
||||
_build/
|
||||
data/
|
||||
gen/.merlin
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
__pycache__
|
1
dune-project
Normal file
1
dune-project
Normal file
|
@ -0,0 +1 @@
|
|||
(lang dune 2.2)
|
3
gen/dune
Normal file
3
gen/dune
Normal file
|
@ -0,0 +1,3 @@
|
|||
(executables
|
||||
(names gen)
|
||||
(libraries base stdio yaml))
|
655
gen/gen.ml
Normal file
655
gen/gen.ml
Normal file
|
@ -0,0 +1,655 @@
|
|||
(* Automatically generate the C++ -> C -> rust bindings.
|
||||
This takes as input the Descriptions.yaml file that gets generated when
|
||||
building PyTorch from source.
|
||||
|
||||
Run with: dune exec gen/gen.exe
|
||||
*)
|
||||
open Base
|
||||
open Stdio
|
||||
|
||||
let excluded_functions =
|
||||
Set.of_list
|
||||
(module String)
|
||||
[
|
||||
"multi_margin_loss";
|
||||
"multi_margin_loss_out";
|
||||
"log_softmax_backward_data";
|
||||
"softmax_backward_data";
|
||||
"clone";
|
||||
"copy_";
|
||||
"conv_transpose2d_backward_out";
|
||||
"conv_transpose3d_backward_out";
|
||||
"slow_conv_transpose2d_backward_out";
|
||||
"slow_conv_transpose3d_backward_out";
|
||||
"slow_conv3d_backward_out";
|
||||
"normal";
|
||||
"_cufft_set_plan_cache_max_size";
|
||||
"_cufft_clear_plan_cache";
|
||||
"backward";
|
||||
"set_data";
|
||||
"_amp_non_finite_check_and_unscale_";
|
||||
"_cummin_helper";
|
||||
"_cummax_helper";
|
||||
"retain_grad";
|
||||
]
|
||||
|
||||
let no_tensor_options =
|
||||
Set.of_list
|
||||
(module String)
|
||||
[
|
||||
"zeros_like";
|
||||
"empty_like";
|
||||
"full_like";
|
||||
"ones_like";
|
||||
"rand_like";
|
||||
"randint_like";
|
||||
"randn_like";
|
||||
]
|
||||
|
||||
let prefixed_functions =
|
||||
Set.of_list
|
||||
(module String)
|
||||
[ "add"; "add_"; "div"; "div_"; "mul"; "mul_"; "sub"; "sub_"; "nll_loss" ]
|
||||
|
||||
let excluded_prefixes = [ "_thnn_"; "_th_"; "thnn_"; "th_" ]
|
||||
|
||||
let excluded_suffixes = [ "_forward"; "_forward_out" ]
|
||||
|
||||
let yaml_error yaml ~msg =
|
||||
Printf.failwithf "%s, %s" msg (Yaml.to_string_exn yaml) ()
|
||||
|
||||
let extract_bool = function
|
||||
| `Bool b -> b
|
||||
| `String "true" -> true
|
||||
| `String "false" -> false
|
||||
| yaml -> yaml_error yaml ~msg:"expected bool"
|
||||
|
||||
let extract_list = function
|
||||
| `A l -> l
|
||||
| yaml -> yaml_error yaml ~msg:"expected list"
|
||||
|
||||
let extract_map = function
|
||||
| `O map -> Map.of_alist_exn (module String) map
|
||||
| yaml -> yaml_error yaml ~msg:"expected map"
|
||||
|
||||
let extract_string = function
|
||||
| `String s -> s
|
||||
(* The yaml spec for torch uses n which is converted to a bool. *)
|
||||
| `Bool b -> if b then "y" else "n"
|
||||
| `Float f -> Float.to_string f
|
||||
| yaml -> yaml_error yaml ~msg:"expected string"
|
||||
|
||||
module Func = struct
|
||||
type arg_type =
|
||||
| Bool
|
||||
| Int64
|
||||
| Double
|
||||
| Tensor
|
||||
| TensorOption
|
||||
| IntList
|
||||
| TensorList
|
||||
| TensorOptions
|
||||
| Scalar
|
||||
| ScalarType
|
||||
| Device
|
||||
| String
|
||||
|
||||
type arg = {
|
||||
arg_name : string;
|
||||
arg_type : arg_type;
|
||||
default_value : string option;
|
||||
}
|
||||
|
||||
type t = {
|
||||
name : string;
|
||||
args : arg list;
|
||||
returns : [ `fixed of int | `dynamic ];
|
||||
(* number of tensors that are returned *)
|
||||
kind : [ `function_ | `method_ ];
|
||||
}
|
||||
|
||||
let arg_type_of_string str ~is_nullable =
|
||||
match String.lowercase str with
|
||||
| "bool" -> Some Bool
|
||||
| "int64_t" -> Some Int64
|
||||
| "double" -> Some Double
|
||||
| "booltensor" | "indextensor" | "tensor" ->
|
||||
Some (if is_nullable then TensorOption else Tensor)
|
||||
| "tensoroptions" -> Some TensorOptions
|
||||
| "intarrayref" | "intlist" -> Some IntList
|
||||
| "tensorlist" -> Some TensorList
|
||||
| "device" -> Some Device
|
||||
| "scalar" -> Some Scalar
|
||||
| "scalartype" -> Some ScalarType
|
||||
| "std::string" -> Some String
|
||||
| _ -> None
|
||||
|
||||
let c_typed_args_list t =
|
||||
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
|
||||
match arg_type with
|
||||
| IntList ->
|
||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorList ->
|
||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
|
||||
| otherwise ->
|
||||
let simple_type_cstring =
|
||||
match otherwise with
|
||||
| Bool -> "int"
|
||||
| Int64 -> "int64_t"
|
||||
| Double -> "double"
|
||||
| Tensor -> "tensor"
|
||||
| TensorOption -> "tensor"
|
||||
| ScalarType -> "int"
|
||||
| Device -> "int"
|
||||
| Scalar -> "scalar"
|
||||
| String | IntList | TensorList | TensorOptions -> assert false
|
||||
in
|
||||
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let c_args_list args =
|
||||
List.map args ~f:(fun { arg_name; arg_type; _ } ->
|
||||
match arg_type with
|
||||
| Scalar | Tensor -> "*" ^ arg_name
|
||||
| TensorOption ->
|
||||
Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
|
||||
| Bool -> "(bool)" ^ arg_name
|
||||
| IntList ->
|
||||
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name
|
||||
arg_name
|
||||
| String ->
|
||||
Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||
| TensorList ->
|
||||
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf
|
||||
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
|
||||
arg_name arg_name
|
||||
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
|
||||
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
|
||||
| _ -> arg_name)
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let c_call t =
|
||||
match t.kind with
|
||||
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
|
||||
| `method_ -> (
|
||||
match t.args with
|
||||
| head :: tail ->
|
||||
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
|
||||
| [] ->
|
||||
Printf.failwithf "Method calls should have at least one argument %s"
|
||||
t.name () )
|
||||
|
||||
let replace_map =
|
||||
Map.of_alist_exn
|
||||
(module String)
|
||||
[
|
||||
("t", "tr");
|
||||
("where", "where_");
|
||||
("view", "view_");
|
||||
("unsafe", "unsafe_");
|
||||
]
|
||||
|
||||
let rust_name name =
|
||||
let name =
|
||||
Map.find replace_map name |> Option.value ~default:name
|
||||
|> String.lowercase
|
||||
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
|
||||
in
|
||||
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
|
||||
|
||||
let c_rust_args_list t =
|
||||
List.map t.args ~f:(fun arg ->
|
||||
let an = arg.arg_name in
|
||||
let single_param = Printf.sprintf "%s_: %s" an in
|
||||
match arg.arg_type with
|
||||
| Bool -> single_param "c_int"
|
||||
| Int64 -> single_param "i64"
|
||||
| Double -> single_param "f64"
|
||||
| Tensor -> single_param "*mut C_tensor"
|
||||
| TensorOption -> single_param "*mut C_tensor"
|
||||
| Scalar -> single_param "*mut C_scalar"
|
||||
| ScalarType -> single_param "c_int"
|
||||
| Device -> single_param "c_int"
|
||||
| String -> Printf.sprintf "%s_ptr: *const u8, %s_len: c_int" an an
|
||||
| IntList -> Printf.sprintf "%s_data: *const i64, %s_len: c_int" an an
|
||||
| TensorList ->
|
||||
Printf.sprintf "%s_data: *const *mut C_tensor, %s_len: c_int" an an
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "%s_kind: c_int, %s_device: c_int" an an)
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let self_name = "self"
|
||||
|
||||
let input_name = "input"
|
||||
|
||||
let self_tensor arg =
|
||||
match arg.arg_type with
|
||||
| Tensor -> String.( = ) arg.arg_name self_name
|
||||
| _ -> false
|
||||
|
||||
let input_tensor arg =
|
||||
match arg.arg_type with
|
||||
| Tensor -> String.( = ) arg.arg_name input_name
|
||||
| _ -> false
|
||||
|
||||
let type_parameters t =
|
||||
let needs_scalar_parameter =
|
||||
List.exists t.args ~f:(fun arg ->
|
||||
match arg.arg_type with Scalar -> true | _ -> false)
|
||||
in
|
||||
let needs_type_parameter =
|
||||
List.exists t.args ~f:(fun arg ->
|
||||
match arg.arg_type with
|
||||
| TensorList | TensorOption -> true
|
||||
| _ -> false)
|
||||
in
|
||||
if needs_type_parameter && needs_scalar_parameter then
|
||||
"<T: Borrow<Tensor>, S: Into<Scalar>>"
|
||||
else if needs_type_parameter then "<T: Borrow<Tensor>>"
|
||||
else if needs_scalar_parameter then "<S: Into<Scalar>>"
|
||||
else ""
|
||||
|
||||
let rust_args_list t =
|
||||
match List.partition_tf t.args ~f:self_tensor with
|
||||
| [ self ], args_list -> (Some self, args_list)
|
||||
| _, _ -> (
|
||||
match List.partition_tf t.args ~f:input_tensor with
|
||||
| [ self ], args_list -> (Some self, args_list)
|
||||
| _, _ -> (None, t.args) )
|
||||
|
||||
let rust_typed_args_list t =
|
||||
let to_string args =
|
||||
List.map args ~f:(fun arg ->
|
||||
let rust_arg_type =
|
||||
match arg.arg_type with
|
||||
| Bool -> "bool"
|
||||
| Int64 ->
|
||||
if String.( = ) arg.arg_name "reduction" then "crate::Reduction"
|
||||
else "i64"
|
||||
| Double -> "f64"
|
||||
| Tensor -> "&Tensor"
|
||||
| TensorOption -> "Option<T>"
|
||||
| IntList -> "&[i64]"
|
||||
| TensorList -> "&[T]"
|
||||
| String -> "&str"
|
||||
| TensorOptions -> "(Kind, Device)"
|
||||
| Scalar -> "S"
|
||||
| ScalarType -> "Kind"
|
||||
| Device -> "Device"
|
||||
in
|
||||
Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type)
|
||||
|> String.concat ~sep:", "
|
||||
in
|
||||
let self_arg =
|
||||
if String.is_suffix t.name ~suffix:"_" then "&mut self" else "&self"
|
||||
in
|
||||
match List.partition_tf t.args ~f:self_tensor with
|
||||
| [ self ], args_list ->
|
||||
( Some self.arg_name,
|
||||
Printf.sprintf "%s, %s" self_arg (to_string args_list) )
|
||||
| _, _ -> (
|
||||
match List.partition_tf t.args ~f:input_tensor with
|
||||
| [ self ], args_list ->
|
||||
( Some self.arg_name,
|
||||
Printf.sprintf "%s, %s" self_arg (to_string args_list) )
|
||||
| _, _ -> (None, to_string t.args) )
|
||||
|
||||
let rust_return_type t ~fallible =
|
||||
let returns =
|
||||
match t.returns with
|
||||
| `fixed 1 -> "Tensor"
|
||||
| `fixed v ->
|
||||
List.init v ~f:(fun _ -> "Tensor")
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
|
||||
| `dynamic -> "Vec<Tensor>"
|
||||
in
|
||||
if fallible then Printf.sprintf " -> failure::Fallible<%s>" returns
|
||||
else Printf.sprintf " -> %s" returns
|
||||
|
||||
let rust_binding_args t ~self =
|
||||
List.map t.args ~f:(fun arg ->
|
||||
let name =
|
||||
if Option.value_map self ~default:false ~f:(String.( = ) arg.arg_name)
|
||||
then "self"
|
||||
else rust_name arg.arg_name
|
||||
in
|
||||
match arg.arg_type with
|
||||
| Tensor -> Printf.sprintf "%s.c_tensor" name
|
||||
| Scalar -> Printf.sprintf "%s.into().c_scalar" name
|
||||
| Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
|
||||
| ScalarType -> Printf.sprintf "%s.c_int()" name
|
||||
| Device -> Printf.sprintf "%s.c_int()" name
|
||||
| TensorOptions -> Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
|
||||
| String -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
|
||||
| IntList -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
|
||||
| TensorList ->
|
||||
Printf.sprintf "ptr_list(%s).as_ptr(), %s.len() as i32" name name
|
||||
| TensorOption ->
|
||||
Printf.sprintf
|
||||
"%s.map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor)" name
|
||||
| Int64 when String.( = ) name "reduction" -> "reduction.to_int()"
|
||||
| _ -> name)
|
||||
|> String.concat ~sep:",\n "
|
||||
end
|
||||
|
||||
exception Not_a_simple_arg
|
||||
|
||||
let read_yaml filename =
|
||||
let funcs =
|
||||
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
|
||||
In_channel.with_file filename ~f:In_channel.input_lines
|
||||
|> List.group ~break:(fun _ l ->
|
||||
String.length l > 0 && Char.( = ) l.[0] '-')
|
||||
|> List.concat_map ~f:(fun lines ->
|
||||
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
|
||||
in
|
||||
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
|
||||
List.filter_map funcs ~f:(fun yaml ->
|
||||
let map = extract_map yaml in
|
||||
let name = Map.find_exn map "name" |> extract_string in
|
||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||
let method_of =
|
||||
Map.find_exn map "method_of"
|
||||
|> extract_list |> List.map ~f:extract_string
|
||||
in
|
||||
let arguments = Map.find_exn map "arguments" |> extract_list in
|
||||
let returns =
|
||||
let is_tensor returns =
|
||||
let returns = extract_map returns in
|
||||
let return_type =
|
||||
Map.find_exn returns "dynamic_type" |> extract_string
|
||||
in
|
||||
String.( = ) return_type "Tensor"
|
||||
|| String.( = ) return_type "BoolTensor"
|
||||
|| String.( = ) return_type "IndexTensor"
|
||||
in
|
||||
let returns = Map.find_exn map "returns" |> extract_list in
|
||||
if List.for_all returns ~f:is_tensor then
|
||||
Some (`fixed (List.length returns))
|
||||
else
|
||||
match returns with
|
||||
| [ returns ] ->
|
||||
let return_type =
|
||||
Map.find_exn (extract_map returns) "dynamic_type"
|
||||
|> extract_string
|
||||
in
|
||||
if String.( = ) return_type "TensorList" then Some `dynamic
|
||||
else None
|
||||
| [] | _ :: _ :: _ -> None
|
||||
in
|
||||
let kind =
|
||||
if List.exists method_of ~f:(String.( = ) "namespace") then
|
||||
Some `function_
|
||||
else if List.exists method_of ~f:(String.( = ) "Tensor") then
|
||||
Some `method_
|
||||
else None
|
||||
in
|
||||
if
|
||||
(not deprecated)
|
||||
&& (not
|
||||
(List.exists excluded_prefixes ~f:(fun prefix ->
|
||||
String.is_prefix name ~prefix)))
|
||||
&& (not
|
||||
(List.exists excluded_suffixes ~f:(fun suffix ->
|
||||
String.is_suffix name ~suffix)))
|
||||
&& not (Set.mem excluded_functions name)
|
||||
then
|
||||
Option.both returns kind
|
||||
|> Option.bind ~f:(fun (returns, kind) ->
|
||||
try
|
||||
let args =
|
||||
List.filter_map arguments ~f:(fun arg ->
|
||||
let arg = extract_map arg in
|
||||
let arg_name =
|
||||
Map.find_exn arg "name" |> extract_string
|
||||
in
|
||||
let arg_type =
|
||||
Map.find_exn arg "dynamic_type" |> extract_string
|
||||
in
|
||||
let is_nullable =
|
||||
Map.find arg "is_nullable"
|
||||
|> Option.value_map ~default:false ~f:extract_bool
|
||||
in
|
||||
let default_value =
|
||||
Map.find arg "default" |> Option.map ~f:extract_string
|
||||
in
|
||||
match Func.arg_type_of_string arg_type ~is_nullable with
|
||||
| Some Scalar
|
||||
when Option.is_some default_value && not is_nullable ->
|
||||
None
|
||||
| Some TensorOptions
|
||||
when Option.is_some default_value
|
||||
&& Set.mem no_tensor_options name ->
|
||||
None
|
||||
| Some arg_type ->
|
||||
let arg_name =
|
||||
match (arg_name, arg_type) with
|
||||
| "self", Scalar -> "self_scalar"
|
||||
| _, _ -> arg_name
|
||||
in
|
||||
Some { Func.arg_name; arg_type; default_value }
|
||||
| None ->
|
||||
if Option.is_some default_value then None
|
||||
else raise Not_a_simple_arg)
|
||||
in
|
||||
Some { Func.name; args; returns; kind }
|
||||
with Not_a_simple_arg -> None)
|
||||
else None)
|
||||
|
||||
let p out_channel s =
|
||||
Printf.ksprintf
|
||||
(fun line ->
|
||||
Out_channel.output_string out_channel line;
|
||||
Out_channel.output_char out_channel '\n')
|
||||
s
|
||||
|
||||
let write_cpp funcs filename =
|
||||
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
|
||||
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
|
||||
let pc s = p out_cpp s in
|
||||
let ph s = p out_h s in
|
||||
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
pc "";
|
||||
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
ph "";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||
let c_typed_args_list = Func.c_typed_args_list func in
|
||||
match func.returns with
|
||||
| `fixed ntensors ->
|
||||
pc "void atg_%s(tensor *out__, %s) {" exported_name
|
||||
c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
if ntensors = 1 then
|
||||
pc " out__[0] = new torch::Tensor(outputs__);"
|
||||
else
|
||||
for i = 0 to ntensors - 1 do
|
||||
pc
|
||||
" out__[%d] = new \
|
||||
torch::Tensor(std::get<%d>(outputs__));"
|
||||
i i
|
||||
done;
|
||||
pc " )";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "void atg_%s(tensor *, %s);" exported_name
|
||||
c_typed_args_list
|
||||
| `dynamic ->
|
||||
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
(* the returned type is a C++ vector of tensors *)
|
||||
pc " int sz = outputs__.size();";
|
||||
pc
|
||||
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + \
|
||||
1) * sizeof(torch::Tensor*));";
|
||||
pc " for (int i = 0; i < sz; ++i)";
|
||||
pc " out__[i] = new torch::Tensor(outputs__[i]);";
|
||||
pc " out__[sz] = nullptr;";
|
||||
pc " return out__;";
|
||||
pc " )";
|
||||
pc " return nullptr;";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list)))
|
||||
|
||||
let write_fallible_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = p out_ml s in
|
||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
|
||||
pm "#[allow(clippy::all)]";
|
||||
pm "use torch_sys::*;";
|
||||
pm "use torch_sys::c_generated::*;";
|
||||
pm "use crate::{Device, Kind, Scalar, Tensor};";
|
||||
pm "use std::convert::Into;";
|
||||
pm "use std::borrow::Borrow;";
|
||||
pm "";
|
||||
pm "fn ptr_list<T: Borrow<Tensor>>(l: &[T]) -> Vec<*mut C_tensor> {";
|
||||
pm " l.iter().map(|x| x.borrow().c_tensor).collect()";
|
||||
pm "}";
|
||||
pm "";
|
||||
pm "impl Tensor {";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
|
||||
let rust_name = Func.rust_name exported_name in
|
||||
let self, rust_args_list = Func.rust_typed_args_list func in
|
||||
pm "";
|
||||
pm " pub fn f_%s%s(" rust_name (Func.type_parameters func);
|
||||
pm " %s" rust_args_list;
|
||||
pm " )%s {" (Func.rust_return_type func ~fallible:true);
|
||||
match func.returns with
|
||||
| `dynamic ->
|
||||
pm " let c_tensors = unsafe_torch_err!({";
|
||||
pm " atg_%s(" exported_name;
|
||||
pm " %s)});" (Func.rust_binding_args func ~self);
|
||||
pm " let mut r__ = vec![];";
|
||||
pm " let mut i = 0;";
|
||||
pm " loop {";
|
||||
pm " let c__ = unsafe{*c_tensors.add(i)};";
|
||||
pm " if c__.is_null() { break }";
|
||||
pm " r__.push(Tensor {c_tensor: c__});";
|
||||
pm " i += 1;";
|
||||
pm " }";
|
||||
pm " unsafe{libc::free(c_tensors as *mut libc::c_void)}";
|
||||
pm " Ok(r__)";
|
||||
pm " }"
|
||||
| `fixed ntensors ->
|
||||
pm " let mut c_tensors = [std::ptr::null_mut(); %d];"
|
||||
ntensors;
|
||||
pm " unsafe_torch_err!({";
|
||||
pm " atg_%s(c_tensors.as_mut_ptr()," exported_name;
|
||||
pm " %s" (Func.rust_binding_args func ~self);
|
||||
pm " ) });";
|
||||
let returns =
|
||||
if ntensors = 1 then "Tensor { c_tensor: c_tensors[0] }"
|
||||
else
|
||||
List.init ntensors
|
||||
~f:(Printf.sprintf "Tensor { c_tensor: c_tensors[%d] }")
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
|
||||
in
|
||||
pm " Ok(%s)" returns;
|
||||
pm " }");
|
||||
pm "}")
|
||||
|
||||
let write_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = p out_ml s in
|
||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
|
||||
pm "#[allow(clippy::all)]";
|
||||
pm "use crate::{Device, Kind, Scalar, Tensor};";
|
||||
pm "use std::convert::Into;";
|
||||
pm "use std::borrow::Borrow;";
|
||||
pm "";
|
||||
pm "impl Tensor {";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
|
||||
let rust_name = Func.rust_name exported_name in
|
||||
let rust_name, fallible_rust_name =
|
||||
if Set.mem prefixed_functions func.name then
|
||||
("g_" ^ rust_name, "f_" ^ rust_name)
|
||||
else (rust_name, "f_" ^ rust_name)
|
||||
in
|
||||
pm "";
|
||||
pm " pub fn %s%s(" rust_name (Func.type_parameters func);
|
||||
let _self, rust_args_list = Func.rust_typed_args_list func in
|
||||
pm " %s" rust_args_list;
|
||||
pm " )%s {" (Func.rust_return_type func ~fallible:false);
|
||||
let self, rust_args_list = Func.rust_args_list func in
|
||||
let self = if Option.is_some self then "self." else "Tensor::" in
|
||||
let rust_args_list =
|
||||
List.map rust_args_list ~f:(fun arg ->
|
||||
Func.rust_name arg.Func.arg_name)
|
||||
|> String.concat ~sep:", "
|
||||
in
|
||||
pm " %s%s(%s).unwrap()" self fallible_rust_name rust_args_list;
|
||||
pm " }");
|
||||
pm "}")
|
||||
|
||||
let write_ffi funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = p out_ml s in
|
||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
|
||||
pm "#[allow(clippy::all)]";
|
||||
pm "use crate::{C_scalar, C_tensor};";
|
||||
pm "use libc::c_int;";
|
||||
pm "";
|
||||
pm "extern \"C\" {";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||
match func.Func.returns with
|
||||
| `fixed _ ->
|
||||
pm " pub fn atg_%s(out__: *mut *mut C_tensor, %s);"
|
||||
exported_name
|
||||
(Func.c_rust_args_list func)
|
||||
| `dynamic ->
|
||||
pm " pub fn atg_%s(%s) -> *mut *mut C_tensor;" exported_name
|
||||
(Func.c_rust_args_list func));
|
||||
pm "}")
|
||||
|
||||
let methods =
|
||||
let c name args = { Func.name; args; returns = `fixed 1; kind = `method_ } in
|
||||
let ca arg_name arg_type =
|
||||
{ Func.arg_name; arg_type; default_value = None }
|
||||
in
|
||||
[
|
||||
c "grad" [ ca "self" Tensor ];
|
||||
c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ];
|
||||
c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ];
|
||||
c "to" [ ca "self" Tensor; ca "device" Device ];
|
||||
]
|
||||
|
||||
let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
|
||||
~fallible_wrapper_filename =
|
||||
let funcs = read_yaml yaml_filename in
|
||||
let funcs = methods @ funcs in
|
||||
printf "Generating code for %d functions.\n%!" (List.length funcs);
|
||||
(* Generate some unique names for overloaded functions. *)
|
||||
let funcs =
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
|
||||
|> Map.of_alist_multi (module String)
|
||||
|> Map.to_alist
|
||||
|> List.concat_map ~f:(fun (name, funcs) ->
|
||||
match funcs with
|
||||
| [] -> assert false
|
||||
| [ func ] -> [ (name, func) ]
|
||||
| funcs ->
|
||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args))
|
||||
|> List.mapi ~f:(fun i func ->
|
||||
( (if i = 0 then name else Printf.sprintf "%s%d" name i),
|
||||
func )))
|
||||
|> Map.of_alist_exn (module String)
|
||||
in
|
||||
write_cpp funcs cpp_filename;
|
||||
write_ffi funcs ffi_filename;
|
||||
write_wrapper funcs wrapper_filename;
|
||||
write_fallible_wrapper funcs fallible_wrapper_filename
|
||||
|
||||
let () =
|
||||
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
|
||||
~cpp_filename:"torch-sys/libtch/torch_api_generated"
|
||||
~ffi_filename:"torch-sys/src/c_generated.rs"
|
||||
~wrapper_filename:"src/wrappers/tensor_generated.rs"
|
||||
~fallible_wrapper_filename:"src/wrappers/tensor_fallible_generated.rs"
|
1
gen/gen.mli
Normal file
1
gen/gen.mli
Normal file
|
@ -0,0 +1 @@
|
|||
(* Intentionally left blank. *)
|
60108
third_party/pytorch/Declarations-v1.4.0.yaml
vendored
Normal file
60108
third_party/pytorch/Declarations-v1.4.0.yaml
vendored
Normal file
File diff suppressed because it is too large
Load Diff
63549
third_party/pytorch/Declarations-v1.5.0.yaml
vendored
Normal file
63549
third_party/pytorch/Declarations-v1.5.0.yaml
vendored
Normal file
File diff suppressed because it is too large
Load Diff
70
third_party/pytorch/LICENSE
vendored
Normal file
70
third_party/pytorch/LICENSE
vendored
Normal file
|
@ -0,0 +1,70 @@
|
|||
From PyTorch:
|
||||
|
||||
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
||||
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
||||
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
||||
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
||||
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
||||
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
||||
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
||||
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
||||
|
||||
From Caffe2:
|
||||
|
||||
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
|
||||
|
||||
All contributions by Facebook:
|
||||
Copyright (c) 2016 Facebook Inc.
|
||||
|
||||
All contributions by Google:
|
||||
Copyright (c) 2015 Google Inc.
|
||||
All rights reserved.
|
||||
|
||||
All contributions by Yangqing Jia:
|
||||
Copyright (c) 2015 Yangqing Jia
|
||||
All rights reserved.
|
||||
|
||||
All contributions from Caffe:
|
||||
Copyright(c) 2013, 2014, 2015, the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
All other contributions:
|
||||
Copyright(c) 2015, 2016 the respective contributors
|
||||
All rights reserved.
|
||||
|
||||
Caffe2 uses a copyright model similar to Caffe: each contributor holds
|
||||
copyright over their contributions to Caffe2. The project versioning records
|
||||
all such contribution and copyright details. If a contributor wants to further
|
||||
mark their specific copyright on a particular contribution, they should
|
||||
indicate their copyright solely in the commit message of the change when it is
|
||||
committed.
|
||||
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright
|
||||
notice, this list of conditions and the following disclaimer in the
|
||||
documentation and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
||||
and IDIAP Research Institute nor the names of its contributors may be
|
||||
used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
||||
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
POSSIBILITY OF SUCH DAMAGE.
|
3
third_party/pytorch/README
vendored
Normal file
3
third_party/pytorch/README
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
The Declarations-B.yaml files included in this directory have been obtained
|
||||
by compiling PyTorch from source using the code in branch B from the official
|
||||
GitHub repo: https://github.com/pytorch/pytorch
|
12
torch-sys/libtch/dummy_cuda_dependency.cpp
Normal file
12
torch-sys/libtch/dummy_cuda_dependency.cpp
Normal file
|
@ -0,0 +1,12 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
namespace at {
|
||||
namespace cuda {
|
||||
int warp_size();
|
||||
}
|
||||
}
|
||||
void dummy_cuda_dependency() {
|
||||
at::cuda::warp_size();
|
||||
}
|
6
torch-sys/libtch/fake_cuda_dependency.cpp
Normal file
6
torch-sys/libtch/fake_cuda_dependency.cpp
Normal file
|
@ -0,0 +1,6 @@
|
|||
extern "C" {
|
||||
void dummy_cuda_dependency();
|
||||
}
|
||||
|
||||
void dummy_cuda_dependency() {
|
||||
}
|
7530
torch-sys/libtch/stb_image.h
Normal file
7530
torch-sys/libtch/stb_image.h
Normal file
File diff suppressed because it is too large
Load Diff
2627
torch-sys/libtch/stb_image_resize.h
Normal file
2627
torch-sys/libtch/stb_image_resize.h
Normal file
File diff suppressed because it is too large
Load Diff
1621
torch-sys/libtch/stb_image_write.h
Normal file
1621
torch-sys/libtch/stb_image_write.h
Normal file
File diff suppressed because it is too large
Load Diff
894
torch-sys/libtch/torch_api.cpp
Normal file
894
torch-sys/libtch/torch_api.cpp
Normal file
|
@ -0,0 +1,894 @@
|
|||
#include<torch/csrc/autograd/engine.h>
|
||||
#include<torch/torch.h>
|
||||
#include<torch/script.h>
|
||||
#include<stdexcept>
|
||||
#include<vector>
|
||||
#include "torch_api.h"
|
||||
|
||||
#define STB_IMAGE_IMPLEMENTATION
|
||||
#include "stb_image.h"
|
||||
|
||||
#define STB_IMAGE_WRITE_IMPLEMENTATION
|
||||
#include "stb_image_write.h"
|
||||
|
||||
#define STB_IMAGE_RESIZE_IMPLEMENTATION
|
||||
#include "stb_image_resize.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
char *get_and_reset_last_err() {
|
||||
char *tmp = torch_last_err;
|
||||
torch_last_err = nullptr;
|
||||
return tmp;
|
||||
}
|
||||
|
||||
void at_manual_seed(int64_t seed) {
|
||||
torch::manual_seed(seed);
|
||||
}
|
||||
|
||||
vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len) {
|
||||
vector<torch::Tensor> result;
|
||||
for (int i = 0; i < len; ++i) result.push_back(*(vs[i]));
|
||||
return result;
|
||||
}
|
||||
|
||||
at::Device device_of_int(int d) {
|
||||
if (d < 0) return at::Device(at::kCPU);
|
||||
return at::Device(at::kCUDA, /*index=*/d);
|
||||
}
|
||||
tensor at_new_tensor() {
|
||||
PROTECT(
|
||||
return new torch::Tensor();
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type) {
|
||||
PROTECT(
|
||||
torch::Tensor tensor = torch::zeros(torch::IntArrayRef(dims, ndims), torch::ScalarType(type));
|
||||
if (element_size_in_bytes != tensor.element_size())
|
||||
throw std::invalid_argument("incoherent element sizes in bytes");
|
||||
void *tensor_data = tensor.data_ptr();
|
||||
memcpy(tensor_data, vs, tensor.numel() * element_size_in_bytes);
|
||||
return new torch::Tensor(tensor);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void at_copy_data(tensor tensor, void *vs, size_t numel, size_t elt_size_in_bytes) {
|
||||
PROTECT(
|
||||
if (elt_size_in_bytes != tensor->element_size())
|
||||
throw std::invalid_argument("incoherent element sizes in bytes");
|
||||
if (numel > tensor->numel())
|
||||
throw std::invalid_argument("target numel is larger than tensor numel");
|
||||
if (tensor->device().type() != at::kCPU) {
|
||||
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
}
|
||||
else {
|
||||
auto tmp_tensor = tensor->contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
tensor at_shallow_clone(tensor t) {
|
||||
PROTECT(return new torch::Tensor(*t);)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void *at_data_ptr(tensor t) {
|
||||
PROTECT(return t->data_ptr();)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
||||
int at_defined(tensor t) {
|
||||
PROTECT(return t->defined();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int at_is_sparse(tensor t) {
|
||||
PROTECT(return t->is_sparse();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t at_dim(tensor t) {
|
||||
PROTECT(return t->dim();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
void at_shape(tensor t, int64_t *dims) {
|
||||
PROTECT(
|
||||
int i = 0;
|
||||
for (int64_t dim : t->sizes()) dims[i++] = dim;
|
||||
)
|
||||
}
|
||||
|
||||
int at_scalar_type(tensor t) {
|
||||
PROTECT(
|
||||
return static_cast<int>(t->scalar_type());
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int at_device(tensor t) {
|
||||
PROTECT(
|
||||
auto device = t->device();
|
||||
if (device.type() == at::kCPU) return -1;
|
||||
if (device.type() == at::kCUDA) return device.index();
|
||||
)
|
||||
return -2;
|
||||
}
|
||||
|
||||
void at_backward(tensor t, int keep_graph, int create_graph) {
|
||||
PROTECT(t->backward({}, keep_graph, create_graph);)
|
||||
}
|
||||
|
||||
int at_requires_grad(tensor t) {
|
||||
PROTECT(return t->requires_grad();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int at_grad_set_enabled(int b) {
|
||||
PROTECT(
|
||||
bool is_enabled = torch::autograd::GradMode::is_enabled();
|
||||
torch::autograd::GradMode::set_enabled(b);
|
||||
return is_enabled;
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
tensor at_get(tensor t, int index) {
|
||||
PROTECT(return new torch::Tensor((*t)[index]);)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T at_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
|
||||
PROTECT(
|
||||
torch::Tensor tensor = *t;
|
||||
for (int i = 0; i < indexes_len; ++i) {
|
||||
tensor = tensor[indexes[i]];
|
||||
}
|
||||
return tensor.item<T>();
|
||||
)
|
||||
return T();
|
||||
}
|
||||
|
||||
double at_double_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
|
||||
return at_value_at_indexes<double>(t, indexes, indexes_len);
|
||||
}
|
||||
|
||||
int64_t at_int64_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
|
||||
return at_value_at_indexes<int64_t>(t, indexes, indexes_len);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void at_set_value_at_indexes(tensor t, int *indexes, int indexes_len, T v) {
|
||||
PROTECT(
|
||||
torch::Tensor tensor = *t;
|
||||
for (int i = 0; i < indexes_len; ++i) {
|
||||
tensor = tensor[indexes[i]];
|
||||
}
|
||||
tensor.fill_(v);
|
||||
)
|
||||
}
|
||||
|
||||
void at_set_double_value_at_indexes(tensor t, int *indexes, int indexes_len, double v) {
|
||||
at_set_value_at_indexes<double>(t, indexes, indexes_len, v);
|
||||
}
|
||||
|
||||
void at_set_int64_value_at_indexes(tensor t, int *indexes, int indexes_len, int64_t v) {
|
||||
at_set_value_at_indexes<int64_t>(t, indexes, indexes_len, v);
|
||||
}
|
||||
|
||||
void at_fill_double(tensor t, double v) {
|
||||
PROTECT(t->fill_(v);)
|
||||
}
|
||||
|
||||
void at_fill_int64(tensor t, int64_t v) {
|
||||
PROTECT(t->fill_(v);)
|
||||
}
|
||||
|
||||
void at_print(tensor t) {
|
||||
PROTECT(
|
||||
torch::Tensor *tensor = (torch::Tensor*)t;
|
||||
cout << *tensor << endl;
|
||||
)
|
||||
}
|
||||
|
||||
char *at_to_string(tensor t, int line_size) {
|
||||
PROTECT(
|
||||
std::ostringstream oss;
|
||||
torch::print(oss, *t, line_size);
|
||||
return strdup(oss.str().c_str());
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void at_copy_(tensor dst, tensor src) {
|
||||
PROTECT(
|
||||
dst->copy_(*src);
|
||||
)
|
||||
}
|
||||
|
||||
void at_save(tensor t, char *filename) {
|
||||
PROTECT(torch::save(*t, filename);)
|
||||
}
|
||||
|
||||
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
|
||||
PROTECT(
|
||||
torch::serialize::OutputArchive archive;
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
archive.write(std::string(tensor_names[i]), *(tensors[i]), /* buffer=*/ false);
|
||||
archive.save_to(filename);
|
||||
)
|
||||
}
|
||||
|
||||
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
|
||||
PROTECT(
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(std::string(filename));
|
||||
vector<torch::Tensor> ts(ntensors);
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
archive.read(std::string(tensor_names[i]), ts[i]);
|
||||
// Only allocate the new tensor now so that if there is an exception raised during
|
||||
// [read], no memory has to be freed.
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
tensors[i] = new torch::Tensor(ts[i]);
|
||||
)
|
||||
}
|
||||
|
||||
void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
|
||||
PROTECT(
|
||||
auto module = torch::jit::load(filename);
|
||||
for (const auto &p : module.named_parameters()) {
|
||||
auto v = p.value;
|
||||
f(data, (char*)p.name.c_str(), new torch::Tensor(v));
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id) {
|
||||
PROTECT(
|
||||
auto module = torch::jit::load(filename, device_of_int(device_id));
|
||||
for (const auto &p : module.named_parameters()) {
|
||||
auto v = p.value;
|
||||
f(data, (char*)p.name.c_str(), new torch::Tensor(v));
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
|
||||
PROTECT(
|
||||
torch::NoGradGuard no_grad;
|
||||
torch::serialize::InputArchive archive;
|
||||
archive.load_from(std::string(filename));
|
||||
for (int i = 0; i < ntensors; ++i) {
|
||||
if (tensors[i]->device().type() == at::kCPU)
|
||||
archive.read(std::string(tensor_names[i]), *(tensors[i]));
|
||||
else {
|
||||
torch::Tensor tmp_tensor = torch::empty_like(*(tensors[i]), at::device(at::kCPU));
|
||||
archive.read(std::string(tensor_names[i]), tmp_tensor);
|
||||
tensors[i]->copy_(tmp_tensor);
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
tensor at_load(char *filename) {
|
||||
PROTECT(
|
||||
torch::Tensor tensor;
|
||||
torch::load(tensor, filename);
|
||||
return new torch::Tensor(tensor);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor at_load_image(char *filename) {
|
||||
PROTECT(
|
||||
int w = -1;
|
||||
int h = -1;
|
||||
int c = -1;
|
||||
void *data = stbi_load(filename, &w, &h, &c, 3);
|
||||
if (data == nullptr)
|
||||
throw std::invalid_argument(stbi_failure_reason());
|
||||
torch::Tensor tensor = torch::zeros({ h, w, 3 }, at::ScalarType::Byte);
|
||||
memcpy(tensor.data_ptr(), data, h * w * 3);
|
||||
free(data);
|
||||
return new torch::Tensor(tensor);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool ends_with(const char *str, const char *suffix) {
|
||||
int suffix_len = strlen(suffix);
|
||||
int str_len = strlen(str);
|
||||
if (str_len < suffix_len) return false;
|
||||
for (int i = 1; i <= suffix_len; ++i)
|
||||
if (str[str_len-i] != suffix[suffix_len-i]) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
int at_save_image(tensor tensor, char *filename) {
|
||||
PROTECT(
|
||||
auto sizes = tensor->sizes();
|
||||
if (sizes.size() != 3)
|
||||
throw std::invalid_argument("invalid number of dimensions, should be 3");
|
||||
int h = sizes[0];
|
||||
int w = sizes[1];
|
||||
int c = sizes[2];
|
||||
auto tmp_tensor = tensor->contiguous();
|
||||
void *tensor_data = tmp_tensor.data_ptr();
|
||||
if (ends_with(filename, ".jpg"))
|
||||
return stbi_write_jpg(filename, w, h, c, tensor_data, 90);
|
||||
if (ends_with(filename, ".bmp"))
|
||||
return stbi_write_bmp(filename, w, h, c, tensor_data);
|
||||
if (ends_with(filename, ".tga"))
|
||||
return stbi_write_tga(filename, w, h, c, tensor_data);
|
||||
return stbi_write_png(filename, w, h, c, tensor_data, 0);
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int at_get_num_interop_threads() {
|
||||
PROTECT(return at::get_num_interop_threads();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int at_get_num_threads() {
|
||||
PROTECT(return at::get_num_threads();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
void at_set_num_interop_threads(int n_threads) {
|
||||
PROTECT(at::set_num_interop_threads(n_threads);)
|
||||
}
|
||||
|
||||
void at_set_num_threads(int n_threads) {
|
||||
PROTECT(at::set_num_threads(n_threads);)
|
||||
}
|
||||
|
||||
tensor at_resize_image(tensor tensor, int out_w, int out_h) {
|
||||
PROTECT(
|
||||
auto sizes = tensor->sizes();
|
||||
if (sizes.size() != 3)
|
||||
throw std::invalid_argument("invalid number of dimensions, should be 3");
|
||||
int h = sizes[0];
|
||||
int w = sizes[1];
|
||||
int c = sizes[2];
|
||||
auto tmp_tensor = tensor->contiguous();
|
||||
const unsigned char *tensor_data = (unsigned char*)tmp_tensor.data_ptr();
|
||||
torch::Tensor out = torch::zeros({ out_h, out_w, c }, at::ScalarType::Byte);
|
||||
stbir_resize_uint8(tensor_data, w, h, 0, (unsigned char*)out.data_ptr(), out_w, out_h, 0, c);
|
||||
return new torch::Tensor(out);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void at_free(tensor t) {
|
||||
delete(t);
|
||||
}
|
||||
|
||||
void at_run_backward(tensor *tensors,
|
||||
int ntensors,
|
||||
tensor *inputs,
|
||||
int ninputs,
|
||||
tensor *outputs,
|
||||
int keep_graph,
|
||||
int create_graph) {
|
||||
PROTECT(
|
||||
vector<torch::autograd::Edge> roots;
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
roots.push_back(torch::autograd::impl::gradient_edge(*tensors[i]));
|
||||
|
||||
vector<torch::autograd::Edge> inputs_;
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
if (!inputs[i]->requires_grad())
|
||||
throw std::invalid_argument("one of the input tensor does not use set_requires_grad");
|
||||
inputs_.push_back(torch::autograd::impl::gradient_edge(*inputs[i]));
|
||||
}
|
||||
|
||||
vector<torch::autograd::Variable> grads;
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
grads.push_back(torch::ones_like(*tensors[i]));
|
||||
|
||||
auto vl = torch::autograd::Engine::get_default_engine().execute(roots, grads, keep_graph, create_graph, inputs_);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
outputs[i] = static_cast<tensor>(new torch::autograd::Variable(vl[i]));
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
optimizer ato_adam(double learning_rate,
|
||||
double beta1,
|
||||
double beta2,
|
||||
double weight_decay) {
|
||||
PROTECT(
|
||||
auto options =
|
||||
torch::optim::AdamOptions(learning_rate)
|
||||
.betas(std::tuple<double, double>(beta1, beta2))
|
||||
.weight_decay(weight_decay);
|
||||
return new torch::optim::Adam(vector<torch::Tensor>(), options);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
optimizer ato_rms_prop(double learning_rate,
|
||||
double alpha,
|
||||
double eps,
|
||||
double weight_decay,
|
||||
double momentum,
|
||||
int centered) {
|
||||
PROTECT(
|
||||
auto options =
|
||||
torch::optim::RMSpropOptions(learning_rate)
|
||||
.alpha(alpha)
|
||||
.eps(eps)
|
||||
.weight_decay(weight_decay)
|
||||
.momentum(momentum)
|
||||
.centered(centered != 0);
|
||||
return new torch::optim::RMSprop(vector<torch::Tensor>(), options);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
optimizer ato_sgd(double learning_rate,
|
||||
double momentum,
|
||||
double dampening,
|
||||
double weight_decay,
|
||||
int nesterov) {
|
||||
PROTECT(
|
||||
auto options =
|
||||
torch::optim::SGDOptions(learning_rate)
|
||||
.momentum(momentum)
|
||||
.dampening(dampening)
|
||||
.weight_decay(weight_decay)
|
||||
.nesterov(nesterov);
|
||||
return new torch::optim::SGD(vector<torch::Tensor>(), options);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ato_add_parameters(optimizer t, tensor *tensors, int ntensors) {
|
||||
PROTECT(
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
t->param_groups()[0].params().push_back(*(tensors[i]));
|
||||
)
|
||||
}
|
||||
|
||||
void ato_set_learning_rate(optimizer t, double learning_rate) {
|
||||
PROTECT(
|
||||
torch::optim::OptimizerOptions* d = &(t->defaults());
|
||||
if (auto adam = dynamic_cast<torch::optim::AdamOptions*>(d))
|
||||
adam->lr(learning_rate);
|
||||
else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions*>(d))
|
||||
rms->lr(learning_rate);
|
||||
else if (auto sgd = dynamic_cast<torch::optim::SGDOptions*>(d))
|
||||
sgd->lr(learning_rate);
|
||||
else
|
||||
throw std::invalid_argument("unexpected optimizer");
|
||||
)
|
||||
}
|
||||
|
||||
void ato_set_momentum(optimizer t, double momentum) {
|
||||
PROTECT(
|
||||
torch::optim::OptimizerOptions* d = &(t->defaults());
|
||||
if (auto adam = dynamic_cast<torch::optim::AdamOptions*>(d)) {
|
||||
auto betas = adam->betas();
|
||||
adam->betas(std::tuple<double, double>(momentum, get<1>(betas)));
|
||||
}
|
||||
else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions*>(d))
|
||||
rms->momentum(momentum);
|
||||
else if (auto sgd = dynamic_cast<torch::optim::SGDOptions*>(d))
|
||||
sgd->momentum(momentum);
|
||||
else
|
||||
throw std::invalid_argument("unexpected optimizer");
|
||||
)
|
||||
}
|
||||
|
||||
void ato_zero_grad(optimizer t) {
|
||||
PROTECT(t->zero_grad();)
|
||||
}
|
||||
|
||||
void ato_step(optimizer t) {
|
||||
PROTECT(t->step();)
|
||||
}
|
||||
|
||||
void ato_free(optimizer t) {
|
||||
delete(t);
|
||||
}
|
||||
|
||||
scalar ats_int(int64_t v) {
|
||||
PROTECT(return new torch::Scalar(v);)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
scalar ats_float(double v) {
|
||||
PROTECT(return new torch::Scalar(v);)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int64_t ats_to_int(scalar s) {
|
||||
PROTECT(return s->toLong();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
double ats_to_float(scalar s) {
|
||||
PROTECT(return s->toDouble();)
|
||||
return 0.;
|
||||
}
|
||||
|
||||
char *ats_to_string(scalar s) {
|
||||
PROTECT(
|
||||
using namespace at;
|
||||
std::ostringstream oss;
|
||||
oss << (*s);
|
||||
return strdup(oss.str().c_str());
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void ats_free(scalar s) {
|
||||
delete(s);
|
||||
}
|
||||
|
||||
int atc_cuda_device_count() {
|
||||
PROTECT(return torch::cuda::device_count();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int atc_cuda_is_available() {
|
||||
PROTECT(return torch::cuda::is_available();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int atc_cudnn_is_available() {
|
||||
PROTECT(return torch::cuda::cudnn_is_available();)
|
||||
return -1;
|
||||
}
|
||||
|
||||
void atc_set_benchmark_cudnn(int b) {
|
||||
at::globalContext().setBenchmarkCuDNN(b);
|
||||
}
|
||||
|
||||
module atm_load(char *filename) {
|
||||
PROTECT(
|
||||
return new torch::jit::script::Module(torch::jit::load(filename));
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
module atm_load_str(char *data, size_t sz) {
|
||||
PROTECT(
|
||||
std::istringstream stream(std::string(data, sz));
|
||||
return new torch::jit::script::Module(torch::jit::load(stream));
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor atm_forward(module m, tensor *tensors, int ntensors) {
|
||||
PROTECT(
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
for (int i = 0; i < ntensors; ++i)
|
||||
inputs.push_back(*(tensors[i]));
|
||||
torch::jit::IValue output = m->forward(inputs);
|
||||
if (!output.isTensor())
|
||||
throw std::invalid_argument("forward did not return a tensor");
|
||||
return new torch::Tensor(output.toTensor());
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue atm_forward_(module m,
|
||||
ivalue *ivalues,
|
||||
int nivalues) {
|
||||
PROTECT(
|
||||
std::vector<torch::jit::IValue> inputs;
|
||||
for (int i = 0; i < nivalues; ++i)
|
||||
inputs.push_back(*(ivalues[i]));
|
||||
torch::jit::IValue output = m->forward(inputs);
|
||||
return new torch::jit::IValue(output);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void atm_free(module m) {
|
||||
delete(m);
|
||||
}
|
||||
|
||||
void atm_to(module m, int device, int dtype, bool non_blocking) {
|
||||
PROTECT(
|
||||
m->to(device_of_int(device), at::ScalarType(dtype), non_blocking);
|
||||
)
|
||||
}
|
||||
|
||||
ivalue ati_tensor(tensor t) {
|
||||
PROTECT(
|
||||
return new torch::jit::IValue(*t);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_int(int64_t i) {
|
||||
PROTECT(
|
||||
return new torch::jit::IValue(i);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_double(double d) {
|
||||
PROTECT(
|
||||
return new torch::jit::IValue(d);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_bool(int i) {
|
||||
PROTECT(
|
||||
return new torch::jit::IValue((bool)i);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_string(char *s) {
|
||||
PROTECT(
|
||||
string str(s);
|
||||
return new torch::jit::IValue(str);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_none() {
|
||||
PROTECT(
|
||||
return new torch::jit::IValue();
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_tuple(ivalue *is, int nvalues) {
|
||||
PROTECT(
|
||||
vector<torch::jit::IValue> vec;
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
|
||||
return new torch::jit::IValue(torch::ivalue::Tuple::create(vec));
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_generic_list(ivalue *is, int nvalues) {
|
||||
PROTECT(
|
||||
c10::List<torch::jit::IValue> vec(c10::AnyType::get());
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
|
||||
return new torch::jit::IValue(c10::List<torch::jit::IValue>(vec));
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_generic_dict(ivalue *is, int nvalues) {
|
||||
c10::Dict<torch::jit::IValue, torch::jit::IValue> dict(c10::AnyType::get(), c10::AnyType::get());
|
||||
PROTECT(
|
||||
for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1]));
|
||||
return new torch::jit::IValue(dict);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_int_list(int64_t *is, int nvalues) {
|
||||
PROTECT(
|
||||
c10::List<int64_t> vec;
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
|
||||
return new torch::jit::IValue(vec);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_double_list(double *is, int nvalues) {
|
||||
PROTECT(
|
||||
c10::List<double> vec;
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
|
||||
return new torch::jit::IValue(vec);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_bool_list(char *is, int nvalues) {
|
||||
PROTECT(
|
||||
c10::List<bool> vec;
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i] != 0);
|
||||
return new torch::jit::IValue(vec);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ivalue ati_tensor_list(tensor *is, int nvalues) {
|
||||
PROTECT(
|
||||
c10::List<at::Tensor> vec;
|
||||
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
|
||||
return new torch::jit::IValue(vec);
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ati_tag(ivalue i) {
|
||||
PROTECT(
|
||||
if (i->isNone()) return 0;
|
||||
else if (i->isTensor()) return 1;
|
||||
else if (i->isDouble()) return 2;
|
||||
else if (i->isInt()) return 3;
|
||||
else if (i->isBool()) return 4;
|
||||
else if (i->isTuple()) return 5;
|
||||
else if (i->isIntList()) return 6;
|
||||
else if (i->isDoubleList()) return 7;
|
||||
else if (i->isBoolList()) return 8;
|
||||
else if (i->isString()) return 9;
|
||||
else if (i->isTensorList()) return 10;
|
||||
else if (i->isList()) return 12;
|
||||
else if (i->isGenericDict()) return 13;
|
||||
throw std::invalid_argument(("unsupported tag" + i->tagKind()).c_str());
|
||||
return -1;
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int64_t ati_to_int(ivalue i) {
|
||||
PROTECT(
|
||||
return i->toInt();
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
double ati_to_double(ivalue i) {
|
||||
PROTECT(
|
||||
return i->toDouble();
|
||||
)
|
||||
return 0;
|
||||
}
|
||||
|
||||
int ati_to_bool(ivalue i) {
|
||||
PROTECT(
|
||||
return i->toBool();
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
char *ati_to_string(ivalue i) {
|
||||
PROTECT(
|
||||
auto str = i->toStringRef();
|
||||
return strdup(str.c_str());
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
tensor ati_to_tensor(ivalue i) {
|
||||
PROTECT(
|
||||
return new torch::Tensor(i->toTensor());
|
||||
)
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int ati_length(ivalue i) {
|
||||
PROTECT(
|
||||
if (i->isTuple()) return i->toTuple()->elements().size();
|
||||
else if (i->isIntList()) return i->toIntList().size();
|
||||
else if (i->isDoubleList()) return i->toDoubleList().size();
|
||||
else if (i->isBoolList()) return i->toBoolList().size();
|
||||
else if (i->isString()) return i->toStringRef().size();
|
||||
else if (i->isTensorList()) return i->toTensorList().size();
|
||||
else if (i->isList()) return i->toList().size();
|
||||
else if (i->isGenericDict()) return i->toGenericDict().size();
|
||||
throw std::invalid_argument(("unsupported tag for length " + i->tagKind()).c_str());
|
||||
return -1;
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
int ati_tuple_length(ivalue i) {
|
||||
PROTECT(
|
||||
return i->toTuple()->elements().size();
|
||||
)
|
||||
return -1;
|
||||
}
|
||||
|
||||
void ati_to_tuple(ivalue i,
|
||||
ivalue *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toTuple()->elements();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected tuple size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = new torch::jit::IValue(vec[i]);
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_generic_list(ivalue i,
|
||||
ivalue *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toList();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected list size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = new torch::jit::IValue(vec[i]);
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_generic_dict(ivalue i,
|
||||
ivalue *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto dict = i->toGenericDict();
|
||||
if (dict.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected dict size");
|
||||
}
|
||||
int k = 0;
|
||||
for (auto it = dict.begin(); it != dict.end(); ++it) {
|
||||
outputs[k++] = new torch::jit::IValue(it->key());
|
||||
outputs[k++] = new torch::jit::IValue(it->value());
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_int_list(ivalue i,
|
||||
int64_t *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toIntList();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected list size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = vec[i];
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_double_list(ivalue i,
|
||||
double *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toDoubleList();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected list size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = vec[i];
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_bool_list(ivalue i,
|
||||
char *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toBoolList();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected list size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = vec[i];
|
||||
)
|
||||
}
|
||||
|
||||
void ati_to_tensor_list(ivalue i,
|
||||
tensor *outputs,
|
||||
int noutputs) {
|
||||
PROTECT(
|
||||
auto vec = i->toTensorList();
|
||||
if (vec.size() != noutputs) {
|
||||
throw std::invalid_argument("unexpected tuple size");
|
||||
}
|
||||
for (int i = 0; i < noutputs; ++i)
|
||||
outputs[i] = new torch::Tensor(vec[i]);
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
void ati_free(ivalue i) {
|
||||
delete(i);
|
||||
}
|
||||
|
||||
#include "torch_api_generated.cpp.h"
|
175
torch-sys/libtch/torch_api.h
Normal file
175
torch-sys/libtch/torch_api.h
Normal file
|
@ -0,0 +1,175 @@
|
|||
#ifndef __TORCH_API_H__
|
||||
#define __TORCH_API_H__
|
||||
#include<stdint.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
thread_local char *torch_last_err = nullptr;
|
||||
|
||||
extern "C" {
|
||||
typedef torch::Tensor *tensor;
|
||||
typedef torch::Scalar *scalar;
|
||||
typedef torch::optim::Optimizer *optimizer;
|
||||
typedef torch::jit::script::Module *module;
|
||||
typedef torch::jit::IValue *ivalue;
|
||||
#define PROTECT(x) \
|
||||
try { \
|
||||
x \
|
||||
} catch (const exception& e) { \
|
||||
torch_last_err = strdup(e.what()); \
|
||||
}
|
||||
#else
|
||||
typedef void *tensor;
|
||||
typedef void *optimizer;
|
||||
typedef void *scalar;
|
||||
typedef void *module;
|
||||
typedef void *ivalue;
|
||||
#endif
|
||||
|
||||
char *get_and_reset_last_err(); // thread-local
|
||||
void at_manual_seed(int64_t);
|
||||
tensor at_new_tensor();
|
||||
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
|
||||
void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
|
||||
tensor at_shallow_clone(tensor);
|
||||
|
||||
void *at_data_ptr(tensor);
|
||||
int at_defined(tensor);
|
||||
int at_is_sparse(tensor);
|
||||
int at_device(tensor);
|
||||
size_t at_dim(tensor);
|
||||
void at_shape(tensor, int64_t *);
|
||||
int at_scalar_type(tensor);
|
||||
|
||||
void at_backward(tensor, int, int);
|
||||
int at_requires_grad(tensor);
|
||||
int at_grad_set_enabled(int);
|
||||
|
||||
tensor at_get(tensor, int index);
|
||||
void at_fill_double(tensor, double);
|
||||
void at_fill_int64(tensor, int64_t);
|
||||
|
||||
double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
void at_set_double_value_at_indexes(tensor, int *indexes, int indexes_len, double v);
|
||||
void at_set_int64_value_at_indexes(tensor, int *indexes, int indexes_len, int64_t v);
|
||||
|
||||
void at_copy_(tensor dst, tensor src);
|
||||
|
||||
void at_print(tensor);
|
||||
char *at_to_string(tensor, int line_size);
|
||||
void at_save(tensor, char *filename);
|
||||
tensor at_load(char *filename);
|
||||
tensor at_load_image(char *filename);
|
||||
int at_save_image(tensor, char *filename);
|
||||
tensor at_resize_image(tensor, int w, int h);
|
||||
|
||||
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
|
||||
/* [at_load_multi] takes as input an array of nullptr for [tensors]. */
|
||||
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
|
||||
/* [at_load_multi_] takes as input an array of allocation [tensors]. */
|
||||
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors, char *filename);
|
||||
|
||||
void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
|
||||
void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id);
|
||||
|
||||
int at_get_num_interop_threads();
|
||||
|
||||
int at_get_num_threads();
|
||||
|
||||
void at_set_num_interop_threads(int n_threads);
|
||||
|
||||
void at_set_num_threads(int n_threads);
|
||||
|
||||
void at_free(tensor);
|
||||
|
||||
void at_run_backward(tensor *tensors,
|
||||
int ntensors,
|
||||
tensor *inputs,
|
||||
int ninputs,
|
||||
tensor *outputs,
|
||||
int keep_graph,
|
||||
int create_graph);
|
||||
|
||||
optimizer ato_adam(double learning_rate,
|
||||
double beta1,
|
||||
double beta2,
|
||||
double weight_decay);
|
||||
optimizer ato_rms_prop(double learning_rate,
|
||||
double alpha,
|
||||
double eps,
|
||||
double weight_decay,
|
||||
double momentum,
|
||||
int centered);
|
||||
optimizer ato_sgd(double learning_rate,
|
||||
double momentum,
|
||||
double dampening,
|
||||
double weight_decay,
|
||||
int nesterov);
|
||||
void ato_add_parameters(optimizer, tensor *, int ntensors);
|
||||
void ato_set_learning_rate(optimizer, double learning_rate);
|
||||
void ato_set_momentum(optimizer, double momentum);
|
||||
void ato_zero_grad(optimizer);
|
||||
void ato_step(optimizer);
|
||||
void ato_free(optimizer);
|
||||
|
||||
scalar ats_int(int64_t);
|
||||
scalar ats_float(double);
|
||||
int64_t ats_to_int(scalar);
|
||||
double ats_to_float(scalar);
|
||||
char *ats_to_string(scalar);
|
||||
void ats_free(scalar);
|
||||
|
||||
int atc_cuda_device_count();
|
||||
int atc_cuda_is_available();
|
||||
int atc_cudnn_is_available();
|
||||
void atc_set_benchmark_cudnn(int b);
|
||||
|
||||
module atm_load(char *);
|
||||
module atm_load_str(char *, size_t sz);
|
||||
tensor atm_forward(module, tensor *tensors, int ntensors);
|
||||
ivalue atm_forward_(module,
|
||||
ivalue *ivalues,
|
||||
int nivalues);
|
||||
void atm_free(module);
|
||||
void atm_to(module m, int device, int dtype, bool non_blocking);
|
||||
|
||||
ivalue ati_none();
|
||||
ivalue ati_tensor(tensor);
|
||||
ivalue ati_int(int64_t);
|
||||
ivalue ati_double(double);
|
||||
ivalue ati_bool(int);
|
||||
ivalue ati_string(char *);
|
||||
ivalue ati_tuple(ivalue *, int);
|
||||
ivalue ati_generic_list(ivalue *, int);
|
||||
ivalue ati_generic_dict(ivalue *, int);
|
||||
ivalue ati_int_list(int64_t *, int);
|
||||
ivalue ati_double_list(double *, int);
|
||||
ivalue ati_bool_list(char *, int);
|
||||
ivalue ati_tensor_list(tensor *, int);
|
||||
|
||||
tensor ati_to_tensor(ivalue);
|
||||
int64_t ati_to_int(ivalue);
|
||||
double ati_to_double(ivalue);
|
||||
char *ati_to_string(ivalue);
|
||||
int ati_to_bool(ivalue);
|
||||
int ati_length(ivalue);
|
||||
int ati_tuple_length(ivalue);
|
||||
void ati_to_tuple(ivalue, ivalue *, int);
|
||||
void ati_to_generic_list(ivalue, ivalue *, int);
|
||||
void ati_to_generic_dict(ivalue, ivalue *, int);
|
||||
void ati_to_int_list(ivalue, int64_t *, int);
|
||||
void ati_to_double_list(ivalue, double *, int);
|
||||
void ati_to_bool_list(ivalue, char *, int);
|
||||
void ati_to_tensor_list(ivalue, tensor *, int);
|
||||
|
||||
int ati_tag(ivalue);
|
||||
|
||||
void ati_free(ivalue);
|
||||
|
||||
#include "torch_api_generated.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
};
|
||||
#endif
|
||||
|
||||
#endif
|
8108
torch-sys/libtch/torch_api_generated.cpp.h
Normal file
8108
torch-sys/libtch/torch_api_generated.cpp.h
Normal file
File diff suppressed because it is too large
Load Diff
1131
torch-sys/libtch/torch_api_generated.h
Normal file
1131
torch-sys/libtch/torch_api_generated.h
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user