fixed incorrect APIs generation
This commit is contained in:
parent
bdf252d831
commit
e9278816b2
647
gen/gen.ml
647
gen/gen.ml
|
@ -82,6 +82,14 @@ let no_tensor_options =
|
|||
; "randint_like"
|
||||
; "randn_like" ]
|
||||
|
||||
(* By default, scalar argument that have a default value are not available on
|
||||
the Rust side, this is to preserve the Rust api simplicity assuming that
|
||||
these scalars arguments are not often overriden.
|
||||
Adding function name [foo] in [with_optional_scalar_args] results in having
|
||||
explicit scalar arguments even if a default is present. *)
|
||||
let with_optional_scalar_args = Set.of_list (module String) [ "arange"; "baddbmm" ]
|
||||
|
||||
|
||||
(*
|
||||
* let prefixed_functions =
|
||||
* Set.of_list
|
||||
|
@ -133,24 +141,26 @@ module Func = struct
|
|||
| Double
|
||||
| DoubleOption
|
||||
| Tensor
|
||||
| TensorOption
|
||||
(* Tensor.t option *)
|
||||
| TensorOption (* Tensor.t option *)
|
||||
| IntList
|
||||
| IntListOption
|
||||
| DoubleList
|
||||
| TensorOptList
|
||||
| TensorList
|
||||
| TensorOptions
|
||||
(* Tensor kind and device *)
|
||||
| TensorOptions (* Tensor kind and device *)
|
||||
| Scalar
|
||||
| ScalarType
|
||||
| ScalarTypeOption
|
||||
| Device
|
||||
| String
|
||||
| Layout
|
||||
| LayoutOption
|
||||
|
||||
type arg =
|
||||
{arg_name: string; arg_type: arg_type; default_value: string option}
|
||||
{ arg_name: string
|
||||
; arg_type: arg_type
|
||||
; default_value: string option
|
||||
}
|
||||
|
||||
(* `Func` type *)
|
||||
type t =
|
||||
|
@ -160,7 +170,8 @@ module Func = struct
|
|||
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
|
||||
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing]
|
||||
; (* number of tensors that are returned *)
|
||||
kind: [`function_ | `method_] }
|
||||
kind: [`function_ | `method_]
|
||||
}
|
||||
|
||||
let arg_type_of_string str ~is_nullable =
|
||||
match String.lowercase str with
|
||||
|
@ -175,108 +186,111 @@ module Func = struct
|
|||
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
|
||||
| "at::device" -> Some Device
|
||||
| "const at::scalar &" | "at::scalar" -> Some Scalar
|
||||
| "at::scalartype" -> Some ScalarType
|
||||
| "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType
|
||||
| "c10::string_view" -> Some String
|
||||
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
|
||||
| _ -> None
|
||||
|
||||
|
||||
let c_typed_args_list t =
|
||||
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
|
||||
match arg_type with
|
||||
| IntList | IntListOption ->
|
||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptList | TensorList ->
|
||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
|
||||
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
|
||||
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
|
||||
| otherwise ->
|
||||
let simple_type_cstring =
|
||||
match otherwise with
|
||||
| Bool -> "int"
|
||||
| Int64 -> "int64_t"
|
||||
| Double -> "double"
|
||||
| Tensor -> "tensor"
|
||||
| TensorOption -> "tensor"
|
||||
| ScalarType -> "int"
|
||||
| Device -> "int"
|
||||
| Scalar -> "scalar"
|
||||
| Layout | LayoutOption -> "int8_t"
|
||||
| Int64Option
|
||||
| DoubleOption
|
||||
| String
|
||||
| IntList
|
||||
| IntListOption
|
||||
| DoubleList
|
||||
| TensorOptList
|
||||
| TensorList
|
||||
| TensorOptions -> assert false
|
||||
in
|
||||
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|
||||
match arg_type with
|
||||
| IntList | IntListOption ->
|
||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptList | TensorList ->
|
||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
|
||||
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
|
||||
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
|
||||
| otherwise ->
|
||||
let simple_type_cstring =
|
||||
match otherwise with
|
||||
| Bool -> "int"
|
||||
| Int64 -> "int64_t"
|
||||
| Double -> "double"
|
||||
| Tensor -> "tensor"
|
||||
| TensorOption -> "tensor"
|
||||
| ScalarType -> "int"
|
||||
| ScalarTypeOption -> "int"
|
||||
| Device -> "int"
|
||||
| Scalar -> "scalar"
|
||||
| Layout | LayoutOption -> "int8_t"
|
||||
| Int64Option
|
||||
| DoubleOption
|
||||
| String
|
||||
| IntList
|
||||
| IntListOption
|
||||
| DoubleList
|
||||
| TensorOptList
|
||||
| TensorList
|
||||
| TensorOptions -> assert false
|
||||
in
|
||||
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let c_args_list args =
|
||||
List.map args ~f:(fun { arg_name; arg_type; _ } ->
|
||||
match arg_type with
|
||||
| Scalar | Tensor -> "*" ^ arg_name
|
||||
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
|
||||
| LayoutOption ->
|
||||
Printf.sprintf
|
||||
"(%s == -1 ? c10::nullopt : \
|
||||
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
|
||||
arg_name
|
||||
arg_name
|
||||
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
|
||||
| Bool -> "(bool)" ^ arg_name
|
||||
| IntList ->
|
||||
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
|
||||
| IntListOption ->
|
||||
Printf.sprintf
|
||||
"%s_data == nullptr ? c10::nullopt : \
|
||||
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
|
||||
arg_name
|
||||
arg_name
|
||||
arg_name
|
||||
| DoubleList ->
|
||||
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
|
||||
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||
| TensorOptList ->
|
||||
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
|
||||
| TensorList ->
|
||||
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf
|
||||
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
|
||||
arg_name
|
||||
arg_name
|
||||
| Int64Option ->
|
||||
Printf.sprintf
|
||||
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
|
||||
arg_name
|
||||
arg_name
|
||||
| DoubleOption ->
|
||||
Printf.sprintf
|
||||
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
|
||||
arg_name
|
||||
arg_name
|
||||
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
|
||||
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
|
||||
| _ -> arg_name)
|
||||
match arg_type with
|
||||
| Scalar | Tensor -> "*" ^ arg_name
|
||||
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
|
||||
| LayoutOption ->
|
||||
Printf.sprintf
|
||||
"(%s == -1 ? c10::nullopt : \
|
||||
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
|
||||
arg_name
|
||||
arg_name
|
||||
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
|
||||
| Bool -> "(bool)" ^ arg_name
|
||||
| IntList -> Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
|
||||
| IntListOption ->
|
||||
Printf.sprintf
|
||||
"%s_data == nullptr ? c10::nullopt : \
|
||||
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
|
||||
arg_name
|
||||
arg_name
|
||||
arg_name
|
||||
| DoubleList ->
|
||||
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
|
||||
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||
| TensorOptList ->
|
||||
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
|
||||
| TensorList -> Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf
|
||||
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
|
||||
arg_name
|
||||
arg_name
|
||||
| Int64Option ->
|
||||
Printf.sprintf
|
||||
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
|
||||
arg_name
|
||||
arg_name
|
||||
| DoubleOption ->
|
||||
Printf.sprintf
|
||||
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
|
||||
arg_name
|
||||
arg_name
|
||||
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
|
||||
| ScalarTypeOption ->
|
||||
Printf.sprintf
|
||||
"%s < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(%s))"
|
||||
arg_name
|
||||
arg_name
|
||||
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
|
||||
| _ -> arg_name)
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
|
||||
let c_call t =
|
||||
match t.kind with
|
||||
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
|
||||
| `method_ -> (
|
||||
match t.args with
|
||||
| head :: tail ->
|
||||
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
|
||||
| [] ->
|
||||
Printf.failwithf "Method calls should have at least one argument %s"
|
||||
t.name () )
|
||||
| `method_ ->
|
||||
(match t.args with
|
||||
| head :: tail ->
|
||||
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
|
||||
| [] ->
|
||||
Printf.failwithf "Method calls should have at least one argument %s" t.name ())
|
||||
|
||||
(*
|
||||
let replace_map =
|
||||
|
@ -289,6 +303,15 @@ module Func = struct
|
|||
; ("to_device", "to_device_") ]
|
||||
*)
|
||||
|
||||
let operator_name t =
|
||||
match String.lowercase t.operator_name with
|
||||
| "scatter_reduce" ->
|
||||
(* scatter_reduce is both an operator name and also obtained from the
|
||||
scatter operator when using the reduce overload. *)
|
||||
"_scatter_reduce"
|
||||
| "scatter_reduce_" -> "_scatter_reduce_"
|
||||
| other -> other
|
||||
|
||||
let is_method t =
|
||||
List.exists t.args ~f:(fun arg ->
|
||||
match arg.arg_name with "self" -> true | _ -> false )
|
||||
|
@ -321,18 +344,16 @@ module Func = struct
|
|||
let single_param = Printf.sprintf "%s %s" an in
|
||||
match arg.arg_type with
|
||||
| Bool -> single_param "int32"
|
||||
| Layout -> single_param "int8"
|
||||
| LayoutOption -> single_param "int8"
|
||||
| Layout | LayoutOption -> single_param "int8"
|
||||
| Int64 -> single_param "int64"
|
||||
| Double -> single_param "float64"
|
||||
| Tensor -> single_param "Ctensor"
|
||||
| TensorOption -> single_param "Ctensor"
|
||||
| Scalar -> single_param "Cscalar"
|
||||
| ScalarType -> single_param "int32"
|
||||
| ScalarType | ScalarTypeOption -> single_param "int32"
|
||||
| Device -> single_param "int32"
|
||||
| String -> single_param "string"
|
||||
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||
| IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||
| IntList | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||
| DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an
|
||||
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
|
@ -353,14 +374,12 @@ module Func = struct
|
|||
| Double -> Printf.sprintf "c%s" an
|
||||
| Tensor -> Printf.sprintf "%s" an
|
||||
| TensorOption -> Printf.sprintf "%s" an
|
||||
| Layout -> Printf.sprintf "c%s" an
|
||||
| LayoutOption -> Printf.sprintf "c%s" an
|
||||
| Layout | LayoutOption -> Printf.sprintf "c%s" an
|
||||
| Scalar -> single_param ""
|
||||
| ScalarType -> Printf.sprintf "c%s" an
|
||||
| ScalarType | ScalarTypeOption -> Printf.sprintf "c%s" an
|
||||
| Device -> Printf.sprintf "c%s" an
|
||||
| String -> Printf.sprintf "c%s, c%sLen" an an
|
||||
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| IntList | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
|
@ -383,12 +402,10 @@ module Func = struct
|
|||
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
|
||||
| Tensor -> ""
|
||||
| TensorOption -> ""
|
||||
| Layout ->
|
||||
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
|
||||
| LayoutOption ->
|
||||
| Layout | LayoutOption ->
|
||||
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
|
||||
| Scalar -> ""
|
||||
| ScalarType ->
|
||||
| ScalarType | ScalarTypeOption ->
|
||||
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||
| Device ->
|
||||
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||
|
@ -399,13 +416,7 @@ module Func = struct
|
|||
%sLen := len(%s)\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an an an
|
||||
| IntList ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an
|
||||
| IntListOption ->
|
||||
| IntList | IntListOption ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||
|
@ -494,14 +505,12 @@ module Func = struct
|
|||
let go_arg_type =
|
||||
match arg.arg_type with
|
||||
| Bool -> "bool"
|
||||
| Layout -> "Layout"
|
||||
| LayoutOption -> "Layout"
|
||||
| Layout | LayoutOption -> "Layout"
|
||||
| Int64 -> "int64"
|
||||
| Double -> "float64"
|
||||
| Tensor -> "*Tensor"
|
||||
| TensorOption -> "*Tensor"
|
||||
| IntList -> "[]int64"
|
||||
| IntListOption -> "[]int64"
|
||||
| IntList | IntListOption -> "[]int64"
|
||||
| DoubleList -> "[]float64"
|
||||
| TensorOptList -> "[]*Tensor"
|
||||
| TensorList -> "[]*Tensor"
|
||||
|
@ -510,7 +519,7 @@ module Func = struct
|
|||
(* E.g. `type KindDevice struct{}` *)
|
||||
| TensorOptions -> "gotch.KindDevice"
|
||||
| Scalar -> "*Scalar"
|
||||
| ScalarType -> "gotch.DType"
|
||||
| ScalarType | ScalarTypeOption -> "gotch.DType"
|
||||
| Int64Option -> "[]int64"
|
||||
| DoubleOption -> "[]float64"
|
||||
| Device -> "gotch.Device"
|
||||
|
@ -603,7 +612,7 @@ module Func = struct
|
|||
else Printf.sprintf "%s.ctensor" name
|
||||
| Scalar -> Printf.sprintf "%s.cscalar" name
|
||||
| Bool -> Printf.sprintf "c%s" name
|
||||
| ScalarType -> Printf.sprintf "%s.CInt()" name
|
||||
| ScalarType | ScalarTypeOption -> Printf.sprintf "%s.CInt()" name
|
||||
| Device -> Printf.sprintf "%s.CInt()" name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
|
||||
|
@ -633,11 +642,10 @@ module Func = struct
|
|||
| Tensor -> ""
|
||||
| TensorOption -> ""
|
||||
| Scalar -> ""
|
||||
| ScalarType -> ""
|
||||
| ScalarType | ScalarTypeOption -> ""
|
||||
| Device -> ""
|
||||
| String -> ""
|
||||
| IntList -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||
| IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||
| IntList | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||
| DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||
| Int64Option ->
|
||||
Printf.sprintf
|
||||
|
@ -667,8 +675,7 @@ module Func = struct
|
|||
"var c%s []lib.Ctensor\n\
|
||||
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
||||
an an an an
|
||||
| Layout -> ""
|
||||
| LayoutOption -> ""
|
||||
| Layout | LayoutOption -> ""
|
||||
| TensorOptions -> "" )
|
||||
|> String.concat ~sep:""
|
||||
end
|
||||
|
@ -679,117 +686,103 @@ let read_yaml filename =
|
|||
let funcs =
|
||||
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
|
||||
In_channel.with_file filename ~f:In_channel.input_lines
|
||||
|> List.group ~break:(fun _ l ->
|
||||
String.length l > 0 && Char.( = ) l.[0] '-' )
|
||||
|> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-')
|
||||
|> List.concat_map ~f:(fun lines ->
|
||||
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list
|
||||
)
|
||||
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
|
||||
in
|
||||
printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ;
|
||||
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
|
||||
List.filter_map funcs ~f:(fun yaml ->
|
||||
let map = extract_map yaml in
|
||||
let name = Map.find_exn map "name" |> extract_string in
|
||||
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
||||
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||
let method_of =
|
||||
Map.find_exn map "method_of"
|
||||
|> extract_list |> List.map ~f:extract_string
|
||||
let map = extract_map yaml in
|
||||
let name = Map.find_exn map "name" |> extract_string in
|
||||
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
||||
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||
let method_of =
|
||||
Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
|
||||
in
|
||||
let arguments = Map.find_exn map "arguments" |> extract_list in
|
||||
let returns =
|
||||
let is_tensor returns =
|
||||
let returns = extract_map returns in
|
||||
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
|
||||
String.( = ) return_type "at::Tensor"
|
||||
in
|
||||
let arguments = Map.find_exn map "arguments" |> extract_list in
|
||||
let returns =
|
||||
let is_tensor returns =
|
||||
let returns = extract_map returns in
|
||||
let returns = Map.find_exn map "returns" |> extract_list in
|
||||
if List.is_empty returns
|
||||
then Some `nothing
|
||||
else if List.for_all returns ~f:is_tensor
|
||||
then Some (`fixed (List.length returns))
|
||||
else (
|
||||
match returns with
|
||||
| [ returns ] ->
|
||||
let return_type =
|
||||
Map.find_exn returns "dynamic_type" |> extract_string
|
||||
Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
|
||||
in
|
||||
String.( = ) return_type "at::Tensor"
|
||||
in
|
||||
let returns = Map.find_exn map "returns" |> extract_list in
|
||||
if List.for_all returns ~f:is_tensor then
|
||||
Some (`fixed (List.length returns))
|
||||
else
|
||||
match returns with
|
||||
| [returns] -> (
|
||||
let return_type =
|
||||
Map.find_exn (extract_map returns) "dynamic_type"
|
||||
|> extract_string
|
||||
in
|
||||
match return_type with
|
||||
| "bool" -> Some `bool
|
||||
| "int64_t" -> Some `int64_t
|
||||
| "double" -> Some `double
|
||||
| "at::TensorList"
|
||||
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
|
||||
Some `dynamic
|
||||
| _ -> None )
|
||||
| [] | _ :: _ :: _ -> None
|
||||
in
|
||||
let kind =
|
||||
if List.exists method_of ~f:(String.( = ) "namespace") then
|
||||
Some `function_
|
||||
else if List.exists method_of ~f:(String.( = ) "Tensor") then
|
||||
Some `method_
|
||||
else None
|
||||
in
|
||||
if
|
||||
(not deprecated)
|
||||
&& (not
|
||||
(List.exists excluded_prefixes ~f:(fun prefix ->
|
||||
String.is_prefix name ~prefix )))
|
||||
&& (not
|
||||
(List.exists excluded_suffixes ~f:(fun suffix ->
|
||||
String.is_suffix name ~suffix )))
|
||||
&& not (Set.mem excluded_functions name)
|
||||
then
|
||||
Option.both returns kind
|
||||
|> Option.bind ~f:(fun (returns, kind) ->
|
||||
try
|
||||
let args =
|
||||
List.filter_map arguments ~f:(fun arg ->
|
||||
let arg = extract_map arg in
|
||||
let arg_name =
|
||||
Map.find_exn arg "name" |> extract_string
|
||||
in
|
||||
let arg_type =
|
||||
Map.find_exn arg "dynamic_type" |> extract_string
|
||||
in
|
||||
let is_nullable =
|
||||
Map.find arg "is_nullable"
|
||||
|> Option.value_map ~default:false ~f:extract_bool
|
||||
in
|
||||
let default_value =
|
||||
Map.find arg "default" |> Option.map ~f:extract_string
|
||||
in
|
||||
match Func.arg_type_of_string arg_type ~is_nullable with
|
||||
| Some Scalar
|
||||
when Option.is_some default_value && not is_nullable
|
||||
->
|
||||
None
|
||||
| Some TensorOptions
|
||||
when Option.is_some default_value
|
||||
&& Set.mem no_tensor_options name ->
|
||||
None
|
||||
| Some arg_type ->
|
||||
let arg_name =
|
||||
match (arg_name, arg_type) with
|
||||
| "self", Scalar -> "self_scalar"
|
||||
| _, _ -> arg_name
|
||||
in
|
||||
Some {Func.arg_name; arg_type; default_value}
|
||||
| None ->
|
||||
if Option.is_some default_value then None
|
||||
else raise Not_a_simple_arg )
|
||||
(match return_type with
|
||||
| "bool" -> Some `bool
|
||||
| "int64_t" -> Some `int64_t
|
||||
| "double" -> Some `double
|
||||
| "at::TensorList" | "dynamic_type: const c10::List<c10::optional<Tensor>> &"
|
||||
-> Some `dynamic
|
||||
| _ -> None)
|
||||
| [] | _ :: _ :: _ -> None)
|
||||
in
|
||||
let kind =
|
||||
if List.exists method_of ~f:(String.( = ) "namespace")
|
||||
then Some `function_
|
||||
else if List.exists method_of ~f:(String.( = ) "Tensor")
|
||||
then Some `method_
|
||||
else None
|
||||
in
|
||||
if (not deprecated)
|
||||
&& (not
|
||||
(List.exists excluded_prefixes ~f:(fun prefix ->
|
||||
String.is_prefix name ~prefix)))
|
||||
&& (not
|
||||
(List.exists excluded_suffixes ~f:(fun suffix ->
|
||||
String.is_suffix name ~suffix)))
|
||||
&& not (Set.mem excluded_functions name)
|
||||
then
|
||||
Option.both returns kind
|
||||
|> Option.bind ~f:(fun (returns, kind) ->
|
||||
try
|
||||
let args ~with_optional_scalar_args =
|
||||
List.filter_map arguments ~f:(fun arg ->
|
||||
let arg = extract_map arg in
|
||||
let arg_name = Map.find_exn arg "name" |> extract_string in
|
||||
let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in
|
||||
let is_nullable =
|
||||
Map.find arg "is_nullable"
|
||||
|> Option.value_map ~default:false ~f:extract_bool
|
||||
in
|
||||
Some
|
||||
{ Func.name
|
||||
; operator_name
|
||||
; overload_name
|
||||
; args
|
||||
; returns
|
||||
; kind }
|
||||
with Not_a_simple_arg -> None )
|
||||
else None )
|
||||
let default_value =
|
||||
Map.find arg "default" |> Option.map ~f:extract_string
|
||||
in
|
||||
match Func.arg_type_of_string arg_type ~is_nullable with
|
||||
| Some Scalar when Option.is_some default_value && not is_nullable ->
|
||||
if with_optional_scalar_args
|
||||
then Some { Func.arg_name; arg_type = Scalar; default_value }
|
||||
else None
|
||||
| Some TensorOptions
|
||||
when Option.is_some default_value && Set.mem no_tensor_options name ->
|
||||
None
|
||||
| Some arg_type ->
|
||||
let arg_name =
|
||||
match arg_name, arg_type with
|
||||
| "self", Scalar -> "self_scalar"
|
||||
| _, _ -> arg_name
|
||||
in
|
||||
Some { Func.arg_name; arg_type; default_value }
|
||||
| None ->
|
||||
if Option.is_some default_value then None else raise Not_a_simple_arg)
|
||||
in
|
||||
let args =
|
||||
args ~with_optional_scalar_args:(Set.mem with_optional_scalar_args name)
|
||||
in
|
||||
Some [ { Func.name; operator_name; overload_name; args; returns; kind } ]
|
||||
with
|
||||
| Not_a_simple_arg -> None)
|
||||
else None)
|
||||
|
||||
let p out_channel s =
|
||||
Printf.ksprintf
|
||||
|
@ -803,72 +796,71 @@ let print_inline out_channel s =
|
|||
|
||||
let write_cpp funcs filename =
|
||||
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
|
||||
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
|
||||
let pc s = p out_cpp s in
|
||||
let ph s = p out_h s in
|
||||
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
pc "";
|
||||
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
ph "";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||
let c_typed_args_list = Func.c_typed_args_list func in
|
||||
match func.returns with
|
||||
| `nothing ->
|
||||
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " %s;" (Func.c_call func);
|
||||
pc " )";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "void atg_%s(%s);" exported_name c_typed_args_list
|
||||
| `fixed ntensors ->
|
||||
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
if ntensors = 1
|
||||
then pc " out__[0] = new torch::Tensor(outputs__);"
|
||||
else
|
||||
for i = 0 to ntensors - 1 do
|
||||
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
|
||||
done;
|
||||
pc " )";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
|
||||
| `dynamic ->
|
||||
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
(* the returned type is a C++ vector of tensors *)
|
||||
pc " int sz = outputs__.size();";
|
||||
pc
|
||||
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
|
||||
sizeof(torch::Tensor*));";
|
||||
pc " for (int i = 0; i < sz; ++i)";
|
||||
pc " out__[i] = new torch::Tensor(outputs__[i]);";
|
||||
pc " out__[sz] = nullptr;";
|
||||
pc " return out__;";
|
||||
pc " )";
|
||||
pc " return nullptr;";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
||||
| (`bool | `int64_t | `double) as returns ->
|
||||
let c_type =
|
||||
match returns with
|
||||
| `bool -> "int"
|
||||
| `int64_t -> "int64_t"
|
||||
| `double -> "double"
|
||||
in
|
||||
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " return %s;" (Func.c_call func);
|
||||
pc " )";
|
||||
pc " return 0;";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
|
||||
|
||||
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
|
||||
let pc s = p out_cpp s in
|
||||
let ph s = p out_h s in
|
||||
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
pc "";
|
||||
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||
ph "";
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||
let c_typed_args_list = Func.c_typed_args_list func in
|
||||
match func.returns with
|
||||
| `nothing ->
|
||||
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " %s;" (Func.c_call func);
|
||||
pc " )";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "void atg_%s(%s);" exported_name c_typed_args_list
|
||||
| `fixed ntensors ->
|
||||
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
if ntensors = 1
|
||||
then pc " out__[0] = new torch::Tensor(outputs__);"
|
||||
else
|
||||
for i = 0 to ntensors - 1 do
|
||||
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
|
||||
done;
|
||||
pc " )";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
|
||||
| `dynamic ->
|
||||
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||
(* the returned type is a C++ vector of tensors *)
|
||||
pc " int sz = outputs__.size();";
|
||||
pc
|
||||
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
|
||||
sizeof(torch::Tensor*));";
|
||||
pc " for (int i = 0; i < sz; ++i)";
|
||||
pc " out__[i] = new torch::Tensor(outputs__[i]);";
|
||||
pc " out__[sz] = nullptr;";
|
||||
pc " return out__;";
|
||||
pc " )";
|
||||
pc " return nullptr;";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
||||
| (`bool | `int64_t | `double) as returns ->
|
||||
let c_type =
|
||||
match returns with
|
||||
| `bool -> "int"
|
||||
| `int64_t -> "int64_t"
|
||||
| `double -> "double"
|
||||
in
|
||||
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
|
||||
pc " PROTECT(";
|
||||
pc " return %s;" (Func.c_call func);
|
||||
pc " )";
|
||||
pc " return 0;";
|
||||
pc "}";
|
||||
pc "";
|
||||
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
|
||||
|
||||
let write_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
|
@ -1402,54 +1394,43 @@ let methods =
|
|||
; c "to" [ ca "self" Tensor; ca "device" Device ]
|
||||
]
|
||||
|
||||
|
||||
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
||||
~wrapper_filename =
|
||||
let funcs = read_yaml yaml_filename in
|
||||
let funcs = read_yaml yaml_filename |> List.concat in
|
||||
let funcs = methods @ funcs in
|
||||
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
||||
printf "Generating code for %d functions.\n%!" (List.length funcs);
|
||||
(* Generate some unique names for overloaded functions. *)
|
||||
let funcs =
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
|
||||
List.map funcs ~f:(fun func -> Func.operator_name func, func)
|
||||
|> Map.of_alist_multi (module String)
|
||||
|> Map.to_alist
|
||||
|> List.concat_map ~f:(fun (name, funcs) ->
|
||||
match funcs with
|
||||
| [] -> assert false
|
||||
| [func] -> [(name, func)]
|
||||
| funcs ->
|
||||
let has_empty_overload =
|
||||
List.exists funcs ~f:(fun (func : Func.t) ->
|
||||
String.is_empty func.overload_name )
|
||||
in
|
||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||
match
|
||||
Int.compare (String.length f1.name)
|
||||
(String.length f2.name)
|
||||
with
|
||||
| 0 ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args)
|
||||
| cmp -> cmp )
|
||||
|> List.mapi ~f:(fun index (func : Func.t) ->
|
||||
let operator_name =
|
||||
String.lowercase func.operator_name
|
||||
in
|
||||
let overload_name =
|
||||
String.lowercase func.overload_name
|
||||
in
|
||||
let name =
|
||||
if
|
||||
String.is_empty overload_name
|
||||
|| (index = 0 && not has_empty_overload)
|
||||
then operator_name
|
||||
else if String.is_suffix operator_name ~suffix:"_" then
|
||||
operator_name ^ overload_name ^ "_"
|
||||
else operator_name ^ "_" ^ overload_name
|
||||
in
|
||||
(name, func) ) )
|
||||
match funcs with
|
||||
| [] -> assert false
|
||||
| [ func ] -> [ name, func ]
|
||||
| funcs ->
|
||||
let has_empty_overload =
|
||||
List.exists funcs ~f:(fun (func : Func.t) ->
|
||||
String.is_empty func.overload_name)
|
||||
in
|
||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||
match Int.compare (String.length f1.name) (String.length f2.name) with
|
||||
| 0 -> Int.compare (List.length f1.args) (List.length f2.args)
|
||||
| cmp -> cmp)
|
||||
|> List.mapi ~f:(fun index (func : Func.t) ->
|
||||
let operator_name = Func.operator_name func in
|
||||
let overload_name = String.lowercase func.overload_name in
|
||||
let name =
|
||||
if String.is_empty overload_name || (index = 0 && not has_empty_overload)
|
||||
then operator_name
|
||||
else if String.is_suffix operator_name ~suffix:"_"
|
||||
then operator_name ^ overload_name ^ "_"
|
||||
else operator_name ^ "_" ^ overload_name
|
||||
in
|
||||
name, func))
|
||||
|> Map.of_alist_exn (module String)
|
||||
in
|
||||
write_cpp funcs cpp_filename ;
|
||||
write_cpp funcs cpp_filename;
|
||||
write_ffi funcs ffi_filename ;
|
||||
write_must_wrapper funcs must_wrapper_filename ;
|
||||
write_wrapper funcs wrapper_filename
|
||||
|
|
|
@ -2666,10 +2666,10 @@ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
|||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice)
|
||||
}
|
||||
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32){
|
||||
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32){
|
||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||
C.atg_arange_start_step(ptr, start , end , coptionsKind, coptionsDevice)
|
||||
C.atg_arange_start_step(ptr, start , end , step , coptionsKind, coptionsDevice)
|
||||
}
|
||||
func AtgArccos(ptr *Ctensor, self Ctensor){
|
||||
C.atg_arccos(ptr, self)
|
||||
|
@ -3004,8 +3004,8 @@ cdivisorOverrideVal := *(*C.int64_t)(unsafe.Pointer(&divisorOverrideVal))
|
|||
cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull))
|
||||
C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull)
|
||||
}
|
||||
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
|
||||
C.atg_baddbmm(ptr, self, batch1, batch2)
|
||||
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor, beta Cscalar, alpha Cscalar){
|
||||
C.atg_baddbmm(ptr, self, batch1, batch2, beta , alpha )
|
||||
}
|
||||
func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
|
||||
C.atg_baddbmm_(ptr, self, batch1, batch2)
|
||||
|
|
|
@ -2068,7 +2068,7 @@ void atg__slow_conv2d_backward(tensor *out__, tensor grad_input, tensor grad_wei
|
|||
|
||||
void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(std::get<0>(outputs__));
|
||||
out__[1] = new torch::Tensor(std::get<1>(outputs__));
|
||||
)
|
||||
|
@ -2223,28 +2223,28 @@ void atg__sparse_csc_tensor_unsafe(tensor *out__, tensor ccol_indices, tensor ro
|
|||
|
||||
void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -2279,7 +2279,7 @@ void atg__sparse_log_softmax_backward_data_out(tensor *out__, tensor out, tensor
|
|||
|
||||
void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_log_softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -2336,7 +2336,7 @@ void atg__sparse_softmax_backward_data_out(tensor *out__, tensor out, tensor gra
|
|||
|
||||
void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_sparse_softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_sparse_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -2629,14 +2629,14 @@ tensor *atg__to_cpu(tensor *tensors_data, int tensors_len) {
|
|||
|
||||
void atg__to_dense(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = self->_to_dense(at::ScalarType(dtype));
|
||||
auto outputs__ = self->_to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::_to_dense_out(*out, *self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::_to_dense_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -3640,9 +3640,9 @@ void atg_arange_start(tensor *out__, scalar start, scalar end, int options_kind,
|
|||
)
|
||||
}
|
||||
|
||||
void atg_arange_start_step(tensor *out__, scalar start, scalar end, int options_kind, int options_device) {
|
||||
void atg_arange_start_step(tensor *out__, scalar start, scalar end, scalar step, int options_kind, int options_device) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::arange(*start, *end, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
|
||||
auto outputs__ = torch::arange(*start, *end, *step, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -4120,9 +4120,9 @@ void atg_avg_pool3d_out(tensor *out__, tensor out, tensor self, int64_t *kernel_
|
|||
)
|
||||
}
|
||||
|
||||
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2) {
|
||||
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2);
|
||||
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2, *beta, *alpha);
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -5910,14 +5910,14 @@ void atg_cummin_out(tensor *out__, tensor values, tensor indices, tensor self, i
|
|||
|
||||
void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::cumprod(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::cumprod(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = self->cumprod_(dim, at::ScalarType(dtype));
|
||||
auto outputs__ = self->cumprod_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -5931,28 +5931,28 @@ void atg_cumprod_backward(tensor *out__, tensor grad, tensor input, int64_t dim,
|
|||
|
||||
void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::cumprod_out(*out, *self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::cumprod_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::cumsum(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::cumsum(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = self->cumsum_(dim, at::ScalarType(dtype));
|
||||
auto outputs__ = self->cumsum_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::cumsum_out(*out, *self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::cumsum_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -9912,28 +9912,28 @@ void atg_linalg_multi_dot_out(tensor *out__, tensor out, tensor *tensors_data, i
|
|||
|
||||
void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -10314,14 +10314,14 @@ void atg_log_sigmoid_out(tensor *out__, tensor out, tensor self) {
|
|||
|
||||
void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::log_softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::log_softmax_out(*out, *self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::log_softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -10953,21 +10953,21 @@ void atg_maximum_out(tensor *out__, tensor out, tensor self, tensor other) {
|
|||
|
||||
void atg_mean(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::mean(*self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::mean(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -11742,14 +11742,14 @@ void atg_nan_to_num_out(tensor *out__, tensor out, tensor self, double nan_v, ui
|
|||
|
||||
void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -11814,14 +11814,14 @@ void atg_nanquantile_scalar_out(tensor *out__, tensor out, tensor self, double q
|
|||
|
||||
void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -11961,14 +11961,14 @@ void atg_native_norm_out(tensor *out__, tensor out, tensor self) {
|
|||
|
||||
void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -12702,28 +12702,28 @@ void atg_prelu(tensor *out__, tensor self, tensor weight) {
|
|||
|
||||
void atg_prod(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::prod(*self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::prod(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::prod_out(*out, *self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::prod_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -14593,14 +14593,14 @@ void atg_soft_margin_loss_out(tensor *out__, tensor out, tensor self, tensor tar
|
|||
|
||||
void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::softmax_out(*out, *self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -15528,7 +15528,7 @@ void atg_special_log_ndtr_out(tensor *out__, tensor out, tensor self) {
|
|||
|
||||
void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::special_log_softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::special_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -15913,7 +15913,7 @@ void atg_special_sinc_out(tensor *out__, tensor out, tensor self) {
|
|||
|
||||
void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::special_softmax(*self, dim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::special_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -16437,28 +16437,28 @@ void atg_subtract_scalar_(tensor *out__, tensor self, scalar other) {
|
|||
|
||||
void atg_sum(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::sum(*self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::sum(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
||||
void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::sum_out(*out, *self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::sum_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -16732,7 +16732,7 @@ void atg_to(tensor *out__, tensor self, int device) {
|
|||
|
||||
void atg_to_dense(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = self->to_dense(at::ScalarType(dtype));
|
||||
auto outputs__ = self->to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -16767,7 +16767,7 @@ void atg_to_dtype_layout(tensor *out__, tensor self, int options_kind, int optio
|
|||
|
||||
void atg_to_mkldnn(tensor *out__, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = self->to_mkldnn(at::ScalarType(dtype));
|
||||
auto outputs__ = self->to_mkldnn(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
@ -16781,7 +16781,7 @@ void atg_to_mkldnn_backward(tensor *out__, tensor grad, tensor input) {
|
|||
|
||||
void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||
PROTECT(
|
||||
auto outputs__ = torch::to_mkldnn_out(*out, *self, at::ScalarType(dtype));
|
||||
auto outputs__ = torch::to_mkldnn_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||
out__[0] = new torch::Tensor(outputs__);
|
||||
)
|
||||
}
|
||||
|
|
|
@ -497,7 +497,7 @@ void atg_any_dim(tensor *, tensor self, int64_t dim, int keepdim);
|
|||
void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim);
|
||||
void atg_arange(tensor *, scalar end, int options_kind, int options_device);
|
||||
void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
||||
void atg_arange_start_step(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
||||
void atg_arange_start_step(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
|
||||
void atg_arccos(tensor *, tensor self);
|
||||
void atg_arccos_(tensor *, tensor self);
|
||||
void atg_arccos_out(tensor *, tensor out, tensor self);
|
||||
|
@ -563,7 +563,7 @@ void atg_avg_pool3d(tensor *, tensor self, int64_t *kernel_size_data, int kernel
|
|||
void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||
void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||
void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2);
|
||||
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha);
|
||||
void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2);
|
||||
void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2);
|
||||
void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device);
|
||||
|
|
|
@ -3961,9 +3961,9 @@ func MustArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, option
|
|||
return retVal
|
||||
}
|
||||
|
||||
func MustArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
|
||||
func MustArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
|
||||
|
||||
retVal, err := ArangeStartStep(start, end, optionsKind, optionsDevice)
|
||||
retVal, err := ArangeStartStep(start, end, step, optionsKind, optionsDevice)
|
||||
if err != nil { log.Fatal(err) }
|
||||
|
||||
return retVal
|
||||
|
@ -4465,9 +4465,9 @@ func(ts *Tensor) MustAvgPool3dOut(out *Tensor, kernelSize []int64, stride []int6
|
|||
return retVal
|
||||
}
|
||||
|
||||
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor) {
|
||||
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor) {
|
||||
|
||||
retVal, err := ts.Baddbmm(batch1, batch2, del)
|
||||
retVal, err := ts.Baddbmm(batch1, batch2, beta, alpha, del)
|
||||
if err != nil { log.Fatal(err) }
|
||||
|
||||
return retVal
|
||||
|
|
|
@ -9327,10 +9327,10 @@ func ArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDev
|
|||
// func.returns = `fixed 1`:
|
||||
// --------------------------
|
||||
|
||||
func ArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
|
||||
func ArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, optionsKind.CInt(), optionsDevice.CInt())
|
||||
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, step.cscalar, optionsKind.CInt(), optionsDevice.CInt())
|
||||
if err = TorchErr(); err != nil {
|
||||
err = fmt.Errorf("ArangeStartStep() failed: %w", err)
|
||||
return retVal, err
|
||||
|
@ -10585,11 +10585,11 @@ var cdivisorOverrideVal int64 = 0
|
|||
// func.returns = `fixed 1`:
|
||||
// --------------------------
|
||||
|
||||
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor, err error) {
|
||||
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor, err error) {
|
||||
if del { defer ts.MustDrop() }
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor)
|
||||
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor, beta.cscalar, alpha.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
err = fmt.Errorf("Baddbmm() failed: %w", err)
|
||||
return retVal, err
|
||||
|
|
|
@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
|
|||
}
|
||||
|
||||
rgbTs := x.MustUnbind(-3, false)
|
||||
r := &rgbTs[0]
|
||||
g := &rgbTs[1]
|
||||
b := &rgbTs[2]
|
||||
r := rgbTs[0]
|
||||
g := rgbTs[1]
|
||||
b := rgbTs[2]
|
||||
|
||||
// This implementation closely follows the TF one:
|
||||
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
||||
|
@ -453,7 +453,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
|||
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
|
||||
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
|
||||
|
||||
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4})
|
||||
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
|
||||
|
||||
// Delete intermediate tensors
|
||||
h.MustDrop()
|
||||
|
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
|
|||
x2 := x1T.Idx(wNar)
|
||||
x1T.MustDrop()
|
||||
out := x2.MustT(true)
|
||||
chans[i] = *out
|
||||
chans[i] = out
|
||||
}
|
||||
|
||||
cropTs := ts.MustStack(chans, 0)
|
||||
|
|
Loading…
Reference in New Issue
Block a user