fixed incorrect APIs generation

This commit is contained in:
sugarme 2023-08-05 13:02:56 +10:00
parent bdf252d831
commit e9278816b2
7 changed files with 382 additions and 401 deletions

View File

@ -82,6 +82,14 @@ let no_tensor_options =
; "randint_like" ; "randint_like"
; "randn_like" ] ; "randn_like" ]
(* By default, scalar argument that have a default value are not available on
the Rust side, this is to preserve the Rust api simplicity assuming that
these scalars arguments are not often overriden.
Adding function name [foo] in [with_optional_scalar_args] results in having
explicit scalar arguments even if a default is present. *)
let with_optional_scalar_args = Set.of_list (module String) [ "arange"; "baddbmm" ]
(* (*
* let prefixed_functions = * let prefixed_functions =
* Set.of_list * Set.of_list
@ -133,24 +141,26 @@ module Func = struct
| Double | Double
| DoubleOption | DoubleOption
| Tensor | Tensor
| TensorOption | TensorOption (* Tensor.t option *)
(* Tensor.t option *)
| IntList | IntList
| IntListOption | IntListOption
| DoubleList | DoubleList
| TensorOptList | TensorOptList
| TensorList | TensorList
| TensorOptions | TensorOptions (* Tensor kind and device *)
(* Tensor kind and device *)
| Scalar | Scalar
| ScalarType | ScalarType
| ScalarTypeOption
| Device | Device
| String | String
| Layout | Layout
| LayoutOption | LayoutOption
type arg = type arg =
{arg_name: string; arg_type: arg_type; default_value: string option} { arg_name: string
; arg_type: arg_type
; default_value: string option
}
(* `Func` type *) (* `Func` type *)
type t = type t =
@ -160,7 +170,8 @@ module Func = struct
; args: arg list (* ; returns: [`fixed of int | `dynamic] *) ; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing] ; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing]
; (* number of tensors that are returned *) ; (* number of tensors that are returned *)
kind: [`function_ | `method_] } kind: [`function_ | `method_]
}
let arg_type_of_string str ~is_nullable = let arg_type_of_string str ~is_nullable =
match String.lowercase str with match String.lowercase str with
@ -175,108 +186,111 @@ module Func = struct
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList | "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device | "at::device" -> Some Device
| "const at::scalar &" | "at::scalar" -> Some Scalar | "const at::scalar &" | "at::scalar" -> Some Scalar
| "at::scalartype" -> Some ScalarType | "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType
| "c10::string_view" -> Some String | "c10::string_view" -> Some String
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout) | "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
| _ -> None | _ -> None
let c_typed_args_list t = let c_typed_args_list t =
List.map t.args ~f:(fun { arg_name; arg_type; _ } -> List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with match arg_type with
| IntList | IntListOption -> | IntList | IntListOption ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name | DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
| TensorOptList | TensorList -> | TensorOptList | TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name | TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name | String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name | Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name | DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
| otherwise -> | otherwise ->
let simple_type_cstring = let simple_type_cstring =
match otherwise with match otherwise with
| Bool -> "int" | Bool -> "int"
| Int64 -> "int64_t" | Int64 -> "int64_t"
| Double -> "double" | Double -> "double"
| Tensor -> "tensor" | Tensor -> "tensor"
| TensorOption -> "tensor" | TensorOption -> "tensor"
| ScalarType -> "int" | ScalarType -> "int"
| Device -> "int" | ScalarTypeOption -> "int"
| Scalar -> "scalar" | Device -> "int"
| Layout | LayoutOption -> "int8_t" | Scalar -> "scalar"
| Int64Option | Layout | LayoutOption -> "int8_t"
| DoubleOption | Int64Option
| String | DoubleOption
| IntList | String
| IntListOption | IntList
| DoubleList | IntListOption
| TensorOptList | DoubleList
| TensorList | TensorOptList
| TensorOptions -> assert false | TensorList
in | TensorOptions -> assert false
Printf.sprintf "%s %s" simple_type_cstring arg_name) in
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|> String.concat ~sep:", " |> String.concat ~sep:", "
let c_args_list args = let c_args_list args =
List.map args ~f:(fun { arg_name; arg_type; _ } -> List.map args ~f:(fun { arg_name; arg_type; _ } ->
match arg_type with match arg_type with
| Scalar | Tensor -> "*" ^ arg_name | Scalar | Tensor -> "*" ^ arg_name
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name | Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
| LayoutOption -> | LayoutOption ->
Printf.sprintf Printf.sprintf
"(%s == -1 ? c10::nullopt : \ "(%s == -1 ? c10::nullopt : \
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))" c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
arg_name arg_name
arg_name arg_name
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name | TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
| Bool -> "(bool)" ^ arg_name | Bool -> "(bool)" ^ arg_name
| IntList -> | IntList -> Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name | IntListOption ->
| IntListOption -> Printf.sprintf
Printf.sprintf "%s_data == nullptr ? c10::nullopt : \
"%s_data == nullptr ? c10::nullopt : \ c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))" arg_name
arg_name arg_name
arg_name arg_name
arg_name | DoubleList ->
| DoubleList -> Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name | String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name | TensorOptList ->
| TensorOptList -> Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name | TensorList -> Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
| TensorList -> | TensorOptions ->
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name Printf.sprintf
| TensorOptions -> "at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
Printf.sprintf arg_name
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))" arg_name
arg_name | Int64Option ->
arg_name Printf.sprintf
| Int64Option -> "%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
Printf.sprintf arg_name
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)" arg_name
arg_name | DoubleOption ->
arg_name Printf.sprintf
| DoubleOption -> "%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
Printf.sprintf arg_name
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)" arg_name
arg_name | ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
arg_name | ScalarTypeOption ->
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name Printf.sprintf
| Device -> Printf.sprintf "device_of_int(%s)" arg_name "%s < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(%s))"
| _ -> arg_name) arg_name
arg_name
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
| _ -> arg_name)
|> String.concat ~sep:", " |> String.concat ~sep:", "
let c_call t = let c_call t =
match t.kind with match t.kind with
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args) | `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
| `method_ -> ( | `method_ ->
match t.args with (match t.args with
| head :: tail -> | head :: tail ->
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail) Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
| [] -> | [] ->
Printf.failwithf "Method calls should have at least one argument %s" Printf.failwithf "Method calls should have at least one argument %s" t.name ())
t.name () )
(* (*
let replace_map = let replace_map =
@ -289,6 +303,15 @@ module Func = struct
; ("to_device", "to_device_") ] ; ("to_device", "to_device_") ]
*) *)
let operator_name t =
match String.lowercase t.operator_name with
| "scatter_reduce" ->
(* scatter_reduce is both an operator name and also obtained from the
scatter operator when using the reduce overload. *)
"_scatter_reduce"
| "scatter_reduce_" -> "_scatter_reduce_"
| other -> other
let is_method t = let is_method t =
List.exists t.args ~f:(fun arg -> List.exists t.args ~f:(fun arg ->
match arg.arg_name with "self" -> true | _ -> false ) match arg.arg_name with "self" -> true | _ -> false )
@ -321,18 +344,16 @@ module Func = struct
let single_param = Printf.sprintf "%s %s" an in let single_param = Printf.sprintf "%s %s" an in
match arg.arg_type with match arg.arg_type with
| Bool -> single_param "int32" | Bool -> single_param "int32"
| Layout -> single_param "int8" | Layout | LayoutOption -> single_param "int8"
| LayoutOption -> single_param "int8"
| Int64 -> single_param "int64" | Int64 -> single_param "int64"
| Double -> single_param "float64" | Double -> single_param "float64"
| Tensor -> single_param "Ctensor" | Tensor -> single_param "Ctensor"
| TensorOption -> single_param "Ctensor" | TensorOption -> single_param "Ctensor"
| Scalar -> single_param "Cscalar" | Scalar -> single_param "Cscalar"
| ScalarType -> single_param "int32" | ScalarType | ScalarTypeOption -> single_param "int32"
| Device -> single_param "int32" | Device -> single_param "int32"
| String -> single_param "string" | String -> single_param "string"
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an | IntList | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
| IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
| DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an | DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an | TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an | TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
@ -353,14 +374,12 @@ module Func = struct
| Double -> Printf.sprintf "c%s" an | Double -> Printf.sprintf "c%s" an
| Tensor -> Printf.sprintf "%s" an | Tensor -> Printf.sprintf "%s" an
| TensorOption -> Printf.sprintf "%s" an | TensorOption -> Printf.sprintf "%s" an
| Layout -> Printf.sprintf "c%s" an | Layout | LayoutOption -> Printf.sprintf "c%s" an
| LayoutOption -> Printf.sprintf "c%s" an
| Scalar -> single_param "" | Scalar -> single_param ""
| ScalarType -> Printf.sprintf "c%s" an | ScalarType | ScalarTypeOption -> Printf.sprintf "c%s" an
| Device -> Printf.sprintf "c%s" an | Device -> Printf.sprintf "c%s" an
| String -> Printf.sprintf "c%s, c%sLen" an an | String -> Printf.sprintf "c%s, c%sLen" an an
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | IntList | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
@ -383,12 +402,10 @@ module Func = struct
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
| Tensor -> "" | Tensor -> ""
| TensorOption -> "" | TensorOption -> ""
| Layout -> | Layout | LayoutOption ->
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
| LayoutOption ->
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
| Scalar -> "" | Scalar -> ""
| ScalarType -> | ScalarType | ScalarTypeOption ->
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
| Device -> | Device ->
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
@ -399,13 +416,7 @@ module Func = struct
%sLen := len(%s)\n\ %sLen := len(%s)\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))" c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an an an an an an an an an
| IntList -> | IntList | IntListOption ->
Printf.sprintf
"\n\
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an
| IntListOption ->
Printf.sprintf Printf.sprintf
"\n\ "\n\
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\ c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
@ -494,14 +505,12 @@ module Func = struct
let go_arg_type = let go_arg_type =
match arg.arg_type with match arg.arg_type with
| Bool -> "bool" | Bool -> "bool"
| Layout -> "Layout" | Layout | LayoutOption -> "Layout"
| LayoutOption -> "Layout"
| Int64 -> "int64" | Int64 -> "int64"
| Double -> "float64" | Double -> "float64"
| Tensor -> "*Tensor" | Tensor -> "*Tensor"
| TensorOption -> "*Tensor" | TensorOption -> "*Tensor"
| IntList -> "[]int64" | IntList | IntListOption -> "[]int64"
| IntListOption -> "[]int64"
| DoubleList -> "[]float64" | DoubleList -> "[]float64"
| TensorOptList -> "[]*Tensor" | TensorOptList -> "[]*Tensor"
| TensorList -> "[]*Tensor" | TensorList -> "[]*Tensor"
@ -510,7 +519,7 @@ module Func = struct
(* E.g. `type KindDevice struct{}` *) (* E.g. `type KindDevice struct{}` *)
| TensorOptions -> "gotch.KindDevice" | TensorOptions -> "gotch.KindDevice"
| Scalar -> "*Scalar" | Scalar -> "*Scalar"
| ScalarType -> "gotch.DType" | ScalarType | ScalarTypeOption -> "gotch.DType"
| Int64Option -> "[]int64" | Int64Option -> "[]int64"
| DoubleOption -> "[]float64" | DoubleOption -> "[]float64"
| Device -> "gotch.Device" | Device -> "gotch.Device"
@ -603,7 +612,7 @@ module Func = struct
else Printf.sprintf "%s.ctensor" name else Printf.sprintf "%s.ctensor" name
| Scalar -> Printf.sprintf "%s.cscalar" name | Scalar -> Printf.sprintf "%s.cscalar" name
| Bool -> Printf.sprintf "c%s" name | Bool -> Printf.sprintf "c%s" name
| ScalarType -> Printf.sprintf "%s.CInt()" name | ScalarType | ScalarTypeOption -> Printf.sprintf "%s.CInt()" name
| Device -> Printf.sprintf "%s.CInt()" name | Device -> Printf.sprintf "%s.CInt()" name
| TensorOptions -> | TensorOptions ->
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
@ -633,11 +642,10 @@ module Func = struct
| Tensor -> "" | Tensor -> ""
| TensorOption -> "" | TensorOption -> ""
| Scalar -> "" | Scalar -> ""
| ScalarType -> "" | ScalarType | ScalarTypeOption -> ""
| Device -> "" | Device -> ""
| String -> "" | String -> ""
| IntList -> Printf.sprintf "%sLen := len(%s)\n" an an | IntList | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
| IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
| DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an | DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an
| Int64Option -> | Int64Option ->
Printf.sprintf Printf.sprintf
@ -667,8 +675,7 @@ module Func = struct
"var c%s []lib.Ctensor\n\ "var c%s []lib.Ctensor\n\
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n" \ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
an an an an an an an an
| Layout -> "" | Layout | LayoutOption -> ""
| LayoutOption -> ""
| TensorOptions -> "" ) | TensorOptions -> "" )
|> String.concat ~sep:"" |> String.concat ~sep:""
end end
@ -679,117 +686,103 @@ let read_yaml filename =
let funcs = let funcs =
(* Split the file to avoid Yaml.of_string_exn segfaulting. *) (* Split the file to avoid Yaml.of_string_exn segfaulting. *)
In_channel.with_file filename ~f:In_channel.input_lines In_channel.with_file filename ~f:In_channel.input_lines
|> List.group ~break:(fun _ l -> |> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-')
String.length l > 0 && Char.( = ) l.[0] '-' )
|> List.concat_map ~f:(fun lines -> |> List.concat_map ~f:(fun lines ->
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
)
in in
printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ; printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
List.filter_map funcs ~f:(fun yaml -> List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of = let method_of =
Map.find_exn map "method_of" Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
|> extract_list |> List.map ~f:extract_string in
let arguments = Map.find_exn map "arguments" |> extract_list in
let returns =
let is_tensor returns =
let returns = extract_map returns in
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
String.( = ) return_type "at::Tensor"
in in
let arguments = Map.find_exn map "arguments" |> extract_list in let returns = Map.find_exn map "returns" |> extract_list in
let returns = if List.is_empty returns
let is_tensor returns = then Some `nothing
let returns = extract_map returns in else if List.for_all returns ~f:is_tensor
then Some (`fixed (List.length returns))
else (
match returns with
| [ returns ] ->
let return_type = let return_type =
Map.find_exn returns "dynamic_type" |> extract_string Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
in in
String.( = ) return_type "at::Tensor" (match return_type with
in | "bool" -> Some `bool
let returns = Map.find_exn map "returns" |> extract_list in | "int64_t" -> Some `int64_t
if List.for_all returns ~f:is_tensor then | "double" -> Some `double
Some (`fixed (List.length returns)) | "at::TensorList" | "dynamic_type: const c10::List<c10::optional<Tensor>> &"
else -> Some `dynamic
match returns with | _ -> None)
| [returns] -> ( | [] | _ :: _ :: _ -> None)
let return_type = in
Map.find_exn (extract_map returns) "dynamic_type" let kind =
|> extract_string if List.exists method_of ~f:(String.( = ) "namespace")
in then Some `function_
match return_type with else if List.exists method_of ~f:(String.( = ) "Tensor")
| "bool" -> Some `bool then Some `method_
| "int64_t" -> Some `int64_t else None
| "double" -> Some `double in
| "at::TensorList" if (not deprecated)
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" -> && (not
Some `dynamic (List.exists excluded_prefixes ~f:(fun prefix ->
| _ -> None ) String.is_prefix name ~prefix)))
| [] | _ :: _ :: _ -> None && (not
in (List.exists excluded_suffixes ~f:(fun suffix ->
let kind = String.is_suffix name ~suffix)))
if List.exists method_of ~f:(String.( = ) "namespace") then && not (Set.mem excluded_functions name)
Some `function_ then
else if List.exists method_of ~f:(String.( = ) "Tensor") then Option.both returns kind
Some `method_ |> Option.bind ~f:(fun (returns, kind) ->
else None try
in let args ~with_optional_scalar_args =
if List.filter_map arguments ~f:(fun arg ->
(not deprecated) let arg = extract_map arg in
&& (not let arg_name = Map.find_exn arg "name" |> extract_string in
(List.exists excluded_prefixes ~f:(fun prefix -> let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in
String.is_prefix name ~prefix ))) let is_nullable =
&& (not Map.find arg "is_nullable"
(List.exists excluded_suffixes ~f:(fun suffix -> |> Option.value_map ~default:false ~f:extract_bool
String.is_suffix name ~suffix )))
&& not (Set.mem excluded_functions name)
then
Option.both returns kind
|> Option.bind ~f:(fun (returns, kind) ->
try
let args =
List.filter_map arguments ~f:(fun arg ->
let arg = extract_map arg in
let arg_name =
Map.find_exn arg "name" |> extract_string
in
let arg_type =
Map.find_exn arg "dynamic_type" |> extract_string
in
let is_nullable =
Map.find arg "is_nullable"
|> Option.value_map ~default:false ~f:extract_bool
in
let default_value =
Map.find arg "default" |> Option.map ~f:extract_string
in
match Func.arg_type_of_string arg_type ~is_nullable with
| Some Scalar
when Option.is_some default_value && not is_nullable
->
None
| Some TensorOptions
when Option.is_some default_value
&& Set.mem no_tensor_options name ->
None
| Some arg_type ->
let arg_name =
match (arg_name, arg_type) with
| "self", Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some {Func.arg_name; arg_type; default_value}
| None ->
if Option.is_some default_value then None
else raise Not_a_simple_arg )
in in
Some let default_value =
{ Func.name Map.find arg "default" |> Option.map ~f:extract_string
; operator_name in
; overload_name match Func.arg_type_of_string arg_type ~is_nullable with
; args | Some Scalar when Option.is_some default_value && not is_nullable ->
; returns if with_optional_scalar_args
; kind } then Some { Func.arg_name; arg_type = Scalar; default_value }
with Not_a_simple_arg -> None ) else None
else None ) | Some TensorOptions
when Option.is_some default_value && Set.mem no_tensor_options name ->
None
| Some arg_type ->
let arg_name =
match arg_name, arg_type with
| "self", Scalar -> "self_scalar"
| _, _ -> arg_name
in
Some { Func.arg_name; arg_type; default_value }
| None ->
if Option.is_some default_value then None else raise Not_a_simple_arg)
in
let args =
args ~with_optional_scalar_args:(Set.mem with_optional_scalar_args name)
in
Some [ { Func.name; operator_name; overload_name; args; returns; kind } ]
with
| Not_a_simple_arg -> None)
else None)
let p out_channel s = let p out_channel s =
Printf.ksprintf Printf.ksprintf
@ -803,72 +796,71 @@ let print_inline out_channel s =
let write_cpp funcs filename = let write_cpp funcs filename =
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp -> Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h -> Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
let pc s = p out_cpp s in let pc s = p out_cpp s in
let ph s = p out_h s in let ph s = p out_h s in
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
pc ""; pc "";
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
ph ""; ph "";
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func -> Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
let c_typed_args_list = Func.c_typed_args_list func in let c_typed_args_list = Func.c_typed_args_list func in
match func.returns with match func.returns with
| `nothing -> | `nothing ->
pc "void atg_%s(%s) {" exported_name c_typed_args_list; pc "void atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT("; pc " PROTECT(";
pc " %s;" (Func.c_call func); pc " %s;" (Func.c_call func);
pc " )"; pc " )";
pc "}"; pc "}";
pc ""; pc "";
ph "void atg_%s(%s);" exported_name c_typed_args_list ph "void atg_%s(%s);" exported_name c_typed_args_list
| `fixed ntensors -> | `fixed ntensors ->
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list; pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
pc " PROTECT("; pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func); pc " auto outputs__ = %s;" (Func.c_call func);
if ntensors = 1 if ntensors = 1
then pc " out__[0] = new torch::Tensor(outputs__);" then pc " out__[0] = new torch::Tensor(outputs__);"
else else
for i = 0 to ntensors - 1 do for i = 0 to ntensors - 1 do
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
done; done;
pc " )"; pc " )";
pc "}"; pc "}";
pc ""; pc "";
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
| `dynamic -> | `dynamic ->
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list; pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
pc " PROTECT("; pc " PROTECT(";
pc " auto outputs__ = %s;" (Func.c_call func); pc " auto outputs__ = %s;" (Func.c_call func);
(* the returned type is a C++ vector of tensors *) (* the returned type is a C++ vector of tensors *)
pc " int sz = outputs__.size();"; pc " int sz = outputs__.size();";
pc pc
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \ " torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
sizeof(torch::Tensor*));"; sizeof(torch::Tensor*));";
pc " for (int i = 0; i < sz; ++i)"; pc " for (int i = 0; i < sz; ++i)";
pc " out__[i] = new torch::Tensor(outputs__[i]);"; pc " out__[i] = new torch::Tensor(outputs__[i]);";
pc " out__[sz] = nullptr;"; pc " out__[sz] = nullptr;";
pc " return out__;"; pc " return out__;";
pc " )"; pc " )";
pc " return nullptr;"; pc " return nullptr;";
pc "}"; pc "}";
pc ""; pc "";
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
| (`bool | `int64_t | `double) as returns -> | (`bool | `int64_t | `double) as returns ->
let c_type = let c_type =
match returns with match returns with
| `bool -> "int" | `bool -> "int"
| `int64_t -> "int64_t" | `int64_t -> "int64_t"
| `double -> "double" | `double -> "double"
in in
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list; pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
pc " PROTECT("; pc " PROTECT(";
pc " return %s;" (Func.c_call func); pc " return %s;" (Func.c_call func);
pc " )"; pc " )";
pc " return 0;"; pc " return 0;";
pc "}"; pc "}";
pc ""; pc "";
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list))) ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
let write_wrapper funcs filename = let write_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml -> Out_channel.with_file filename ~f:(fun out_ml ->
@ -1402,54 +1394,43 @@ let methods =
; c "to" [ ca "self" Tensor; ca "device" Device ] ; c "to" [ ca "self" Tensor; ca "device" Device ]
] ]
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
~wrapper_filename = ~wrapper_filename =
let funcs = read_yaml yaml_filename in let funcs = read_yaml yaml_filename |> List.concat in
let funcs = methods @ funcs in let funcs = methods @ funcs in
printf "Generating code for %d functions.\n%!" (List.length funcs) ; printf "Generating code for %d functions.\n%!" (List.length funcs);
(* Generate some unique names for overloaded functions. *) (* Generate some unique names for overloaded functions. *)
let funcs = let funcs =
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func)) List.map funcs ~f:(fun func -> Func.operator_name func, func)
|> Map.of_alist_multi (module String) |> Map.of_alist_multi (module String)
|> Map.to_alist |> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) -> |> List.concat_map ~f:(fun (name, funcs) ->
match funcs with match funcs with
| [] -> assert false | [] -> assert false
| [func] -> [(name, func)] | [ func ] -> [ name, func ]
| funcs -> | funcs ->
let has_empty_overload = let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) -> List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name ) String.is_empty func.overload_name)
in in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) -> List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
match match Int.compare (String.length f1.name) (String.length f2.name) with
Int.compare (String.length f1.name) | 0 -> Int.compare (List.length f1.args) (List.length f2.args)
(String.length f2.name) | cmp -> cmp)
with |> List.mapi ~f:(fun index (func : Func.t) ->
| 0 -> let operator_name = Func.operator_name func in
Int.compare (List.length f1.args) (List.length f2.args) let overload_name = String.lowercase func.overload_name in
| cmp -> cmp ) let name =
|> List.mapi ~f:(fun index (func : Func.t) -> if String.is_empty overload_name || (index = 0 && not has_empty_overload)
let operator_name = then operator_name
String.lowercase func.operator_name else if String.is_suffix operator_name ~suffix:"_"
in then operator_name ^ overload_name ^ "_"
let overload_name = else operator_name ^ "_" ^ overload_name
String.lowercase func.overload_name in
in name, func))
let name =
if
String.is_empty overload_name
|| (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_" then
operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
(name, func) ) )
|> Map.of_alist_exn (module String) |> Map.of_alist_exn (module String)
in in
write_cpp funcs cpp_filename ; write_cpp funcs cpp_filename;
write_ffi funcs ffi_filename ; write_ffi funcs ffi_filename ;
write_must_wrapper funcs must_wrapper_filename ; write_must_wrapper funcs must_wrapper_filename ;
write_wrapper funcs wrapper_filename write_wrapper funcs wrapper_filename

View File

@ -2666,10 +2666,10 @@ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice)) coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice) C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice)
} }
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32){ func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32){
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind)) coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice)) coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
C.atg_arange_start_step(ptr, start , end , coptionsKind, coptionsDevice) C.atg_arange_start_step(ptr, start , end , step , coptionsKind, coptionsDevice)
} }
func AtgArccos(ptr *Ctensor, self Ctensor){ func AtgArccos(ptr *Ctensor, self Ctensor){
C.atg_arccos(ptr, self) C.atg_arccos(ptr, self)
@ -3004,8 +3004,8 @@ cdivisorOverrideVal := *(*C.int64_t)(unsafe.Pointer(&divisorOverrideVal))
cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull)) cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull))
C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull) C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull)
} }
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){ func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor, beta Cscalar, alpha Cscalar){
C.atg_baddbmm(ptr, self, batch1, batch2) C.atg_baddbmm(ptr, self, batch1, batch2, beta , alpha )
} }
func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){ func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
C.atg_baddbmm_(ptr, self, batch1, batch2) C.atg_baddbmm_(ptr, self, batch1, batch2)

View File

@ -2068,7 +2068,7 @@ void atg__slow_conv2d_backward(tensor *out__, tensor grad_input, tensor grad_wei
void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) { void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, at::ScalarType(dtype)); auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[0] = new torch::Tensor(std::get<0>(outputs__));
out__[1] = new torch::Tensor(std::get<1>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__));
) )
@ -2223,28 +2223,28 @@ void atg__sparse_csc_tensor_unsafe(tensor *out__, tensor ccol_indices, tensor ro
void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -2279,7 +2279,7 @@ void atg__sparse_log_softmax_backward_data_out(tensor *out__, tensor out, tensor
void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) { void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_log_softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -2336,7 +2336,7 @@ void atg__sparse_softmax_backward_data_out(tensor *out__, tensor out, tensor gra
void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) { void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_sparse_softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::_sparse_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -2629,14 +2629,14 @@ tensor *atg__to_cpu(tensor *tensors_data, int tensors_len) {
void atg__to_dense(tensor *out__, tensor self, int dtype) { void atg__to_dense(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = self->_to_dense(at::ScalarType(dtype)); auto outputs__ = self->_to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) { void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::_to_dense_out(*out, *self, at::ScalarType(dtype)); auto outputs__ = torch::_to_dense_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -3640,9 +3640,9 @@ void atg_arange_start(tensor *out__, scalar start, scalar end, int options_kind,
) )
} }
void atg_arange_start_step(tensor *out__, scalar start, scalar end, int options_kind, int options_device) { void atg_arange_start_step(tensor *out__, scalar start, scalar end, scalar step, int options_kind, int options_device) {
PROTECT( PROTECT(
auto outputs__ = torch::arange(*start, *end, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind))); auto outputs__ = torch::arange(*start, *end, *step, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -4120,9 +4120,9 @@ void atg_avg_pool3d_out(tensor *out__, tensor out, tensor self, int64_t *kernel_
) )
} }
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2) { void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha) {
PROTECT( PROTECT(
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2); auto outputs__ = torch::baddbmm(*self, *batch1, *batch2, *beta, *alpha);
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -5910,14 +5910,14 @@ void atg_cummin_out(tensor *out__, tensor values, tensor indices, tensor self, i
void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::cumprod(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::cumprod(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = self->cumprod_(dim, at::ScalarType(dtype)); auto outputs__ = self->cumprod_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -5931,28 +5931,28 @@ void atg_cumprod_backward(tensor *out__, tensor grad, tensor input, int64_t dim,
void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::cumprod_out(*out, *self, dim, at::ScalarType(dtype)); auto outputs__ = torch::cumprod_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::cumsum(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::cumsum(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = self->cumsum_(dim, at::ScalarType(dtype)); auto outputs__ = self->cumsum_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::cumsum_out(*out, *self, dim, at::ScalarType(dtype)); auto outputs__ = torch::cumsum_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -9912,28 +9912,28 @@ void atg_linalg_multi_dot_out(tensor *out__, tensor out, tensor *tensors_data, i
void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -10314,14 +10314,14 @@ void atg_log_sigmoid_out(tensor *out__, tensor out, tensor self) {
void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::log_softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::log_softmax_out(*out, *self, dim, at::ScalarType(dtype)); auto outputs__ = torch::log_softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -10953,21 +10953,21 @@ void atg_maximum_out(tensor *out__, tensor out, tensor self, tensor other) {
void atg_mean(tensor *out__, tensor self, int dtype) { void atg_mean(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::mean(*self, at::ScalarType(dtype)); auto outputs__ = torch::mean(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -11742,14 +11742,14 @@ void atg_nan_to_num_out(tensor *out__, tensor out, tensor self, double nan_v, ui
void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -11814,14 +11814,14 @@ void atg_nanquantile_scalar_out(tensor *out__, tensor out, tensor self, double q
void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -11961,14 +11961,14 @@ void atg_native_norm_out(tensor *out__, tensor out, tensor self) {
void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -12702,28 +12702,28 @@ void atg_prelu(tensor *out__, tensor self, tensor weight) {
void atg_prod(tensor *out__, tensor self, int dtype) { void atg_prod(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::prod(*self, at::ScalarType(dtype)); auto outputs__ = torch::prod(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) { void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::prod(*self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) { void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) { void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::prod_out(*out, *self, at::ScalarType(dtype)); auto outputs__ = torch::prod_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -14593,14 +14593,14 @@ void atg_soft_margin_loss_out(tensor *out__, tensor out, tensor self, tensor tar
void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::softmax_out(*out, *self, dim, at::ScalarType(dtype)); auto outputs__ = torch::softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -15528,7 +15528,7 @@ void atg_special_log_ndtr_out(tensor *out__, tensor out, tensor self) {
void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::special_log_softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::special_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -15913,7 +15913,7 @@ void atg_special_sinc_out(tensor *out__, tensor out, tensor self) {
void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::special_softmax(*self, dim, at::ScalarType(dtype)); auto outputs__ = torch::special_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -16437,28 +16437,28 @@ void atg_subtract_scalar_(tensor *out__, tensor self, scalar other) {
void atg_sum(tensor *out__, tensor self, int dtype) { void atg_sum(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::sum(*self, at::ScalarType(dtype)); auto outputs__ = torch::sum(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) { void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::sum_out(*out, *self, at::ScalarType(dtype)); auto outputs__ = torch::sum_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -16732,7 +16732,7 @@ void atg_to(tensor *out__, tensor self, int device) {
void atg_to_dense(tensor *out__, tensor self, int dtype) { void atg_to_dense(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = self->to_dense(at::ScalarType(dtype)); auto outputs__ = self->to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -16767,7 +16767,7 @@ void atg_to_dtype_layout(tensor *out__, tensor self, int options_kind, int optio
void atg_to_mkldnn(tensor *out__, tensor self, int dtype) { void atg_to_mkldnn(tensor *out__, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = self->to_mkldnn(at::ScalarType(dtype)); auto outputs__ = self->to_mkldnn(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }
@ -16781,7 +16781,7 @@ void atg_to_mkldnn_backward(tensor *out__, tensor grad, tensor input) {
void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) { void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) {
PROTECT( PROTECT(
auto outputs__ = torch::to_mkldnn_out(*out, *self, at::ScalarType(dtype)); auto outputs__ = torch::to_mkldnn_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
out__[0] = new torch::Tensor(outputs__); out__[0] = new torch::Tensor(outputs__);
) )
} }

View File

@ -497,7 +497,7 @@ void atg_any_dim(tensor *, tensor self, int64_t dim, int keepdim);
void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim); void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim);
void atg_arange(tensor *, scalar end, int options_kind, int options_device); void atg_arange(tensor *, scalar end, int options_kind, int options_device);
void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device); void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device);
void atg_arange_start_step(tensor *, scalar start, scalar end, int options_kind, int options_device); void atg_arange_start_step(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
void atg_arccos(tensor *, tensor self); void atg_arccos(tensor *, tensor self);
void atg_arccos_(tensor *, tensor self); void atg_arccos_(tensor *, tensor self);
void atg_arccos_out(tensor *, tensor out, tensor self); void atg_arccos_out(tensor *, tensor out, tensor self);
@ -563,7 +563,7 @@ void atg_avg_pool3d(tensor *, tensor self, int64_t *kernel_size_data, int kernel
void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2); void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha);
void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2); void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2);
void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2); void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2);
void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device); void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device);

View File

@ -3961,9 +3961,9 @@ func MustArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, option
return retVal return retVal
} }
func MustArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) { func MustArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
retVal, err := ArangeStartStep(start, end, optionsKind, optionsDevice) retVal, err := ArangeStartStep(start, end, step, optionsKind, optionsDevice)
if err != nil { log.Fatal(err) } if err != nil { log.Fatal(err) }
return retVal return retVal
@ -4465,9 +4465,9 @@ func(ts *Tensor) MustAvgPool3dOut(out *Tensor, kernelSize []int64, stride []int6
return retVal return retVal
} }
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor) { func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor) {
retVal, err := ts.Baddbmm(batch1, batch2, del) retVal, err := ts.Baddbmm(batch1, batch2, beta, alpha, del)
if err != nil { log.Fatal(err) } if err != nil { log.Fatal(err) }
return retVal return retVal

View File

@ -9327,10 +9327,10 @@ func ArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDev
// func.returns = `fixed 1`: // func.returns = `fixed 1`:
// -------------------------- // --------------------------
func ArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) { func ArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, optionsKind.CInt(), optionsDevice.CInt()) lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, step.cscalar, optionsKind.CInt(), optionsDevice.CInt())
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
err = fmt.Errorf("ArangeStartStep() failed: %w", err) err = fmt.Errorf("ArangeStartStep() failed: %w", err)
return retVal, err return retVal, err
@ -10585,11 +10585,11 @@ var cdivisorOverrideVal int64 = 0
// func.returns = `fixed 1`: // func.returns = `fixed 1`:
// -------------------------- // --------------------------
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor, err error) { func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor, err error) {
if del { defer ts.MustDrop() } if del { defer ts.MustDrop() }
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor) lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor, beta.cscalar, alpha.cscalar)
if err = TorchErr(); err != nil { if err = TorchErr(); err != nil {
err = fmt.Errorf("Baddbmm() failed: %w", err) err = fmt.Errorf("Baddbmm() failed: %w", err)
return retVal, err return retVal, err

View File

@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
} }
rgbTs := x.MustUnbind(-3, false) rgbTs := x.MustUnbind(-3, false)
r := &rgbTs[0] r := rgbTs[0]
g := &rgbTs[1] g := rgbTs[1]
b := &rgbTs[2] b := rgbTs[2]
// This implementation closely follows the TF one: // This implementation closely follows the TF one:
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 // https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
@ -453,7 +453,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3) a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4) a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}) out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
// Delete intermediate tensors // Delete intermediate tensors
h.MustDrop() h.MustDrop()
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
x2 := x1T.Idx(wNar) x2 := x1T.Idx(wNar)
x1T.MustDrop() x1T.MustDrop()
out := x2.MustT(true) out := x2.MustT(true)
chans[i] = *out chans[i] = out
} }
cropTs := ts.MustStack(chans, 0) cropTs := ts.MustStack(chans, 0)