initial commit

This commit is contained in:
sugarme 2020-05-22 23:43:09 +10:00
commit 0b2b0c1bf2
18 changed files with 146515 additions and 0 deletions

21
.gitignore vendored Normal file
View File

@ -0,0 +1,21 @@
.directory
*.swp
*.swo
*.dat
*.log
*.bak
*.out
*.dot
*.rs
*.txt
*.json
target/
_build/
data/
gen/.merlin
**/*.rs.bk
Cargo.lock
__pycache__

1
dune-project Normal file
View File

@ -0,0 +1 @@
(lang dune 2.2)

3
gen/dune Normal file
View File

@ -0,0 +1,3 @@
(executables
(names gen)
(libraries base stdio yaml))

655
gen/gen.ml Normal file
View File

@ -0,0 +1,655 @@
(* Automatically generate the C++ -> C -> rust bindings.
This takes as input the Descriptions.yaml file that gets generated when
building PyTorch from source.
Run with: dune exec gen/gen.exe
*)
open Base
open Stdio
let excluded_functions =
Set.of_list
(module String)
[
"multi_margin_loss";
"multi_margin_loss_out";
"log_softmax_backward_data";
"softmax_backward_data";
"clone";
"copy_";
"conv_transpose2d_backward_out";
"conv_transpose3d_backward_out";
"slow_conv_transpose2d_backward_out";
"slow_conv_transpose3d_backward_out";
"slow_conv3d_backward_out";
"normal";
"_cufft_set_plan_cache_max_size";
"_cufft_clear_plan_cache";
"backward";
"set_data";
"_amp_non_finite_check_and_unscale_";
"_cummin_helper";
"_cummax_helper";
"retain_grad";
]
let no_tensor_options =
Set.of_list
(module String)
[
"zeros_like";
"empty_like";
"full_like";
"ones_like";
"rand_like";
"randint_like";
"randn_like";
]
let prefixed_functions =
Set.of_list
(module String)
[ "add"; "add_"; "div"; "div_"; "mul"; "mul_"; "sub"; "sub_"; "nll_loss" ]
let excluded_prefixes = [ "_thnn_"; "_th_"; "thnn_"; "th_" ]
let excluded_suffixes = [ "_forward"; "_forward_out" ]
let yaml_error yaml ~msg =
Printf.failwithf "%s, %s" msg (Yaml.to_string_exn yaml) ()
let extract_bool = function
| `Bool b -> b
| `String "true" -> true
| `String "false" -> false
| yaml -> yaml_error yaml ~msg:"expected bool"
let extract_list = function
| `A l -> l
| yaml -> yaml_error yaml ~msg:"expected list"
let extract_map = function
| `O map -> Map.of_alist_exn (module String) map
| yaml -> yaml_error yaml ~msg:"expected map"
let extract_string = function
| `String s -> s
(* The yaml spec for torch uses n which is converted to a bool. *)
| `Bool b -> if b then "y" else "n"
| `Float f -> Float.to_string f
| yaml -> yaml_error yaml ~msg:"expected string"
module Func = struct
type arg_type =
| Bool
| Int64
| Double
| Tensor
| TensorOption
| IntList
| TensorList
| TensorOptions
| Scalar
| ScalarType
| Device
| String
type arg = {
arg_name : string;
arg_type : arg_type;
default_value : string option;
}
type t = {
name : string;
args : arg list;
returns : [ `fixed of int | `dynamic ];
(* number of tensors that are returned *)
kind : [ `function_ | `method_ ];
}
let arg_type_of_string str ~is_nullable =
match String.lowercase str with
| "bool" -> Some Bool
| "int64_t" -> Some Int64
| "double" -> Some Double
| "booltensor" | "indextensor" | "tensor" ->
Some (if is_nullable then TensorOption else Tensor)
| "tensoroptions" -> Some TensorOptions
| "intarrayref" | "intlist" -> Some IntList
| "tensorlist" -> Some TensorList
| "device" -> Some Device
| "scalar" -> Some Scalar
| "scalartype" -> Some ScalarType
| "std::string" -> Some String
| _ -> None
let c_typed_args_list t =
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| IntList ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions ->
Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
| otherwise ->
let simple_type_cstring =
match otherwise with
| Bool -> "int"
| Int64 -> "int64_t"
| Double -> "double"
| Tensor -> "tensor"
| TensorOption -> "tensor"
| ScalarType -> "int"
| Device -> "int"
| Scalar -> "scalar"
| String | IntList | TensorList | TensorOptions -> assert false
in
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|> String.concat ~sep:", "
let c_args_list args =
List.map args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| Scalar | Tensor -> "*" ^ arg_name
| TensorOption ->
Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
| Bool -> "(bool)" ^ arg_name
| IntList ->
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name
arg_name
| String ->
Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| TensorList ->
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
| TensorOptions ->
Printf.sprintf
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
arg_name arg_name
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
|> String.concat ~sep:", "
let c_call t =
match t.kind with
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
| `method_ -> (
match t.args with
| head :: tail ->
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
| [] ->
Printf.failwithf "Method calls should have at least one argument %s"
t.name () )
let replace_map =
Map.of_alist_exn
(module String)
[
("t", "tr");
("where", "where_");
("view", "view_");
("unsafe", "unsafe_");
]
let rust_name name =
let name =
Map.find replace_map name |> Option.value ~default:name
|> String.lowercase
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
in
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
let c_rust_args_list t =
List.map t.args ~f:(fun arg ->
let an = arg.arg_name in
let single_param = Printf.sprintf "%s_: %s" an in
match arg.arg_type with
| Bool -> single_param "c_int"
| Int64 -> single_param "i64"
| Double -> single_param "f64"
| Tensor -> single_param "*mut C_tensor"
| TensorOption -> single_param "*mut C_tensor"
| Scalar -> single_param "*mut C_scalar"
| ScalarType -> single_param "c_int"
| Device -> single_param "c_int"
| String -> Printf.sprintf "%s_ptr: *const u8, %s_len: c_int" an an
| IntList -> Printf.sprintf "%s_data: *const i64, %s_len: c_int" an an
| TensorList ->
Printf.sprintf "%s_data: *const *mut C_tensor, %s_len: c_int" an an
| TensorOptions ->
Printf.sprintf "%s_kind: c_int, %s_device: c_int" an an)
|> String.concat ~sep:", "
let self_name = "self"
let input_name = "input"
let self_tensor arg =
match arg.arg_type with
| Tensor -> String.( = ) arg.arg_name self_name
| _ -> false
let input_tensor arg =
match arg.arg_type with
| Tensor -> String.( = ) arg.arg_name input_name
| _ -> false
let type_parameters t =
let needs_scalar_parameter =
List.exists t.args ~f:(fun arg ->
match arg.arg_type with Scalar -> true | _ -> false)
in
let needs_type_parameter =
List.exists t.args ~f:(fun arg ->
match arg.arg_type with
| TensorList | TensorOption -> true
| _ -> false)
in
if needs_type_parameter && needs_scalar_parameter then
"<T: Borrow<Tensor>, S: Into<Scalar>>"
else if needs_type_parameter then "<T: Borrow<Tensor>>"
else if needs_scalar_parameter then "<S: Into<Scalar>>"
else ""
let rust_args_list t =
match List.partition_tf t.args ~f:self_tensor with
| [ self ], args_list -> (Some self, args_list)
| _, _ -> (
match List.partition_tf t.args ~f:input_tensor with
| [ self ], args_list -> (Some self, args_list)
| _, _ -> (None, t.args) )
let rust_typed_args_list t =
let to_string args =
List.map args ~f:(fun arg ->
let rust_arg_type =
match arg.arg_type with
| Bool -> "bool"
| Int64 ->
if String.( = ) arg.arg_name "reduction" then "crate::Reduction"
else "i64"
| Double -> "f64"
| Tensor -> "&Tensor"
| TensorOption -> "Option<T>"
| IntList -> "&[i64]"
| TensorList -> "&[T]"
| String -> "&str"
| TensorOptions -> "(Kind, Device)"
| Scalar -> "S"
| ScalarType -> "Kind"
| Device -> "Device"
in
Printf.sprintf "%s: %s" (rust_name arg.arg_name) rust_arg_type)
|> String.concat ~sep:", "
in
let self_arg =
if String.is_suffix t.name ~suffix:"_" then "&mut self" else "&self"
in
match List.partition_tf t.args ~f:self_tensor with
| [ self ], args_list ->
( Some self.arg_name,
Printf.sprintf "%s, %s" self_arg (to_string args_list) )
| _, _ -> (
match List.partition_tf t.args ~f:input_tensor with
| [ self ], args_list ->
( Some self.arg_name,
Printf.sprintf "%s, %s" self_arg (to_string args_list) )
| _, _ -> (None, to_string t.args) )
let rust_return_type t ~fallible =
let returns =
match t.returns with
| `fixed 1 -> "Tensor"
| `fixed v ->
List.init v ~f:(fun _ -> "Tensor")
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
| `dynamic -> "Vec<Tensor>"
in
if fallible then Printf.sprintf " -> failure::Fallible<%s>" returns
else Printf.sprintf " -> %s" returns
let rust_binding_args t ~self =
List.map t.args ~f:(fun arg ->
let name =
if Option.value_map self ~default:false ~f:(String.( = ) arg.arg_name)
then "self"
else rust_name arg.arg_name
in
match arg.arg_type with
| Tensor -> Printf.sprintf "%s.c_tensor" name
| Scalar -> Printf.sprintf "%s.into().c_scalar" name
| Bool -> Printf.sprintf "if %s { 1 } else { 0 }" name
| ScalarType -> Printf.sprintf "%s.c_int()" name
| Device -> Printf.sprintf "%s.c_int()" name
| TensorOptions -> Printf.sprintf "%s.0.c_int(), %s.1.c_int()" name name
| String -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
| IntList -> Printf.sprintf "%s.as_ptr(), %s.len() as i32" name name
| TensorList ->
Printf.sprintf "ptr_list(%s).as_ptr(), %s.len() as i32" name name
| TensorOption ->
Printf.sprintf
"%s.map_or(std::ptr::null_mut(), |t| t.borrow().c_tensor)" name
| Int64 when String.( = ) name "reduction" -> "reduction.to_int()"
| _ -> name)
|> String.concat ~sep:",\n "
end
exception Not_a_simple_arg
let read_yaml filename =
let funcs =
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
In_channel.with_file filename ~f:In_channel.input_lines
|> List.group ~break:(fun _ l ->
String.length l > 0 && Char.( = ) l.[0] '-')
|> List.concat_map ~f:(fun lines ->
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
in
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of"
|> extract_list |> List.map ~f:extract_string
in
let arguments = Map.find_exn map "arguments" |> extract_list in
let returns =
let is_tensor returns =
let returns = extract_map returns in
let return_type =
Map.find_exn returns "dynamic_type" |> extract_string
in
String.( = ) return_type "Tensor"
|| String.( = ) return_type "BoolTensor"
|| String.( = ) return_type "IndexTensor"
in
let returns = Map.find_exn map "returns" |> extract_list in
if List.for_all returns ~f:is_tensor then
Some (`fixed (List.length returns))
else
match returns with
| [ returns ] ->
let return_type =
Map.find_exn (extract_map returns) "dynamic_type"
|> extract_string
in
if String.( = ) return_type "TensorList" then Some `dynamic
else None
| [] | _ :: _ :: _ -> None
in
let kind =
if List.exists method_of ~f:(String.( = ) "namespace") then
Some `function_
else if List.exists method_of ~f:(String.( = ) "Tensor") then
Some `method_
else None
in
if
(not deprecated)
&& (not
(List.exists excluded_prefixes ~f:(fun prefix ->
String.is_prefix name ~prefix)))
&& (not
(List.exists excluded_suffixes ~f:(fun suffix ->
String.is_suffix name ~suffix)))
&& not (Set.mem excluded_functions name)
then
Option.both returns kind
|> Option.bind ~f:(fun (returns, kind) ->
try
let args =
List.filter_map arguments ~f:(fun arg ->
let arg = extract_map arg in
let arg_name =
Map.find_exn arg "name" |> extract_string
in
let arg_type =
Map.find_exn arg "dynamic_type" |> extract_string
in
let is_nullable =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar
when Option.is_some default_value && not is_nullable ->
None
| Some TensorOptions
when Option.is_some default_value
&& Set.mem no_tensor_options name ->
None
| Some arg_type ->
let arg_name =
match (arg_name, arg_type) with
| "self", Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some { Func.arg_name; arg_type; default_value }
| None ->
if Option.is_some default_value then None
else raise Not_a_simple_arg)
in
Some { Func.name; args; returns; kind }
with Not_a_simple_arg -> None)
else None)
let p out_channel s =
Printf.ksprintf
(fun line ->
Out_channel.output_string out_channel line;
Out_channel.output_char out_channel '\n')
s
let write_cpp funcs filename =
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
let c_typed_args_list = Func.c_typed_args_list func in
match func.returns with
| `fixed ntensors ->
pc "void atg_%s(tensor *out__, %s) {" exported_name
c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
if ntensors = 1 then
pc " out__[0] = new torch::Tensor(outputs__);"
else
for i = 0 to ntensors - 1 do
pc
" out__[%d] = new \
torch::Tensor(std::get<%d>(outputs__));"
i i
done;
pc " )";
pc "}";
pc "";
ph "void atg_%s(tensor *, %s);" exported_name
c_typed_args_list
| `dynamic ->
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
(* the returned type is a C++ vector of tensors *)
pc " int sz = outputs__.size();";
pc
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + \
1) * sizeof(torch::Tensor*));";
pc " for (int i = 0; i < sz; ++i)";
pc " out__[i] = new torch::Tensor(outputs__[i]);";
pc " out__[sz] = nullptr;";
pc " return out__;";
pc " )";
pc " return nullptr;";
pc "}";
pc "";
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list)))
let write_fallible_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use torch_sys::*;";
pm "use torch_sys::c_generated::*;";
pm "use crate::{Device, Kind, Scalar, Tensor};";
pm "use std::convert::Into;";
pm "use std::borrow::Borrow;";
pm "";
pm "fn ptr_list<T: Borrow<Tensor>>(l: &[T]) -> Vec<*mut C_tensor> {";
pm " l.iter().map(|x| x.borrow().c_tensor).collect()";
pm "}";
pm "";
pm "impl Tensor {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
let rust_name = Func.rust_name exported_name in
let self, rust_args_list = Func.rust_typed_args_list func in
pm "";
pm " pub fn f_%s%s(" rust_name (Func.type_parameters func);
pm " %s" rust_args_list;
pm " )%s {" (Func.rust_return_type func ~fallible:true);
match func.returns with
| `dynamic ->
pm " let c_tensors = unsafe_torch_err!({";
pm " atg_%s(" exported_name;
pm " %s)});" (Func.rust_binding_args func ~self);
pm " let mut r__ = vec![];";
pm " let mut i = 0;";
pm " loop {";
pm " let c__ = unsafe{*c_tensors.add(i)};";
pm " if c__.is_null() { break }";
pm " r__.push(Tensor {c_tensor: c__});";
pm " i += 1;";
pm " }";
pm " unsafe{libc::free(c_tensors as *mut libc::c_void)}";
pm " Ok(r__)";
pm " }"
| `fixed ntensors ->
pm " let mut c_tensors = [std::ptr::null_mut(); %d];"
ntensors;
pm " unsafe_torch_err!({";
pm " atg_%s(c_tensors.as_mut_ptr()," exported_name;
pm " %s" (Func.rust_binding_args func ~self);
pm " ) });";
let returns =
if ntensors = 1 then "Tensor { c_tensor: c_tensors[0] }"
else
List.init ntensors
~f:(Printf.sprintf "Tensor { c_tensor: c_tensors[%d] }")
|> String.concat ~sep:", " |> Printf.sprintf "(%s)"
in
pm " Ok(%s)" returns;
pm " }");
pm "}")
let write_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use crate::{Device, Kind, Scalar, Tensor};";
pm "use std::convert::Into;";
pm "use std::borrow::Borrow;";
pm "";
pm "impl Tensor {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:(func : Func.t) ->
let rust_name = Func.rust_name exported_name in
let rust_name, fallible_rust_name =
if Set.mem prefixed_functions func.name then
("g_" ^ rust_name, "f_" ^ rust_name)
else (rust_name, "f_" ^ rust_name)
in
pm "";
pm " pub fn %s%s(" rust_name (Func.type_parameters func);
let _self, rust_args_list = Func.rust_typed_args_list func in
pm " %s" rust_args_list;
pm " )%s {" (Func.rust_return_type func ~fallible:false);
let self, rust_args_list = Func.rust_args_list func in
let self = if Option.is_some self then "self." else "Tensor::" in
let rust_args_list =
List.map rust_args_list ~f:(fun arg ->
Func.rust_name arg.Func.arg_name)
|> String.concat ~sep:", "
in
pm " %s%s(%s).unwrap()" self fallible_rust_name rust_args_list;
pm " }");
pm "}")
let write_ffi funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */";
pm "#[allow(clippy::all)]";
pm "use crate::{C_scalar, C_tensor};";
pm "use libc::c_int;";
pm "";
pm "extern \"C\" {";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
match func.Func.returns with
| `fixed _ ->
pm " pub fn atg_%s(out__: *mut *mut C_tensor, %s);"
exported_name
(Func.c_rust_args_list func)
| `dynamic ->
pm " pub fn atg_%s(%s) -> *mut *mut C_tensor;" exported_name
(Func.c_rust_args_list func));
pm "}")
let methods =
let c name args = { Func.name; args; returns = `fixed 1; kind = `method_ } in
let ca arg_name arg_type =
{ Func.arg_name; arg_type; default_value = None }
in
[
c "grad" [ ca "self" Tensor ];
c "set_requires_grad" [ ca "self" Tensor; ca "r" Bool ];
c "toType" [ ca "self" Tensor; ca "scalar_type" ScalarType ];
c "to" [ ca "self" Tensor; ca "device" Device ];
]
let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
~fallible_wrapper_filename =
let funcs = read_yaml yaml_filename in
let funcs = methods @ funcs in
printf "Generating code for %d functions.\n%!" (List.length funcs);
(* Generate some unique names for overloaded functions. *)
let funcs =
List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
|> Map.of_alist_multi (module String)
|> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) ->
match funcs with
| [] -> assert false
| [ func ] -> [ (name, func) ]
| funcs ->
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
Int.compare (List.length f1.args) (List.length f2.args))
|> List.mapi ~f:(fun i func ->
( (if i = 0 then name else Printf.sprintf "%s%d" name i),
func )))
|> Map.of_alist_exn (module String)
in
write_cpp funcs cpp_filename;
write_ffi funcs ffi_filename;
write_wrapper funcs wrapper_filename;
write_fallible_wrapper funcs fallible_wrapper_filename
let () =
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
~cpp_filename:"torch-sys/libtch/torch_api_generated"
~ffi_filename:"torch-sys/src/c_generated.rs"
~wrapper_filename:"src/wrappers/tensor_generated.rs"
~fallible_wrapper_filename:"src/wrappers/tensor_fallible_generated.rs"

1
gen/gen.mli Normal file
View File

@ -0,0 +1 @@
(* Intentionally left blank. *)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

70
third_party/pytorch/LICENSE vendored Normal file
View File

@ -0,0 +1,70 @@
From PyTorch:
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
Copyright (c) 2011-2013 NYU (Clement Farabet)
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
From Caffe2:
Copyright (c) 2016-present, Facebook Inc. All rights reserved.
All contributions by Facebook:
Copyright (c) 2016 Facebook Inc.
All contributions by Google:
Copyright (c) 2015 Google Inc.
All rights reserved.
All contributions by Yangqing Jia:
Copyright (c) 2015 Yangqing Jia
All rights reserved.
All contributions from Caffe:
Copyright(c) 2013, 2014, 2015, the respective contributors
All rights reserved.
All other contributions:
Copyright(c) 2015, 2016 the respective contributors
All rights reserved.
Caffe2 uses a copyright model similar to Caffe: each contributor holds
copyright over their contributions to Caffe2. The project versioning records
all such contribution and copyright details. If a contributor wants to further
mark their specific copyright on a particular contribution, they should
indicate their copyright solely in the commit message of the change when it is
committed.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
and IDIAP Research Institute nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

3
third_party/pytorch/README vendored Normal file
View File

@ -0,0 +1,3 @@
The Declarations-B.yaml files included in this directory have been obtained
by compiling PyTorch from source using the code in branch B from the official
GitHub repo: https://github.com/pytorch/pytorch

View File

@ -0,0 +1,12 @@
extern "C" {
void dummy_cuda_dependency();
}
namespace at {
namespace cuda {
int warp_size();
}
}
void dummy_cuda_dependency() {
at::cuda::warp_size();
}

View File

@ -0,0 +1,6 @@
extern "C" {
void dummy_cuda_dependency();
}
void dummy_cuda_dependency() {
}

7530
torch-sys/libtch/stb_image.h Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,894 @@
#include<torch/csrc/autograd/engine.h>
#include<torch/torch.h>
#include<torch/script.h>
#include<stdexcept>
#include<vector>
#include "torch_api.h"
#define STB_IMAGE_IMPLEMENTATION
#include "stb_image.h"
#define STB_IMAGE_WRITE_IMPLEMENTATION
#include "stb_image_write.h"
#define STB_IMAGE_RESIZE_IMPLEMENTATION
#include "stb_image_resize.h"
using namespace std;
char *get_and_reset_last_err() {
char *tmp = torch_last_err;
torch_last_err = nullptr;
return tmp;
}
void at_manual_seed(int64_t seed) {
torch::manual_seed(seed);
}
vector<torch::Tensor> of_carray_tensor(torch::Tensor **vs, int len) {
vector<torch::Tensor> result;
for (int i = 0; i < len; ++i) result.push_back(*(vs[i]));
return result;
}
at::Device device_of_int(int d) {
if (d < 0) return at::Device(at::kCPU);
return at::Device(at::kCUDA, /*index=*/d);
}
tensor at_new_tensor() {
PROTECT(
return new torch::Tensor();
)
return nullptr;
}
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type) {
PROTECT(
torch::Tensor tensor = torch::zeros(torch::IntArrayRef(dims, ndims), torch::ScalarType(type));
if (element_size_in_bytes != tensor.element_size())
throw std::invalid_argument("incoherent element sizes in bytes");
void *tensor_data = tensor.data_ptr();
memcpy(tensor_data, vs, tensor.numel() * element_size_in_bytes);
return new torch::Tensor(tensor);
)
return nullptr;
}
void at_copy_data(tensor tensor, void *vs, size_t numel, size_t elt_size_in_bytes) {
PROTECT(
if (elt_size_in_bytes != tensor->element_size())
throw std::invalid_argument("incoherent element sizes in bytes");
if (numel > tensor->numel())
throw std::invalid_argument("target numel is larger than tensor numel");
if (tensor->device().type() != at::kCPU) {
torch::Tensor tmp_tensor = tensor->to(at::kCPU).contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
}
else {
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
memcpy(vs, tensor_data, numel * elt_size_in_bytes);
}
)
}
tensor at_shallow_clone(tensor t) {
PROTECT(return new torch::Tensor(*t);)
return nullptr;
}
void *at_data_ptr(tensor t) {
PROTECT(return t->data_ptr();)
return nullptr;
}
int at_defined(tensor t) {
PROTECT(return t->defined();)
return -1;
}
int at_is_sparse(tensor t) {
PROTECT(return t->is_sparse();)
return -1;
}
size_t at_dim(tensor t) {
PROTECT(return t->dim();)
return -1;
}
void at_shape(tensor t, int64_t *dims) {
PROTECT(
int i = 0;
for (int64_t dim : t->sizes()) dims[i++] = dim;
)
}
int at_scalar_type(tensor t) {
PROTECT(
return static_cast<int>(t->scalar_type());
)
return -1;
}
int at_device(tensor t) {
PROTECT(
auto device = t->device();
if (device.type() == at::kCPU) return -1;
if (device.type() == at::kCUDA) return device.index();
)
return -2;
}
void at_backward(tensor t, int keep_graph, int create_graph) {
PROTECT(t->backward({}, keep_graph, create_graph);)
}
int at_requires_grad(tensor t) {
PROTECT(return t->requires_grad();)
return -1;
}
int at_grad_set_enabled(int b) {
PROTECT(
bool is_enabled = torch::autograd::GradMode::is_enabled();
torch::autograd::GradMode::set_enabled(b);
return is_enabled;
)
return -1;
}
tensor at_get(tensor t, int index) {
PROTECT(return new torch::Tensor((*t)[index]);)
return nullptr;
}
template<typename T>
T at_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
PROTECT(
torch::Tensor tensor = *t;
for (int i = 0; i < indexes_len; ++i) {
tensor = tensor[indexes[i]];
}
return tensor.item<T>();
)
return T();
}
double at_double_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
return at_value_at_indexes<double>(t, indexes, indexes_len);
}
int64_t at_int64_value_at_indexes(tensor t, int64_t *indexes, int indexes_len) {
return at_value_at_indexes<int64_t>(t, indexes, indexes_len);
}
template<typename T>
void at_set_value_at_indexes(tensor t, int *indexes, int indexes_len, T v) {
PROTECT(
torch::Tensor tensor = *t;
for (int i = 0; i < indexes_len; ++i) {
tensor = tensor[indexes[i]];
}
tensor.fill_(v);
)
}
void at_set_double_value_at_indexes(tensor t, int *indexes, int indexes_len, double v) {
at_set_value_at_indexes<double>(t, indexes, indexes_len, v);
}
void at_set_int64_value_at_indexes(tensor t, int *indexes, int indexes_len, int64_t v) {
at_set_value_at_indexes<int64_t>(t, indexes, indexes_len, v);
}
void at_fill_double(tensor t, double v) {
PROTECT(t->fill_(v);)
}
void at_fill_int64(tensor t, int64_t v) {
PROTECT(t->fill_(v);)
}
void at_print(tensor t) {
PROTECT(
torch::Tensor *tensor = (torch::Tensor*)t;
cout << *tensor << endl;
)
}
char *at_to_string(tensor t, int line_size) {
PROTECT(
std::ostringstream oss;
torch::print(oss, *t, line_size);
return strdup(oss.str().c_str());
)
return nullptr;
}
void at_copy_(tensor dst, tensor src) {
PROTECT(
dst->copy_(*src);
)
}
void at_save(tensor t, char *filename) {
PROTECT(torch::save(*t, filename);)
}
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
PROTECT(
torch::serialize::OutputArchive archive;
for (int i = 0; i < ntensors; ++i)
archive.write(std::string(tensor_names[i]), *(tensors[i]), /* buffer=*/ false);
archive.save_to(filename);
)
}
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
PROTECT(
torch::serialize::InputArchive archive;
archive.load_from(std::string(filename));
vector<torch::Tensor> ts(ntensors);
for (int i = 0; i < ntensors; ++i)
archive.read(std::string(tensor_names[i]), ts[i]);
// Only allocate the new tensor now so that if there is an exception raised during
// [read], no memory has to be freed.
for (int i = 0; i < ntensors; ++i)
tensors[i] = new torch::Tensor(ts[i]);
)
}
void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor)) {
PROTECT(
auto module = torch::jit::load(filename);
for (const auto &p : module.named_parameters()) {
auto v = p.value;
f(data, (char*)p.name.c_str(), new torch::Tensor(v));
}
)
}
void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id) {
PROTECT(
auto module = torch::jit::load(filename, device_of_int(device_id));
for (const auto &p : module.named_parameters()) {
auto v = p.value;
f(data, (char*)p.name.c_str(), new torch::Tensor(v));
}
)
}
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors, char *filename) {
PROTECT(
torch::NoGradGuard no_grad;
torch::serialize::InputArchive archive;
archive.load_from(std::string(filename));
for (int i = 0; i < ntensors; ++i) {
if (tensors[i]->device().type() == at::kCPU)
archive.read(std::string(tensor_names[i]), *(tensors[i]));
else {
torch::Tensor tmp_tensor = torch::empty_like(*(tensors[i]), at::device(at::kCPU));
archive.read(std::string(tensor_names[i]), tmp_tensor);
tensors[i]->copy_(tmp_tensor);
}
}
)
}
tensor at_load(char *filename) {
PROTECT(
torch::Tensor tensor;
torch::load(tensor, filename);
return new torch::Tensor(tensor);
)
return nullptr;
}
tensor at_load_image(char *filename) {
PROTECT(
int w = -1;
int h = -1;
int c = -1;
void *data = stbi_load(filename, &w, &h, &c, 3);
if (data == nullptr)
throw std::invalid_argument(stbi_failure_reason());
torch::Tensor tensor = torch::zeros({ h, w, 3 }, at::ScalarType::Byte);
memcpy(tensor.data_ptr(), data, h * w * 3);
free(data);
return new torch::Tensor(tensor);
)
return nullptr;
}
bool ends_with(const char *str, const char *suffix) {
int suffix_len = strlen(suffix);
int str_len = strlen(str);
if (str_len < suffix_len) return false;
for (int i = 1; i <= suffix_len; ++i)
if (str[str_len-i] != suffix[suffix_len-i]) return false;
return true;
}
int at_save_image(tensor tensor, char *filename) {
PROTECT(
auto sizes = tensor->sizes();
if (sizes.size() != 3)
throw std::invalid_argument("invalid number of dimensions, should be 3");
int h = sizes[0];
int w = sizes[1];
int c = sizes[2];
auto tmp_tensor = tensor->contiguous();
void *tensor_data = tmp_tensor.data_ptr();
if (ends_with(filename, ".jpg"))
return stbi_write_jpg(filename, w, h, c, tensor_data, 90);
if (ends_with(filename, ".bmp"))
return stbi_write_bmp(filename, w, h, c, tensor_data);
if (ends_with(filename, ".tga"))
return stbi_write_tga(filename, w, h, c, tensor_data);
return stbi_write_png(filename, w, h, c, tensor_data, 0);
)
return -1;
}
int at_get_num_interop_threads() {
PROTECT(return at::get_num_interop_threads();)
return -1;
}
int at_get_num_threads() {
PROTECT(return at::get_num_threads();)
return -1;
}
void at_set_num_interop_threads(int n_threads) {
PROTECT(at::set_num_interop_threads(n_threads);)
}
void at_set_num_threads(int n_threads) {
PROTECT(at::set_num_threads(n_threads);)
}
tensor at_resize_image(tensor tensor, int out_w, int out_h) {
PROTECT(
auto sizes = tensor->sizes();
if (sizes.size() != 3)
throw std::invalid_argument("invalid number of dimensions, should be 3");
int h = sizes[0];
int w = sizes[1];
int c = sizes[2];
auto tmp_tensor = tensor->contiguous();
const unsigned char *tensor_data = (unsigned char*)tmp_tensor.data_ptr();
torch::Tensor out = torch::zeros({ out_h, out_w, c }, at::ScalarType::Byte);
stbir_resize_uint8(tensor_data, w, h, 0, (unsigned char*)out.data_ptr(), out_w, out_h, 0, c);
return new torch::Tensor(out);
)
return nullptr;
}
void at_free(tensor t) {
delete(t);
}
void at_run_backward(tensor *tensors,
int ntensors,
tensor *inputs,
int ninputs,
tensor *outputs,
int keep_graph,
int create_graph) {
PROTECT(
vector<torch::autograd::Edge> roots;
for (int i = 0; i < ntensors; ++i)
roots.push_back(torch::autograd::impl::gradient_edge(*tensors[i]));
vector<torch::autograd::Edge> inputs_;
for (int i = 0; i < ninputs; ++i) {
if (!inputs[i]->requires_grad())
throw std::invalid_argument("one of the input tensor does not use set_requires_grad");
inputs_.push_back(torch::autograd::impl::gradient_edge(*inputs[i]));
}
vector<torch::autograd::Variable> grads;
for (int i = 0; i < ntensors; ++i)
grads.push_back(torch::ones_like(*tensors[i]));
auto vl = torch::autograd::Engine::get_default_engine().execute(roots, grads, keep_graph, create_graph, inputs_);
for (int i = 0; i < ninputs; ++i) {
outputs[i] = static_cast<tensor>(new torch::autograd::Variable(vl[i]));
}
)
}
optimizer ato_adam(double learning_rate,
double beta1,
double beta2,
double weight_decay) {
PROTECT(
auto options =
torch::optim::AdamOptions(learning_rate)
.betas(std::tuple<double, double>(beta1, beta2))
.weight_decay(weight_decay);
return new torch::optim::Adam(vector<torch::Tensor>(), options);
)
return nullptr;
}
optimizer ato_rms_prop(double learning_rate,
double alpha,
double eps,
double weight_decay,
double momentum,
int centered) {
PROTECT(
auto options =
torch::optim::RMSpropOptions(learning_rate)
.alpha(alpha)
.eps(eps)
.weight_decay(weight_decay)
.momentum(momentum)
.centered(centered != 0);
return new torch::optim::RMSprop(vector<torch::Tensor>(), options);
)
return nullptr;
}
optimizer ato_sgd(double learning_rate,
double momentum,
double dampening,
double weight_decay,
int nesterov) {
PROTECT(
auto options =
torch::optim::SGDOptions(learning_rate)
.momentum(momentum)
.dampening(dampening)
.weight_decay(weight_decay)
.nesterov(nesterov);
return new torch::optim::SGD(vector<torch::Tensor>(), options);
)
return nullptr;
}
void ato_add_parameters(optimizer t, tensor *tensors, int ntensors) {
PROTECT(
for (int i = 0; i < ntensors; ++i)
t->param_groups()[0].params().push_back(*(tensors[i]));
)
}
void ato_set_learning_rate(optimizer t, double learning_rate) {
PROTECT(
torch::optim::OptimizerOptions* d = &(t->defaults());
if (auto adam = dynamic_cast<torch::optim::AdamOptions*>(d))
adam->lr(learning_rate);
else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions*>(d))
rms->lr(learning_rate);
else if (auto sgd = dynamic_cast<torch::optim::SGDOptions*>(d))
sgd->lr(learning_rate);
else
throw std::invalid_argument("unexpected optimizer");
)
}
void ato_set_momentum(optimizer t, double momentum) {
PROTECT(
torch::optim::OptimizerOptions* d = &(t->defaults());
if (auto adam = dynamic_cast<torch::optim::AdamOptions*>(d)) {
auto betas = adam->betas();
adam->betas(std::tuple<double, double>(momentum, get<1>(betas)));
}
else if (auto rms = dynamic_cast<torch::optim::RMSpropOptions*>(d))
rms->momentum(momentum);
else if (auto sgd = dynamic_cast<torch::optim::SGDOptions*>(d))
sgd->momentum(momentum);
else
throw std::invalid_argument("unexpected optimizer");
)
}
void ato_zero_grad(optimizer t) {
PROTECT(t->zero_grad();)
}
void ato_step(optimizer t) {
PROTECT(t->step();)
}
void ato_free(optimizer t) {
delete(t);
}
scalar ats_int(int64_t v) {
PROTECT(return new torch::Scalar(v);)
return nullptr;
}
scalar ats_float(double v) {
PROTECT(return new torch::Scalar(v);)
return nullptr;
}
int64_t ats_to_int(scalar s) {
PROTECT(return s->toLong();)
return -1;
}
double ats_to_float(scalar s) {
PROTECT(return s->toDouble();)
return 0.;
}
char *ats_to_string(scalar s) {
PROTECT(
using namespace at;
std::ostringstream oss;
oss << (*s);
return strdup(oss.str().c_str());
)
return nullptr;
}
void ats_free(scalar s) {
delete(s);
}
int atc_cuda_device_count() {
PROTECT(return torch::cuda::device_count();)
return -1;
}
int atc_cuda_is_available() {
PROTECT(return torch::cuda::is_available();)
return -1;
}
int atc_cudnn_is_available() {
PROTECT(return torch::cuda::cudnn_is_available();)
return -1;
}
void atc_set_benchmark_cudnn(int b) {
at::globalContext().setBenchmarkCuDNN(b);
}
module atm_load(char *filename) {
PROTECT(
return new torch::jit::script::Module(torch::jit::load(filename));
)
return nullptr;
}
module atm_load_str(char *data, size_t sz) {
PROTECT(
std::istringstream stream(std::string(data, sz));
return new torch::jit::script::Module(torch::jit::load(stream));
)
return nullptr;
}
tensor atm_forward(module m, tensor *tensors, int ntensors) {
PROTECT(
std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < ntensors; ++i)
inputs.push_back(*(tensors[i]));
torch::jit::IValue output = m->forward(inputs);
if (!output.isTensor())
throw std::invalid_argument("forward did not return a tensor");
return new torch::Tensor(output.toTensor());
)
return nullptr;
}
ivalue atm_forward_(module m,
ivalue *ivalues,
int nivalues) {
PROTECT(
std::vector<torch::jit::IValue> inputs;
for (int i = 0; i < nivalues; ++i)
inputs.push_back(*(ivalues[i]));
torch::jit::IValue output = m->forward(inputs);
return new torch::jit::IValue(output);
)
return nullptr;
}
void atm_free(module m) {
delete(m);
}
void atm_to(module m, int device, int dtype, bool non_blocking) {
PROTECT(
m->to(device_of_int(device), at::ScalarType(dtype), non_blocking);
)
}
ivalue ati_tensor(tensor t) {
PROTECT(
return new torch::jit::IValue(*t);
)
return nullptr;
}
ivalue ati_int(int64_t i) {
PROTECT(
return new torch::jit::IValue(i);
)
return nullptr;
}
ivalue ati_double(double d) {
PROTECT(
return new torch::jit::IValue(d);
)
return nullptr;
}
ivalue ati_bool(int i) {
PROTECT(
return new torch::jit::IValue((bool)i);
)
return nullptr;
}
ivalue ati_string(char *s) {
PROTECT(
string str(s);
return new torch::jit::IValue(str);
)
return nullptr;
}
ivalue ati_none() {
PROTECT(
return new torch::jit::IValue();
)
return nullptr;
}
ivalue ati_tuple(ivalue *is, int nvalues) {
PROTECT(
vector<torch::jit::IValue> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(torch::ivalue::Tuple::create(vec));
)
return nullptr;
}
ivalue ati_generic_list(ivalue *is, int nvalues) {
PROTECT(
c10::List<torch::jit::IValue> vec(c10::AnyType::get());
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(c10::List<torch::jit::IValue>(vec));
)
return nullptr;
}
ivalue ati_generic_dict(ivalue *is, int nvalues) {
c10::Dict<torch::jit::IValue, torch::jit::IValue> dict(c10::AnyType::get(), c10::AnyType::get());
PROTECT(
for (int i = 0; i < nvalues; ++i) dict.insert(*(is[2*i]), *(is[2*i+1]));
return new torch::jit::IValue(dict);
)
return nullptr;
}
ivalue ati_int_list(int64_t *is, int nvalues) {
PROTECT(
c10::List<int64_t> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
return new torch::jit::IValue(vec);
)
return nullptr;
}
ivalue ati_double_list(double *is, int nvalues) {
PROTECT(
c10::List<double> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i]);
return new torch::jit::IValue(vec);
)
return nullptr;
}
ivalue ati_bool_list(char *is, int nvalues) {
PROTECT(
c10::List<bool> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(is[i] != 0);
return new torch::jit::IValue(vec);
)
return nullptr;
}
ivalue ati_tensor_list(tensor *is, int nvalues) {
PROTECT(
c10::List<at::Tensor> vec;
for (int i = 0; i < nvalues; ++i) vec.push_back(*(is[i]));
return new torch::jit::IValue(vec);
)
return nullptr;
}
int ati_tag(ivalue i) {
PROTECT(
if (i->isNone()) return 0;
else if (i->isTensor()) return 1;
else if (i->isDouble()) return 2;
else if (i->isInt()) return 3;
else if (i->isBool()) return 4;
else if (i->isTuple()) return 5;
else if (i->isIntList()) return 6;
else if (i->isDoubleList()) return 7;
else if (i->isBoolList()) return 8;
else if (i->isString()) return 9;
else if (i->isTensorList()) return 10;
else if (i->isList()) return 12;
else if (i->isGenericDict()) return 13;
throw std::invalid_argument(("unsupported tag" + i->tagKind()).c_str());
return -1;
)
return -1;
}
int64_t ati_to_int(ivalue i) {
PROTECT(
return i->toInt();
)
return -1;
}
double ati_to_double(ivalue i) {
PROTECT(
return i->toDouble();
)
return 0;
}
int ati_to_bool(ivalue i) {
PROTECT(
return i->toBool();
)
return -1;
}
char *ati_to_string(ivalue i) {
PROTECT(
auto str = i->toStringRef();
return strdup(str.c_str());
)
return nullptr;
}
tensor ati_to_tensor(ivalue i) {
PROTECT(
return new torch::Tensor(i->toTensor());
)
return nullptr;
}
int ati_length(ivalue i) {
PROTECT(
if (i->isTuple()) return i->toTuple()->elements().size();
else if (i->isIntList()) return i->toIntList().size();
else if (i->isDoubleList()) return i->toDoubleList().size();
else if (i->isBoolList()) return i->toBoolList().size();
else if (i->isString()) return i->toStringRef().size();
else if (i->isTensorList()) return i->toTensorList().size();
else if (i->isList()) return i->toList().size();
else if (i->isGenericDict()) return i->toGenericDict().size();
throw std::invalid_argument(("unsupported tag for length " + i->tagKind()).c_str());
return -1;
)
return -1;
}
int ati_tuple_length(ivalue i) {
PROTECT(
return i->toTuple()->elements().size();
)
return -1;
}
void ati_to_tuple(ivalue i,
ivalue *outputs,
int noutputs) {
PROTECT(
auto vec = i->toTuple()->elements();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected tuple size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = new torch::jit::IValue(vec[i]);
)
}
void ati_to_generic_list(ivalue i,
ivalue *outputs,
int noutputs) {
PROTECT(
auto vec = i->toList();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = new torch::jit::IValue(vec[i]);
)
}
void ati_to_generic_dict(ivalue i,
ivalue *outputs,
int noutputs) {
PROTECT(
auto dict = i->toGenericDict();
if (dict.size() != noutputs) {
throw std::invalid_argument("unexpected dict size");
}
int k = 0;
for (auto it = dict.begin(); it != dict.end(); ++it) {
outputs[k++] = new torch::jit::IValue(it->key());
outputs[k++] = new torch::jit::IValue(it->value());
}
)
}
void ati_to_int_list(ivalue i,
int64_t *outputs,
int noutputs) {
PROTECT(
auto vec = i->toIntList();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = vec[i];
)
}
void ati_to_double_list(ivalue i,
double *outputs,
int noutputs) {
PROTECT(
auto vec = i->toDoubleList();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = vec[i];
)
}
void ati_to_bool_list(ivalue i,
char *outputs,
int noutputs) {
PROTECT(
auto vec = i->toBoolList();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected list size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = vec[i];
)
}
void ati_to_tensor_list(ivalue i,
tensor *outputs,
int noutputs) {
PROTECT(
auto vec = i->toTensorList();
if (vec.size() != noutputs) {
throw std::invalid_argument("unexpected tuple size");
}
for (int i = 0; i < noutputs; ++i)
outputs[i] = new torch::Tensor(vec[i]);
)
}
void ati_free(ivalue i) {
delete(i);
}
#include "torch_api_generated.cpp.h"

View File

@ -0,0 +1,175 @@
#ifndef __TORCH_API_H__
#define __TORCH_API_H__
#include<stdint.h>
#ifdef __cplusplus
thread_local char *torch_last_err = nullptr;
extern "C" {
typedef torch::Tensor *tensor;
typedef torch::Scalar *scalar;
typedef torch::optim::Optimizer *optimizer;
typedef torch::jit::script::Module *module;
typedef torch::jit::IValue *ivalue;
#define PROTECT(x) \
try { \
x \
} catch (const exception& e) { \
torch_last_err = strdup(e.what()); \
}
#else
typedef void *tensor;
typedef void *optimizer;
typedef void *scalar;
typedef void *module;
typedef void *ivalue;
#endif
char *get_and_reset_last_err(); // thread-local
void at_manual_seed(int64_t);
tensor at_new_tensor();
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
void at_copy_data(tensor tensor, void *vs, size_t numel, size_t element_size_in_bytes);
tensor at_shallow_clone(tensor);
void *at_data_ptr(tensor);
int at_defined(tensor);
int at_is_sparse(tensor);
int at_device(tensor);
size_t at_dim(tensor);
void at_shape(tensor, int64_t *);
int at_scalar_type(tensor);
void at_backward(tensor, int, int);
int at_requires_grad(tensor);
int at_grad_set_enabled(int);
tensor at_get(tensor, int index);
void at_fill_double(tensor, double);
void at_fill_int64(tensor, int64_t);
double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
void at_set_double_value_at_indexes(tensor, int *indexes, int indexes_len, double v);
void at_set_int64_value_at_indexes(tensor, int *indexes, int indexes_len, int64_t v);
void at_copy_(tensor dst, tensor src);
void at_print(tensor);
char *at_to_string(tensor, int line_size);
void at_save(tensor, char *filename);
tensor at_load(char *filename);
tensor at_load_image(char *filename);
int at_save_image(tensor, char *filename);
tensor at_resize_image(tensor, int w, int h);
void at_save_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
/* [at_load_multi] takes as input an array of nullptr for [tensors]. */
void at_load_multi(tensor *tensors, char **tensor_names, int ntensors, char *filename);
/* [at_load_multi_] takes as input an array of allocation [tensors]. */
void at_load_multi_(tensor *tensors, char **tensor_names, int ntensors, char *filename);
void at_load_callback(char *filename, void *data, void (*f)(void *, char *, tensor));
void at_load_callback_with_device(char *filename, void *data, void (*f)(void *, char *, tensor), int device_id);
int at_get_num_interop_threads();
int at_get_num_threads();
void at_set_num_interop_threads(int n_threads);
void at_set_num_threads(int n_threads);
void at_free(tensor);
void at_run_backward(tensor *tensors,
int ntensors,
tensor *inputs,
int ninputs,
tensor *outputs,
int keep_graph,
int create_graph);
optimizer ato_adam(double learning_rate,
double beta1,
double beta2,
double weight_decay);
optimizer ato_rms_prop(double learning_rate,
double alpha,
double eps,
double weight_decay,
double momentum,
int centered);
optimizer ato_sgd(double learning_rate,
double momentum,
double dampening,
double weight_decay,
int nesterov);
void ato_add_parameters(optimizer, tensor *, int ntensors);
void ato_set_learning_rate(optimizer, double learning_rate);
void ato_set_momentum(optimizer, double momentum);
void ato_zero_grad(optimizer);
void ato_step(optimizer);
void ato_free(optimizer);
scalar ats_int(int64_t);
scalar ats_float(double);
int64_t ats_to_int(scalar);
double ats_to_float(scalar);
char *ats_to_string(scalar);
void ats_free(scalar);
int atc_cuda_device_count();
int atc_cuda_is_available();
int atc_cudnn_is_available();
void atc_set_benchmark_cudnn(int b);
module atm_load(char *);
module atm_load_str(char *, size_t sz);
tensor atm_forward(module, tensor *tensors, int ntensors);
ivalue atm_forward_(module,
ivalue *ivalues,
int nivalues);
void atm_free(module);
void atm_to(module m, int device, int dtype, bool non_blocking);
ivalue ati_none();
ivalue ati_tensor(tensor);
ivalue ati_int(int64_t);
ivalue ati_double(double);
ivalue ati_bool(int);
ivalue ati_string(char *);
ivalue ati_tuple(ivalue *, int);
ivalue ati_generic_list(ivalue *, int);
ivalue ati_generic_dict(ivalue *, int);
ivalue ati_int_list(int64_t *, int);
ivalue ati_double_list(double *, int);
ivalue ati_bool_list(char *, int);
ivalue ati_tensor_list(tensor *, int);
tensor ati_to_tensor(ivalue);
int64_t ati_to_int(ivalue);
double ati_to_double(ivalue);
char *ati_to_string(ivalue);
int ati_to_bool(ivalue);
int ati_length(ivalue);
int ati_tuple_length(ivalue);
void ati_to_tuple(ivalue, ivalue *, int);
void ati_to_generic_list(ivalue, ivalue *, int);
void ati_to_generic_dict(ivalue, ivalue *, int);
void ati_to_int_list(ivalue, int64_t *, int);
void ati_to_double_list(ivalue, double *, int);
void ati_to_bool_list(ivalue, char *, int);
void ati_to_tensor_list(ivalue, tensor *, int);
int ati_tag(ivalue);
void ati_free(ivalue);
#include "torch_api_generated.h"
#ifdef __cplusplus
};
#endif
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff