fixed incorrect APIs generation

This commit is contained in:
sugarme 2023-08-05 13:02:56 +10:00
parent bdf252d831
commit e9278816b2
7 changed files with 382 additions and 401 deletions

View File

@ -82,6 +82,14 @@ let no_tensor_options =
; "randint_like"
; "randn_like" ]
(* By default, scalar argument that have a default value are not available on
the Rust side, this is to preserve the Rust api simplicity assuming that
these scalars arguments are not often overriden.
Adding function name [foo] in [with_optional_scalar_args] results in having
explicit scalar arguments even if a default is present. *)
let with_optional_scalar_args = Set.of_list (module String) [ "arange"; "baddbmm" ]
(*
* let prefixed_functions =
* Set.of_list
@ -133,24 +141,26 @@ module Func = struct
| Double
| DoubleOption
| Tensor
| TensorOption
(* Tensor.t option *)
| TensorOption (* Tensor.t option *)
| IntList
| IntListOption
| DoubleList
| TensorOptList
| TensorList
| TensorOptions
(* Tensor kind and device *)
| TensorOptions (* Tensor kind and device *)
| Scalar
| ScalarType
| ScalarTypeOption
| Device
| String
| Layout
| LayoutOption
type arg =
{arg_name: string; arg_type: arg_type; default_value: string option}
{ arg_name: string
; arg_type: arg_type
; default_value: string option
}
(* `Func` type *)
type t =
@ -160,7 +170,8 @@ module Func = struct
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing]
; (* number of tensors that are returned *)
kind: [`function_ | `method_] }
kind: [`function_ | `method_]
}
let arg_type_of_string str ~is_nullable =
match String.lowercase str with
@ -175,108 +186,111 @@ module Func = struct
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
| "const at::scalar &" | "at::scalar" -> Some Scalar
| "at::scalartype" -> Some ScalarType
| "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType
| "c10::string_view" -> Some String
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
| _ -> None
let c_typed_args_list t =
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| IntList | IntListOption ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
| TensorOptList | TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
| otherwise ->
let simple_type_cstring =
match otherwise with
| Bool -> "int"
| Int64 -> "int64_t"
| Double -> "double"
| Tensor -> "tensor"
| TensorOption -> "tensor"
| ScalarType -> "int"
| Device -> "int"
| Scalar -> "scalar"
| Layout | LayoutOption -> "int8_t"
| Int64Option
| DoubleOption
| String
| IntList
| IntListOption
| DoubleList
| TensorOptList
| TensorList
| TensorOptions -> assert false
in
Printf.sprintf "%s %s" simple_type_cstring arg_name)
match arg_type with
| IntList | IntListOption ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
| TensorOptList | TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
| otherwise ->
let simple_type_cstring =
match otherwise with
| Bool -> "int"
| Int64 -> "int64_t"
| Double -> "double"
| Tensor -> "tensor"
| TensorOption -> "tensor"
| ScalarType -> "int"
| ScalarTypeOption -> "int"
| Device -> "int"
| Scalar -> "scalar"
| Layout | LayoutOption -> "int8_t"
| Int64Option
| DoubleOption
| String
| IntList
| IntListOption
| DoubleList
| TensorOptList
| TensorList
| TensorOptions -> assert false
in
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|> String.concat ~sep:", "
let c_args_list args =
List.map args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with
| Scalar | Tensor -> "*" ^ arg_name
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
| LayoutOption ->
Printf.sprintf
"(%s == -1 ? c10::nullopt : \
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
arg_name
arg_name
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
| Bool -> "(bool)" ^ arg_name
| IntList ->
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
| IntListOption ->
Printf.sprintf
"%s_data == nullptr ? c10::nullopt : \
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
arg_name
arg_name
arg_name
| DoubleList ->
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| TensorOptList ->
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
| TensorList ->
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
| TensorOptions ->
Printf.sprintf
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
arg_name
arg_name
| Int64Option ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
arg_name
arg_name
| DoubleOption ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
arg_name
arg_name
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
match arg_type with
| Scalar | Tensor -> "*" ^ arg_name
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
| LayoutOption ->
Printf.sprintf
"(%s == -1 ? c10::nullopt : \
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
arg_name
arg_name
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
| Bool -> "(bool)" ^ arg_name
| IntList -> Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
| IntListOption ->
Printf.sprintf
"%s_data == nullptr ? c10::nullopt : \
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
arg_name
arg_name
arg_name
| DoubleList ->
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| TensorOptList ->
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
| TensorList -> Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
| TensorOptions ->
Printf.sprintf
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
arg_name
arg_name
| Int64Option ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
arg_name
arg_name
| DoubleOption ->
Printf.sprintf
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
arg_name
arg_name
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
| ScalarTypeOption ->
Printf.sprintf
"%s < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(%s))"
arg_name
arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
|> String.concat ~sep:", "
let c_call t =
match t.kind with
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
| `method_ -> (
match t.args with
| head :: tail ->
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
| [] ->
Printf.failwithf "Method calls should have at least one argument %s"
t.name () )
| `method_ ->
(match t.args with
| head :: tail ->
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
| [] ->
Printf.failwithf "Method calls should have at least one argument %s" t.name ())
(*
let replace_map =
@ -289,6 +303,15 @@ module Func = struct
; ("to_device", "to_device_") ]
*)
let operator_name t =
match String.lowercase t.operator_name with
| "scatter_reduce" ->
(* scatter_reduce is both an operator name and also obtained from the
scatter operator when using the reduce overload. *)
"_scatter_reduce"
| "scatter_reduce_" -> "_scatter_reduce_"
| other -> other
let is_method t =
List.exists t.args ~f:(fun arg ->
match arg.arg_name with "self" -> true | _ -> false )
@ -321,18 +344,16 @@ module Func = struct
let single_param = Printf.sprintf "%s %s" an in
match arg.arg_type with
| Bool -> single_param "int32"
| Layout -> single_param "int8"
| LayoutOption -> single_param "int8"
| Layout | LayoutOption -> single_param "int8"
| Int64 -> single_param "int64"
| Double -> single_param "float64"
| Tensor -> single_param "Ctensor"
| TensorOption -> single_param "Ctensor"
| Scalar -> single_param "Cscalar"
| ScalarType -> single_param "int32"
| ScalarType | ScalarTypeOption -> single_param "int32"
| Device -> single_param "int32"
| String -> single_param "string"
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
| IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
| IntList | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
| DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
@ -353,14 +374,12 @@ module Func = struct
| Double -> Printf.sprintf "c%s" an
| Tensor -> Printf.sprintf "%s" an
| TensorOption -> Printf.sprintf "%s" an
| Layout -> Printf.sprintf "c%s" an
| LayoutOption -> Printf.sprintf "c%s" an
| Layout | LayoutOption -> Printf.sprintf "c%s" an
| Scalar -> single_param ""
| ScalarType -> Printf.sprintf "c%s" an
| ScalarType | ScalarTypeOption -> Printf.sprintf "c%s" an
| Device -> Printf.sprintf "c%s" an
| String -> Printf.sprintf "c%s, c%sLen" an an
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| IntList | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
@ -383,12 +402,10 @@ module Func = struct
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
| Tensor -> ""
| TensorOption -> ""
| Layout ->
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
| LayoutOption ->
| Layout | LayoutOption ->
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
| Scalar -> ""
| ScalarType ->
| ScalarType | ScalarTypeOption ->
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
| Device ->
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
@ -399,13 +416,7 @@ module Func = struct
%sLen := len(%s)\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an an an
| IntList ->
Printf.sprintf
"\n\
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an
| IntListOption ->
| IntList | IntListOption ->
Printf.sprintf
"\n\
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
@ -494,14 +505,12 @@ module Func = struct
let go_arg_type =
match arg.arg_type with
| Bool -> "bool"
| Layout -> "Layout"
| LayoutOption -> "Layout"
| Layout | LayoutOption -> "Layout"
| Int64 -> "int64"
| Double -> "float64"
| Tensor -> "*Tensor"
| TensorOption -> "*Tensor"
| IntList -> "[]int64"
| IntListOption -> "[]int64"
| IntList | IntListOption -> "[]int64"
| DoubleList -> "[]float64"
| TensorOptList -> "[]*Tensor"
| TensorList -> "[]*Tensor"
@ -510,7 +519,7 @@ module Func = struct
(* E.g. `type KindDevice struct{}` *)
| TensorOptions -> "gotch.KindDevice"
| Scalar -> "*Scalar"
| ScalarType -> "gotch.DType"
| ScalarType | ScalarTypeOption -> "gotch.DType"
| Int64Option -> "[]int64"
| DoubleOption -> "[]float64"
| Device -> "gotch.Device"
@ -603,7 +612,7 @@ module Func = struct
else Printf.sprintf "%s.ctensor" name
| Scalar -> Printf.sprintf "%s.cscalar" name
| Bool -> Printf.sprintf "c%s" name
| ScalarType -> Printf.sprintf "%s.CInt()" name
| ScalarType | ScalarTypeOption -> Printf.sprintf "%s.CInt()" name
| Device -> Printf.sprintf "%s.CInt()" name
| TensorOptions ->
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
@ -633,11 +642,10 @@ module Func = struct
| Tensor -> ""
| TensorOption -> ""
| Scalar -> ""
| ScalarType -> ""
| ScalarType | ScalarTypeOption -> ""
| Device -> ""
| String -> ""
| IntList -> Printf.sprintf "%sLen := len(%s)\n" an an
| IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
| IntList | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
| DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an
| Int64Option ->
Printf.sprintf
@ -667,8 +675,7 @@ module Func = struct
"var c%s []lib.Ctensor\n\
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
an an an an
| Layout -> ""
| LayoutOption -> ""
| Layout | LayoutOption -> ""
| TensorOptions -> "" )
|> String.concat ~sep:""
end
@ -679,117 +686,103 @@ let read_yaml filename =
let funcs =
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
In_channel.with_file filename ~f:In_channel.input_lines
|> List.group ~break:(fun _ l ->
String.length l > 0 && Char.( = ) l.[0] '-' )
|> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-')
|> List.concat_map ~f:(fun lines ->
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list
)
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
in
printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ;
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of"
|> extract_list |> List.map ~f:extract_string
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
in
let arguments = Map.find_exn map "arguments" |> extract_list in
let returns =
let is_tensor returns =
let returns = extract_map returns in
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
String.( = ) return_type "at::Tensor"
in
let arguments = Map.find_exn map "arguments" |> extract_list in
let returns =
let is_tensor returns =
let returns = extract_map returns in
let returns = Map.find_exn map "returns" |> extract_list in
if List.is_empty returns
then Some `nothing
else if List.for_all returns ~f:is_tensor
then Some (`fixed (List.length returns))
else (
match returns with
| [ returns ] ->
let return_type =
Map.find_exn returns "dynamic_type" |> extract_string
Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
in
String.( = ) return_type "at::Tensor"
in
let returns = Map.find_exn map "returns" |> extract_list in
if List.for_all returns ~f:is_tensor then
Some (`fixed (List.length returns))
else
match returns with
| [returns] -> (
let return_type =
Map.find_exn (extract_map returns) "dynamic_type"
|> extract_string
in
match return_type with
| "bool" -> Some `bool
| "int64_t" -> Some `int64_t
| "double" -> Some `double
| "at::TensorList"
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
Some `dynamic
| _ -> None )
| [] | _ :: _ :: _ -> None
in
let kind =
if List.exists method_of ~f:(String.( = ) "namespace") then
Some `function_
else if List.exists method_of ~f:(String.( = ) "Tensor") then
Some `method_
else None
in
if
(not deprecated)
&& (not
(List.exists excluded_prefixes ~f:(fun prefix ->
String.is_prefix name ~prefix )))
&& (not
(List.exists excluded_suffixes ~f:(fun suffix ->
String.is_suffix name ~suffix )))
&& not (Set.mem excluded_functions name)
then
Option.both returns kind
|> Option.bind ~f:(fun (returns, kind) ->
try
let args =
List.filter_map arguments ~f:(fun arg ->
let arg = extract_map arg in
let arg_name =
Map.find_exn arg "name" |> extract_string
in
let arg_type =
Map.find_exn arg "dynamic_type" |> extract_string
in
let is_nullable =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar
when Option.is_some default_value && not is_nullable
->
None
| Some TensorOptions
when Option.is_some default_value
&& Set.mem no_tensor_options name ->
None
| Some arg_type ->
let arg_name =
match (arg_name, arg_type) with
| "self", Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some {Func.arg_name; arg_type; default_value}
| None ->
if Option.is_some default_value then None
else raise Not_a_simple_arg )
(match return_type with
| "bool" -> Some `bool
| "int64_t" -> Some `int64_t
| "double" -> Some `double
| "at::TensorList" | "dynamic_type: const c10::List<c10::optional<Tensor>> &"
-> Some `dynamic
| _ -> None)
| [] | _ :: _ :: _ -> None)
in
let kind =
if List.exists method_of ~f:(String.( = ) "namespace")
then Some `function_
else if List.exists method_of ~f:(String.( = ) "Tensor")
then Some `method_
else None
in
if (not deprecated)
&& (not
(List.exists excluded_prefixes ~f:(fun prefix ->
String.is_prefix name ~prefix)))
&& (not
(List.exists excluded_suffixes ~f:(fun suffix ->
String.is_suffix name ~suffix)))
&& not (Set.mem excluded_functions name)
then
Option.both returns kind
|> Option.bind ~f:(fun (returns, kind) ->
try
let args ~with_optional_scalar_args =
List.filter_map arguments ~f:(fun arg ->
let arg = extract_map arg in
let arg_name = Map.find_exn arg "name" |> extract_string in
let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in
let is_nullable =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
Some
{ Func.name
; operator_name
; overload_name
; args
; returns
; kind }
with Not_a_simple_arg -> None )
else None )
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar when Option.is_some default_value && not is_nullable ->
if with_optional_scalar_args
then Some { Func.arg_name; arg_type = Scalar; default_value }
else None
| Some TensorOptions
when Option.is_some default_value && Set.mem no_tensor_options name ->
None
| Some arg_type ->
let arg_name =
match arg_name, arg_type with
| "self", Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some { Func.arg_name; arg_type; default_value }
| None ->
if Option.is_some default_value then None else raise Not_a_simple_arg)
in
let args =
args ~with_optional_scalar_args:(Set.mem with_optional_scalar_args name)
in
Some [ { Func.name; operator_name; overload_name; args; returns; kind } ]
with
| Not_a_simple_arg -> None)
else None)
let p out_channel s =
Printf.ksprintf
@ -803,72 +796,71 @@ let print_inline out_channel s =
let write_cpp funcs filename =
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
let c_typed_args_list = Func.c_typed_args_list func in
match func.returns with
| `nothing ->
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " %s;" (Func.c_call func);
pc " )";
pc "}";
pc "";
ph "void atg_%s(%s);" exported_name c_typed_args_list
| `fixed ntensors ->
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
if ntensors = 1
then pc " out__[0] = new torch::Tensor(outputs__);"
else
for i = 0 to ntensors - 1 do
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
done;
pc " )";
pc "}";
pc "";
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
| `dynamic ->
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
(* the returned type is a C++ vector of tensors *)
pc " int sz = outputs__.size();";
pc
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
sizeof(torch::Tensor*));";
pc " for (int i = 0; i < sz; ++i)";
pc " out__[i] = new torch::Tensor(outputs__[i]);";
pc " out__[sz] = nullptr;";
pc " return out__;";
pc " )";
pc " return nullptr;";
pc "}";
pc "";
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
| (`bool | `int64_t | `double) as returns ->
let c_type =
match returns with
| `bool -> "int"
| `int64_t -> "int64_t"
| `double -> "double"
in
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
pc " PROTECT(";
pc " return %s;" (Func.c_call func);
pc " )";
pc " return 0;";
pc "}";
pc "";
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
let pc s = p out_cpp s in
let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph "";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
let c_typed_args_list = Func.c_typed_args_list func in
match func.returns with
| `nothing ->
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " %s;" (Func.c_call func);
pc " )";
pc "}";
pc "";
ph "void atg_%s(%s);" exported_name c_typed_args_list
| `fixed ntensors ->
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
if ntensors = 1
then pc " out__[0] = new torch::Tensor(outputs__);"
else
for i = 0 to ntensors - 1 do
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
done;
pc " )";
pc "}";
pc "";
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
| `dynamic ->
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func);
(* the returned type is a C++ vector of tensors *)
pc " int sz = outputs__.size();";
pc
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
sizeof(torch::Tensor*));";
pc " for (int i = 0; i < sz; ++i)";
pc " out__[i] = new torch::Tensor(outputs__[i]);";
pc " out__[sz] = nullptr;";
pc " return out__;";
pc " )";
pc " return nullptr;";
pc "}";
pc "";
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
| (`bool | `int64_t | `double) as returns ->
let c_type =
match returns with
| `bool -> "int"
| `int64_t -> "int64_t"
| `double -> "double"
in
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
pc " PROTECT(";
pc " return %s;" (Func.c_call func);
pc " )";
pc " return 0;";
pc "}";
pc "";
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
let write_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
@ -1402,54 +1394,43 @@ let methods =
; c "to" [ ca "self" Tensor; ca "device" Device ]
]
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
~wrapper_filename =
let funcs = read_yaml yaml_filename in
let funcs = read_yaml yaml_filename |> List.concat in
let funcs = methods @ funcs in
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
printf "Generating code for %d functions.\n%!" (List.length funcs);
(* Generate some unique names for overloaded functions. *)
let funcs =
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
List.map funcs ~f:(fun func -> Func.operator_name func, func)
|> Map.of_alist_multi (module String)
|> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) ->
match funcs with
| [] -> assert false
| [func] -> [(name, func)]
| funcs ->
let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name )
in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
match
Int.compare (String.length f1.name)
(String.length f2.name)
with
| 0 ->
Int.compare (List.length f1.args) (List.length f2.args)
| cmp -> cmp )
|> List.mapi ~f:(fun index (func : Func.t) ->
let operator_name =
String.lowercase func.operator_name
in
let overload_name =
String.lowercase func.overload_name
in
let name =
if
String.is_empty overload_name
|| (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_" then
operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
(name, func) ) )
match funcs with
| [] -> assert false
| [ func ] -> [ name, func ]
| funcs ->
let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name)
in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
match Int.compare (String.length f1.name) (String.length f2.name) with
| 0 -> Int.compare (List.length f1.args) (List.length f2.args)
| cmp -> cmp)
|> List.mapi ~f:(fun index (func : Func.t) ->
let operator_name = Func.operator_name func in
let overload_name = String.lowercase func.overload_name in
let name =
if String.is_empty overload_name || (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_"
then operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
name, func))
|> Map.of_alist_exn (module String)
in
write_cpp funcs cpp_filename ;
write_cpp funcs cpp_filename;
write_ffi funcs ffi_filename ;
write_must_wrapper funcs must_wrapper_filename ;
write_wrapper funcs wrapper_filename

View File

@ -2666,10 +2666,10 @@ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice)
}
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32){
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32){
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange_start_step(ptr, start , end , coptionsKind, coptionsDevice)
C.atg_arange_start_step(ptr, start , end , step , coptionsKind, coptionsDevice)
}
func AtgArccos(ptr *Ctensor, self Ctensor){
C.atg_arccos(ptr, self)
@ -3004,8 +3004,8 @@ cdivisorOverrideVal := *(*C.int64_t)(unsafe.Pointer(&divisorOverrideVal))
cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull))
C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull)
}
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
C.atg_baddbmm(ptr, self, batch1, batch2)
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor, beta Cscalar, alpha Cscalar){
C.atg_baddbmm(ptr, self, batch1, batch2, beta , alpha )
}
func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
C.atg_baddbmm_(ptr, self, batch1, batch2)

View File

@ -2068,7 +2068,7 @@ void atg__slow_conv2d_backward(tensor *out__, tensor grad_input, tensor grad_wei
void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) {
PROTECT(
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, at::ScalarType(dtype));
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(std::get<0>(outputs__));
out__[1] = new torch::Tensor(std::get<1>(outputs__));
)
@ -2223,28 +2223,28 @@ void atg__sparse_csc_tensor_unsafe(tensor *out__, tensor ccol_indices, tensor ro
void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -2279,7 +2279,7 @@ void atg__sparse_log_softmax_backward_data_out(tensor *out__, tensor out, tensor
void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_log_softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -2336,7 +2336,7 @@ void atg__sparse_softmax_backward_data_out(tensor *out__, tensor out, tensor gra
void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::_sparse_softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::_sparse_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -2629,14 +2629,14 @@ tensor *atg__to_cpu(tensor *tensors_data, int tensors_len) {
void atg__to_dense(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = self->_to_dense(at::ScalarType(dtype));
auto outputs__ = self->_to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::_to_dense_out(*out, *self, at::ScalarType(dtype));
auto outputs__ = torch::_to_dense_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -3640,9 +3640,9 @@ void atg_arange_start(tensor *out__, scalar start, scalar end, int options_kind,
)
}
void atg_arange_start_step(tensor *out__, scalar start, scalar end, int options_kind, int options_device) {
void atg_arange_start_step(tensor *out__, scalar start, scalar end, scalar step, int options_kind, int options_device) {
PROTECT(
auto outputs__ = torch::arange(*start, *end, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
auto outputs__ = torch::arange(*start, *end, *step, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -4120,9 +4120,9 @@ void atg_avg_pool3d_out(tensor *out__, tensor out, tensor self, int64_t *kernel_
)
}
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2) {
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha) {
PROTECT(
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2);
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2, *beta, *alpha);
out__[0] = new torch::Tensor(outputs__);
)
}
@ -5910,14 +5910,14 @@ void atg_cummin_out(tensor *out__, tensor values, tensor indices, tensor self, i
void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::cumprod(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::cumprod(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = self->cumprod_(dim, at::ScalarType(dtype));
auto outputs__ = self->cumprod_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -5931,28 +5931,28 @@ void atg_cumprod_backward(tensor *out__, tensor grad, tensor input, int64_t dim,
void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::cumprod_out(*out, *self, dim, at::ScalarType(dtype));
auto outputs__ = torch::cumprod_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::cumsum(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::cumsum(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = self->cumsum_(dim, at::ScalarType(dtype));
auto outputs__ = self->cumsum_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::cumsum_out(*out, *self, dim, at::ScalarType(dtype));
auto outputs__ = torch::cumsum_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -9912,28 +9912,28 @@ void atg_linalg_multi_dot_out(tensor *out__, tensor out, tensor *tensors_data, i
void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -10314,14 +10314,14 @@ void atg_log_sigmoid_out(tensor *out__, tensor out, tensor self) {
void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::log_softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::log_softmax_out(*out, *self, dim, at::ScalarType(dtype));
auto outputs__ = torch::log_softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -10953,21 +10953,21 @@ void atg_maximum_out(tensor *out__, tensor out, tensor self, tensor other) {
void atg_mean(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::mean(*self, at::ScalarType(dtype));
auto outputs__ = torch::mean(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -11742,14 +11742,14 @@ void atg_nan_to_num_out(tensor *out__, tensor out, tensor self, double nan_v, ui
void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -11814,14 +11814,14 @@ void atg_nanquantile_scalar_out(tensor *out__, tensor out, tensor self, double q
void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -11961,14 +11961,14 @@ void atg_native_norm_out(tensor *out__, tensor out, tensor self) {
void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -12702,28 +12702,28 @@ void atg_prelu(tensor *out__, tensor self, tensor weight) {
void atg_prod(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::prod(*self, at::ScalarType(dtype));
auto outputs__ = torch::prod(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::prod_out(*out, *self, at::ScalarType(dtype));
auto outputs__ = torch::prod_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -14593,14 +14593,14 @@ void atg_soft_margin_loss_out(tensor *out__, tensor out, tensor self, tensor tar
void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::softmax_out(*out, *self, dim, at::ScalarType(dtype));
auto outputs__ = torch::softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -15528,7 +15528,7 @@ void atg_special_log_ndtr_out(tensor *out__, tensor out, tensor self) {
void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::special_log_softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::special_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -15913,7 +15913,7 @@ void atg_special_sinc_out(tensor *out__, tensor out, tensor self) {
void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT(
auto outputs__ = torch::special_softmax(*self, dim, at::ScalarType(dtype));
auto outputs__ = torch::special_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -16437,28 +16437,28 @@ void atg_subtract_scalar_(tensor *out__, tensor self, scalar other) {
void atg_sum(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::sum(*self, at::ScalarType(dtype));
auto outputs__ = torch::sum(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT(
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::sum_out(*out, *self, at::ScalarType(dtype));
auto outputs__ = torch::sum_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -16732,7 +16732,7 @@ void atg_to(tensor *out__, tensor self, int device) {
void atg_to_dense(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = self->to_dense(at::ScalarType(dtype));
auto outputs__ = self->to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -16767,7 +16767,7 @@ void atg_to_dtype_layout(tensor *out__, tensor self, int options_kind, int optio
void atg_to_mkldnn(tensor *out__, tensor self, int dtype) {
PROTECT(
auto outputs__ = self->to_mkldnn(at::ScalarType(dtype));
auto outputs__ = self->to_mkldnn(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}
@ -16781,7 +16781,7 @@ void atg_to_mkldnn_backward(tensor *out__, tensor grad, tensor input) {
void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT(
auto outputs__ = torch::to_mkldnn_out(*out, *self, at::ScalarType(dtype));
auto outputs__ = torch::to_mkldnn_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__);
)
}

View File

@ -497,7 +497,7 @@ void atg_any_dim(tensor *, tensor self, int64_t dim, int keepdim);
void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim);
void atg_arange(tensor *, scalar end, int options_kind, int options_device);
void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device);
void atg_arange_start_step(tensor *, scalar start, scalar end, int options_kind, int options_device);
void atg_arange_start_step(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
void atg_arccos(tensor *, tensor self);
void atg_arccos_(tensor *, tensor self);
void atg_arccos_out(tensor *, tensor out, tensor self);
@ -563,7 +563,7 @@ void atg_avg_pool3d(tensor *, tensor self, int64_t *kernel_size_data, int kernel
void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2);
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha);
void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2);
void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2);
void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device);

View File

@ -3961,9 +3961,9 @@ func MustArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, option
return retVal
}
func MustArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
func MustArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
retVal, err := ArangeStartStep(start, end, optionsKind, optionsDevice)
retVal, err := ArangeStartStep(start, end, step, optionsKind, optionsDevice)
if err != nil { log.Fatal(err) }
return retVal
@ -4465,9 +4465,9 @@ func(ts *Tensor) MustAvgPool3dOut(out *Tensor, kernelSize []int64, stride []int6
return retVal
}
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor) {
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor) {
retVal, err := ts.Baddbmm(batch1, batch2, del)
retVal, err := ts.Baddbmm(batch1, batch2, beta, alpha, del)
if err != nil { log.Fatal(err) }
return retVal

View File

@ -9327,10 +9327,10 @@ func ArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDev
// func.returns = `fixed 1`:
// --------------------------
func ArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
func ArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, optionsKind.CInt(), optionsDevice.CInt())
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, step.cscalar, optionsKind.CInt(), optionsDevice.CInt())
if err = TorchErr(); err != nil {
err = fmt.Errorf("ArangeStartStep() failed: %w", err)
return retVal, err
@ -10585,11 +10585,11 @@ var cdivisorOverrideVal int64 = 0
// func.returns = `fixed 1`:
// --------------------------
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor, err error) {
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor, err error) {
if del { defer ts.MustDrop() }
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor)
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor, beta.cscalar, alpha.cscalar)
if err = TorchErr(); err != nil {
err = fmt.Errorf("Baddbmm() failed: %w", err)
return retVal, err

View File

@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
}
rgbTs := x.MustUnbind(-3, false)
r := &rgbTs[0]
g := &rgbTs[1]
b := &rgbTs[2]
r := rgbTs[0]
g := rgbTs[1]
b := rgbTs[2]
// This implementation closely follows the TF one:
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
@ -453,7 +453,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4})
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
// Delete intermediate tensors
h.MustDrop()
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
x2 := x1T.Idx(wNar)
x1T.MustDrop()
out := x2.MustT(true)
chans[i] = *out
chans[i] = out
}
cropTs := ts.MustStack(chans, 0)