fixed incorrect APIs generation
This commit is contained in:
parent
bdf252d831
commit
e9278816b2
647
gen/gen.ml
647
gen/gen.ml
|
@ -82,6 +82,14 @@ let no_tensor_options =
|
||||||
; "randint_like"
|
; "randint_like"
|
||||||
; "randn_like" ]
|
; "randn_like" ]
|
||||||
|
|
||||||
|
(* By default, scalar argument that have a default value are not available on
|
||||||
|
the Rust side, this is to preserve the Rust api simplicity assuming that
|
||||||
|
these scalars arguments are not often overriden.
|
||||||
|
Adding function name [foo] in [with_optional_scalar_args] results in having
|
||||||
|
explicit scalar arguments even if a default is present. *)
|
||||||
|
let with_optional_scalar_args = Set.of_list (module String) [ "arange"; "baddbmm" ]
|
||||||
|
|
||||||
|
|
||||||
(*
|
(*
|
||||||
* let prefixed_functions =
|
* let prefixed_functions =
|
||||||
* Set.of_list
|
* Set.of_list
|
||||||
|
@ -133,24 +141,26 @@ module Func = struct
|
||||||
| Double
|
| Double
|
||||||
| DoubleOption
|
| DoubleOption
|
||||||
| Tensor
|
| Tensor
|
||||||
| TensorOption
|
| TensorOption (* Tensor.t option *)
|
||||||
(* Tensor.t option *)
|
|
||||||
| IntList
|
| IntList
|
||||||
| IntListOption
|
| IntListOption
|
||||||
| DoubleList
|
| DoubleList
|
||||||
| TensorOptList
|
| TensorOptList
|
||||||
| TensorList
|
| TensorList
|
||||||
| TensorOptions
|
| TensorOptions (* Tensor kind and device *)
|
||||||
(* Tensor kind and device *)
|
|
||||||
| Scalar
|
| Scalar
|
||||||
| ScalarType
|
| ScalarType
|
||||||
|
| ScalarTypeOption
|
||||||
| Device
|
| Device
|
||||||
| String
|
| String
|
||||||
| Layout
|
| Layout
|
||||||
| LayoutOption
|
| LayoutOption
|
||||||
|
|
||||||
type arg =
|
type arg =
|
||||||
{arg_name: string; arg_type: arg_type; default_value: string option}
|
{ arg_name: string
|
||||||
|
; arg_type: arg_type
|
||||||
|
; default_value: string option
|
||||||
|
}
|
||||||
|
|
||||||
(* `Func` type *)
|
(* `Func` type *)
|
||||||
type t =
|
type t =
|
||||||
|
@ -160,7 +170,8 @@ module Func = struct
|
||||||
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
|
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
|
||||||
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing]
|
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing]
|
||||||
; (* number of tensors that are returned *)
|
; (* number of tensors that are returned *)
|
||||||
kind: [`function_ | `method_] }
|
kind: [`function_ | `method_]
|
||||||
|
}
|
||||||
|
|
||||||
let arg_type_of_string str ~is_nullable =
|
let arg_type_of_string str ~is_nullable =
|
||||||
match String.lowercase str with
|
match String.lowercase str with
|
||||||
|
@ -175,108 +186,111 @@ module Func = struct
|
||||||
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
|
| "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList
|
||||||
| "at::device" -> Some Device
|
| "at::device" -> Some Device
|
||||||
| "const at::scalar &" | "at::scalar" -> Some Scalar
|
| "const at::scalar &" | "at::scalar" -> Some Scalar
|
||||||
| "at::scalartype" -> Some ScalarType
|
| "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType
|
||||||
| "c10::string_view" -> Some String
|
| "c10::string_view" -> Some String
|
||||||
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
|
| "at::layout" -> Some (if is_nullable then LayoutOption else Layout)
|
||||||
| _ -> None
|
| _ -> None
|
||||||
|
|
||||||
|
|
||||||
let c_typed_args_list t =
|
let c_typed_args_list t =
|
||||||
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
|
List.map t.args ~f:(fun { arg_name; arg_type; _ } ->
|
||||||
match arg_type with
|
match arg_type with
|
||||||
| IntList | IntListOption ->
|
| IntList | IntListOption ->
|
||||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||||
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
|
| DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name
|
||||||
| TensorOptList | TensorList ->
|
| TensorOptList | TensorList ->
|
||||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||||
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
| TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||||
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
|
| String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name
|
||||||
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
|
| Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name
|
||||||
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
|
| DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name
|
||||||
| otherwise ->
|
| otherwise ->
|
||||||
let simple_type_cstring =
|
let simple_type_cstring =
|
||||||
match otherwise with
|
match otherwise with
|
||||||
| Bool -> "int"
|
| Bool -> "int"
|
||||||
| Int64 -> "int64_t"
|
| Int64 -> "int64_t"
|
||||||
| Double -> "double"
|
| Double -> "double"
|
||||||
| Tensor -> "tensor"
|
| Tensor -> "tensor"
|
||||||
| TensorOption -> "tensor"
|
| TensorOption -> "tensor"
|
||||||
| ScalarType -> "int"
|
| ScalarType -> "int"
|
||||||
| Device -> "int"
|
| ScalarTypeOption -> "int"
|
||||||
| Scalar -> "scalar"
|
| Device -> "int"
|
||||||
| Layout | LayoutOption -> "int8_t"
|
| Scalar -> "scalar"
|
||||||
| Int64Option
|
| Layout | LayoutOption -> "int8_t"
|
||||||
| DoubleOption
|
| Int64Option
|
||||||
| String
|
| DoubleOption
|
||||||
| IntList
|
| String
|
||||||
| IntListOption
|
| IntList
|
||||||
| DoubleList
|
| IntListOption
|
||||||
| TensorOptList
|
| DoubleList
|
||||||
| TensorList
|
| TensorOptList
|
||||||
| TensorOptions -> assert false
|
| TensorList
|
||||||
in
|
| TensorOptions -> assert false
|
||||||
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|
in
|
||||||
|
Printf.sprintf "%s %s" simple_type_cstring arg_name)
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
let c_args_list args =
|
let c_args_list args =
|
||||||
List.map args ~f:(fun { arg_name; arg_type; _ } ->
|
List.map args ~f:(fun { arg_name; arg_type; _ } ->
|
||||||
match arg_type with
|
match arg_type with
|
||||||
| Scalar | Tensor -> "*" ^ arg_name
|
| Scalar | Tensor -> "*" ^ arg_name
|
||||||
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
|
| Layout -> Printf.sprintf "static_cast<at::Layout>(%s)" arg_name
|
||||||
| LayoutOption ->
|
| LayoutOption ->
|
||||||
Printf.sprintf
|
Printf.sprintf
|
||||||
"(%s == -1 ? c10::nullopt : \
|
"(%s == -1 ? c10::nullopt : \
|
||||||
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
|
c10::optional<at::Layout>(static_cast<at::Layout>(%s)))"
|
||||||
arg_name
|
arg_name
|
||||||
arg_name
|
arg_name
|
||||||
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
|
| TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name
|
||||||
| Bool -> "(bool)" ^ arg_name
|
| Bool -> "(bool)" ^ arg_name
|
||||||
| IntList ->
|
| IntList -> Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
|
||||||
Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name
|
| IntListOption ->
|
||||||
| IntListOption ->
|
Printf.sprintf
|
||||||
Printf.sprintf
|
"%s_data == nullptr ? c10::nullopt : \
|
||||||
"%s_data == nullptr ? c10::nullopt : \
|
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
|
||||||
c10::optional<torch::IntArrayRef>(torch::IntArrayRef(%s_data, %s_len))"
|
arg_name
|
||||||
arg_name
|
arg_name
|
||||||
arg_name
|
arg_name
|
||||||
arg_name
|
| DoubleList ->
|
||||||
| DoubleList ->
|
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
|
||||||
Printf.sprintf "at::ArrayRef<double>(%s_data, %s_len)" arg_name arg_name
|
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||||
| String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
| TensorOptList ->
|
||||||
| TensorOptList ->
|
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
|
||||||
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name
|
| TensorList -> Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
|
||||||
| TensorList ->
|
| TensorOptions ->
|
||||||
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name
|
Printf.sprintf
|
||||||
| TensorOptions ->
|
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
|
||||||
Printf.sprintf
|
arg_name
|
||||||
"at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))"
|
arg_name
|
||||||
arg_name
|
| Int64Option ->
|
||||||
arg_name
|
Printf.sprintf
|
||||||
| Int64Option ->
|
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
|
||||||
Printf.sprintf
|
arg_name
|
||||||
"%s_null ? c10::nullopt : c10::optional<int64_t>(%s_v)"
|
arg_name
|
||||||
arg_name
|
| DoubleOption ->
|
||||||
arg_name
|
Printf.sprintf
|
||||||
| DoubleOption ->
|
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
|
||||||
Printf.sprintf
|
arg_name
|
||||||
"%s_null ? c10::nullopt : c10::optional<double>(%s_v)"
|
arg_name
|
||||||
arg_name
|
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
|
||||||
arg_name
|
| ScalarTypeOption ->
|
||||||
| ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name
|
Printf.sprintf
|
||||||
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
|
"%s < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(%s))"
|
||||||
| _ -> arg_name)
|
arg_name
|
||||||
|
arg_name
|
||||||
|
| Device -> Printf.sprintf "device_of_int(%s)" arg_name
|
||||||
|
| _ -> arg_name)
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
|
|
||||||
let c_call t =
|
let c_call t =
|
||||||
match t.kind with
|
match t.kind with
|
||||||
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
|
| `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args)
|
||||||
| `method_ -> (
|
| `method_ ->
|
||||||
match t.args with
|
(match t.args with
|
||||||
| head :: tail ->
|
| head :: tail ->
|
||||||
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
|
Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail)
|
||||||
| [] ->
|
| [] ->
|
||||||
Printf.failwithf "Method calls should have at least one argument %s"
|
Printf.failwithf "Method calls should have at least one argument %s" t.name ())
|
||||||
t.name () )
|
|
||||||
|
|
||||||
(*
|
(*
|
||||||
let replace_map =
|
let replace_map =
|
||||||
|
@ -289,6 +303,15 @@ module Func = struct
|
||||||
; ("to_device", "to_device_") ]
|
; ("to_device", "to_device_") ]
|
||||||
*)
|
*)
|
||||||
|
|
||||||
|
let operator_name t =
|
||||||
|
match String.lowercase t.operator_name with
|
||||||
|
| "scatter_reduce" ->
|
||||||
|
(* scatter_reduce is both an operator name and also obtained from the
|
||||||
|
scatter operator when using the reduce overload. *)
|
||||||
|
"_scatter_reduce"
|
||||||
|
| "scatter_reduce_" -> "_scatter_reduce_"
|
||||||
|
| other -> other
|
||||||
|
|
||||||
let is_method t =
|
let is_method t =
|
||||||
List.exists t.args ~f:(fun arg ->
|
List.exists t.args ~f:(fun arg ->
|
||||||
match arg.arg_name with "self" -> true | _ -> false )
|
match arg.arg_name with "self" -> true | _ -> false )
|
||||||
|
@ -321,18 +344,16 @@ module Func = struct
|
||||||
let single_param = Printf.sprintf "%s %s" an in
|
let single_param = Printf.sprintf "%s %s" an in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> single_param "int32"
|
| Bool -> single_param "int32"
|
||||||
| Layout -> single_param "int8"
|
| Layout | LayoutOption -> single_param "int8"
|
||||||
| LayoutOption -> single_param "int8"
|
|
||||||
| Int64 -> single_param "int64"
|
| Int64 -> single_param "int64"
|
||||||
| Double -> single_param "float64"
|
| Double -> single_param "float64"
|
||||||
| Tensor -> single_param "Ctensor"
|
| Tensor -> single_param "Ctensor"
|
||||||
| TensorOption -> single_param "Ctensor"
|
| TensorOption -> single_param "Ctensor"
|
||||||
| Scalar -> single_param "Cscalar"
|
| Scalar -> single_param "Cscalar"
|
||||||
| ScalarType -> single_param "int32"
|
| ScalarType | ScalarTypeOption -> single_param "int32"
|
||||||
| Device -> single_param "int32"
|
| Device -> single_param "int32"
|
||||||
| String -> single_param "string"
|
| String -> single_param "string"
|
||||||
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
|
| IntList | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||||
| IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an
|
|
||||||
| DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an
|
| DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an
|
||||||
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||||
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||||
|
@ -353,14 +374,12 @@ module Func = struct
|
||||||
| Double -> Printf.sprintf "c%s" an
|
| Double -> Printf.sprintf "c%s" an
|
||||||
| Tensor -> Printf.sprintf "%s" an
|
| Tensor -> Printf.sprintf "%s" an
|
||||||
| TensorOption -> Printf.sprintf "%s" an
|
| TensorOption -> Printf.sprintf "%s" an
|
||||||
| Layout -> Printf.sprintf "c%s" an
|
| Layout | LayoutOption -> Printf.sprintf "c%s" an
|
||||||
| LayoutOption -> Printf.sprintf "c%s" an
|
|
||||||
| Scalar -> single_param ""
|
| Scalar -> single_param ""
|
||||||
| ScalarType -> Printf.sprintf "c%s" an
|
| ScalarType | ScalarTypeOption -> Printf.sprintf "c%s" an
|
||||||
| Device -> Printf.sprintf "c%s" an
|
| Device -> Printf.sprintf "c%s" an
|
||||||
| String -> Printf.sprintf "c%s, c%sLen" an an
|
| String -> Printf.sprintf "c%s, c%sLen" an an
|
||||||
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
| IntList | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
| IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
|
||||||
| DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
| DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||||
|
@ -383,12 +402,10 @@ module Func = struct
|
||||||
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
|
Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an
|
||||||
| Tensor -> ""
|
| Tensor -> ""
|
||||||
| TensorOption -> ""
|
| TensorOption -> ""
|
||||||
| Layout ->
|
| Layout | LayoutOption ->
|
||||||
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
|
|
||||||
| LayoutOption ->
|
|
||||||
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
|
Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an
|
||||||
| Scalar -> ""
|
| Scalar -> ""
|
||||||
| ScalarType ->
|
| ScalarType | ScalarTypeOption ->
|
||||||
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||||
| Device ->
|
| Device ->
|
||||||
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an
|
||||||
|
@ -399,13 +416,7 @@ module Func = struct
|
||||||
%sLen := len(%s)\n\
|
%sLen := len(%s)\n\
|
||||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||||
an an an an an an
|
an an an an an an
|
||||||
| IntList ->
|
| IntList | IntListOption ->
|
||||||
Printf.sprintf
|
|
||||||
"\n\
|
|
||||||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
|
||||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
|
||||||
an an an an
|
|
||||||
| IntListOption ->
|
|
||||||
Printf.sprintf
|
Printf.sprintf
|
||||||
"\n\
|
"\n\
|
||||||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||||
|
@ -494,14 +505,12 @@ module Func = struct
|
||||||
let go_arg_type =
|
let go_arg_type =
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> "bool"
|
| Bool -> "bool"
|
||||||
| Layout -> "Layout"
|
| Layout | LayoutOption -> "Layout"
|
||||||
| LayoutOption -> "Layout"
|
|
||||||
| Int64 -> "int64"
|
| Int64 -> "int64"
|
||||||
| Double -> "float64"
|
| Double -> "float64"
|
||||||
| Tensor -> "*Tensor"
|
| Tensor -> "*Tensor"
|
||||||
| TensorOption -> "*Tensor"
|
| TensorOption -> "*Tensor"
|
||||||
| IntList -> "[]int64"
|
| IntList | IntListOption -> "[]int64"
|
||||||
| IntListOption -> "[]int64"
|
|
||||||
| DoubleList -> "[]float64"
|
| DoubleList -> "[]float64"
|
||||||
| TensorOptList -> "[]*Tensor"
|
| TensorOptList -> "[]*Tensor"
|
||||||
| TensorList -> "[]*Tensor"
|
| TensorList -> "[]*Tensor"
|
||||||
|
@ -510,7 +519,7 @@ module Func = struct
|
||||||
(* E.g. `type KindDevice struct{}` *)
|
(* E.g. `type KindDevice struct{}` *)
|
||||||
| TensorOptions -> "gotch.KindDevice"
|
| TensorOptions -> "gotch.KindDevice"
|
||||||
| Scalar -> "*Scalar"
|
| Scalar -> "*Scalar"
|
||||||
| ScalarType -> "gotch.DType"
|
| ScalarType | ScalarTypeOption -> "gotch.DType"
|
||||||
| Int64Option -> "[]int64"
|
| Int64Option -> "[]int64"
|
||||||
| DoubleOption -> "[]float64"
|
| DoubleOption -> "[]float64"
|
||||||
| Device -> "gotch.Device"
|
| Device -> "gotch.Device"
|
||||||
|
@ -603,7 +612,7 @@ module Func = struct
|
||||||
else Printf.sprintf "%s.ctensor" name
|
else Printf.sprintf "%s.ctensor" name
|
||||||
| Scalar -> Printf.sprintf "%s.cscalar" name
|
| Scalar -> Printf.sprintf "%s.cscalar" name
|
||||||
| Bool -> Printf.sprintf "c%s" name
|
| Bool -> Printf.sprintf "c%s" name
|
||||||
| ScalarType -> Printf.sprintf "%s.CInt()" name
|
| ScalarType | ScalarTypeOption -> Printf.sprintf "%s.CInt()" name
|
||||||
| Device -> Printf.sprintf "%s.CInt()" name
|
| Device -> Printf.sprintf "%s.CInt()" name
|
||||||
| TensorOptions ->
|
| TensorOptions ->
|
||||||
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
|
Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name
|
||||||
|
@ -633,11 +642,10 @@ module Func = struct
|
||||||
| Tensor -> ""
|
| Tensor -> ""
|
||||||
| TensorOption -> ""
|
| TensorOption -> ""
|
||||||
| Scalar -> ""
|
| Scalar -> ""
|
||||||
| ScalarType -> ""
|
| ScalarType | ScalarTypeOption -> ""
|
||||||
| Device -> ""
|
| Device -> ""
|
||||||
| String -> ""
|
| String -> ""
|
||||||
| IntList -> Printf.sprintf "%sLen := len(%s)\n" an an
|
| IntList | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||||
| IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an
|
|
||||||
| DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an
|
| DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an
|
||||||
| Int64Option ->
|
| Int64Option ->
|
||||||
Printf.sprintf
|
Printf.sprintf
|
||||||
|
@ -667,8 +675,7 @@ module Func = struct
|
||||||
"var c%s []lib.Ctensor\n\
|
"var c%s []lib.Ctensor\n\
|
||||||
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
||||||
an an an an
|
an an an an
|
||||||
| Layout -> ""
|
| Layout | LayoutOption -> ""
|
||||||
| LayoutOption -> ""
|
|
||||||
| TensorOptions -> "" )
|
| TensorOptions -> "" )
|
||||||
|> String.concat ~sep:""
|
|> String.concat ~sep:""
|
||||||
end
|
end
|
||||||
|
@ -679,117 +686,103 @@ let read_yaml filename =
|
||||||
let funcs =
|
let funcs =
|
||||||
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
|
(* Split the file to avoid Yaml.of_string_exn segfaulting. *)
|
||||||
In_channel.with_file filename ~f:In_channel.input_lines
|
In_channel.with_file filename ~f:In_channel.input_lines
|
||||||
|> List.group ~break:(fun _ l ->
|
|> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-')
|
||||||
String.length l > 0 && Char.( = ) l.[0] '-' )
|
|
||||||
|> List.concat_map ~f:(fun lines ->
|
|> List.concat_map ~f:(fun lines ->
|
||||||
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list
|
Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list)
|
||||||
)
|
|
||||||
in
|
in
|
||||||
printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ;
|
printf "Read %s, got %d functions.\n%!" filename (List.length funcs);
|
||||||
List.filter_map funcs ~f:(fun yaml ->
|
List.filter_map funcs ~f:(fun yaml ->
|
||||||
let map = extract_map yaml in
|
let map = extract_map yaml in
|
||||||
let name = Map.find_exn map "name" |> extract_string in
|
let name = Map.find_exn map "name" |> extract_string in
|
||||||
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
||||||
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
||||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||||
let method_of =
|
let method_of =
|
||||||
Map.find_exn map "method_of"
|
Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string
|
||||||
|> extract_list |> List.map ~f:extract_string
|
in
|
||||||
|
let arguments = Map.find_exn map "arguments" |> extract_list in
|
||||||
|
let returns =
|
||||||
|
let is_tensor returns =
|
||||||
|
let returns = extract_map returns in
|
||||||
|
let return_type = Map.find_exn returns "dynamic_type" |> extract_string in
|
||||||
|
String.( = ) return_type "at::Tensor"
|
||||||
in
|
in
|
||||||
let arguments = Map.find_exn map "arguments" |> extract_list in
|
let returns = Map.find_exn map "returns" |> extract_list in
|
||||||
let returns =
|
if List.is_empty returns
|
||||||
let is_tensor returns =
|
then Some `nothing
|
||||||
let returns = extract_map returns in
|
else if List.for_all returns ~f:is_tensor
|
||||||
|
then Some (`fixed (List.length returns))
|
||||||
|
else (
|
||||||
|
match returns with
|
||||||
|
| [ returns ] ->
|
||||||
let return_type =
|
let return_type =
|
||||||
Map.find_exn returns "dynamic_type" |> extract_string
|
Map.find_exn (extract_map returns) "dynamic_type" |> extract_string
|
||||||
in
|
in
|
||||||
String.( = ) return_type "at::Tensor"
|
(match return_type with
|
||||||
in
|
| "bool" -> Some `bool
|
||||||
let returns = Map.find_exn map "returns" |> extract_list in
|
| "int64_t" -> Some `int64_t
|
||||||
if List.for_all returns ~f:is_tensor then
|
| "double" -> Some `double
|
||||||
Some (`fixed (List.length returns))
|
| "at::TensorList" | "dynamic_type: const c10::List<c10::optional<Tensor>> &"
|
||||||
else
|
-> Some `dynamic
|
||||||
match returns with
|
| _ -> None)
|
||||||
| [returns] -> (
|
| [] | _ :: _ :: _ -> None)
|
||||||
let return_type =
|
in
|
||||||
Map.find_exn (extract_map returns) "dynamic_type"
|
let kind =
|
||||||
|> extract_string
|
if List.exists method_of ~f:(String.( = ) "namespace")
|
||||||
in
|
then Some `function_
|
||||||
match return_type with
|
else if List.exists method_of ~f:(String.( = ) "Tensor")
|
||||||
| "bool" -> Some `bool
|
then Some `method_
|
||||||
| "int64_t" -> Some `int64_t
|
else None
|
||||||
| "double" -> Some `double
|
in
|
||||||
| "at::TensorList"
|
if (not deprecated)
|
||||||
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
|
&& (not
|
||||||
Some `dynamic
|
(List.exists excluded_prefixes ~f:(fun prefix ->
|
||||||
| _ -> None )
|
String.is_prefix name ~prefix)))
|
||||||
| [] | _ :: _ :: _ -> None
|
&& (not
|
||||||
in
|
(List.exists excluded_suffixes ~f:(fun suffix ->
|
||||||
let kind =
|
String.is_suffix name ~suffix)))
|
||||||
if List.exists method_of ~f:(String.( = ) "namespace") then
|
&& not (Set.mem excluded_functions name)
|
||||||
Some `function_
|
then
|
||||||
else if List.exists method_of ~f:(String.( = ) "Tensor") then
|
Option.both returns kind
|
||||||
Some `method_
|
|> Option.bind ~f:(fun (returns, kind) ->
|
||||||
else None
|
try
|
||||||
in
|
let args ~with_optional_scalar_args =
|
||||||
if
|
List.filter_map arguments ~f:(fun arg ->
|
||||||
(not deprecated)
|
let arg = extract_map arg in
|
||||||
&& (not
|
let arg_name = Map.find_exn arg "name" |> extract_string in
|
||||||
(List.exists excluded_prefixes ~f:(fun prefix ->
|
let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in
|
||||||
String.is_prefix name ~prefix )))
|
let is_nullable =
|
||||||
&& (not
|
Map.find arg "is_nullable"
|
||||||
(List.exists excluded_suffixes ~f:(fun suffix ->
|
|> Option.value_map ~default:false ~f:extract_bool
|
||||||
String.is_suffix name ~suffix )))
|
|
||||||
&& not (Set.mem excluded_functions name)
|
|
||||||
then
|
|
||||||
Option.both returns kind
|
|
||||||
|> Option.bind ~f:(fun (returns, kind) ->
|
|
||||||
try
|
|
||||||
let args =
|
|
||||||
List.filter_map arguments ~f:(fun arg ->
|
|
||||||
let arg = extract_map arg in
|
|
||||||
let arg_name =
|
|
||||||
Map.find_exn arg "name" |> extract_string
|
|
||||||
in
|
|
||||||
let arg_type =
|
|
||||||
Map.find_exn arg "dynamic_type" |> extract_string
|
|
||||||
in
|
|
||||||
let is_nullable =
|
|
||||||
Map.find arg "is_nullable"
|
|
||||||
|> Option.value_map ~default:false ~f:extract_bool
|
|
||||||
in
|
|
||||||
let default_value =
|
|
||||||
Map.find arg "default" |> Option.map ~f:extract_string
|
|
||||||
in
|
|
||||||
match Func.arg_type_of_string arg_type ~is_nullable with
|
|
||||||
| Some Scalar
|
|
||||||
when Option.is_some default_value && not is_nullable
|
|
||||||
->
|
|
||||||
None
|
|
||||||
| Some TensorOptions
|
|
||||||
when Option.is_some default_value
|
|
||||||
&& Set.mem no_tensor_options name ->
|
|
||||||
None
|
|
||||||
| Some arg_type ->
|
|
||||||
let arg_name =
|
|
||||||
match (arg_name, arg_type) with
|
|
||||||
| "self", Scalar -> "self_scalar"
|
|
||||||
| _, _ -> arg_name
|
|
||||||
in
|
|
||||||
Some {Func.arg_name; arg_type; default_value}
|
|
||||||
| None ->
|
|
||||||
if Option.is_some default_value then None
|
|
||||||
else raise Not_a_simple_arg )
|
|
||||||
in
|
in
|
||||||
Some
|
let default_value =
|
||||||
{ Func.name
|
Map.find arg "default" |> Option.map ~f:extract_string
|
||||||
; operator_name
|
in
|
||||||
; overload_name
|
match Func.arg_type_of_string arg_type ~is_nullable with
|
||||||
; args
|
| Some Scalar when Option.is_some default_value && not is_nullable ->
|
||||||
; returns
|
if with_optional_scalar_args
|
||||||
; kind }
|
then Some { Func.arg_name; arg_type = Scalar; default_value }
|
||||||
with Not_a_simple_arg -> None )
|
else None
|
||||||
else None )
|
| Some TensorOptions
|
||||||
|
when Option.is_some default_value && Set.mem no_tensor_options name ->
|
||||||
|
None
|
||||||
|
| Some arg_type ->
|
||||||
|
let arg_name =
|
||||||
|
match arg_name, arg_type with
|
||||||
|
| "self", Scalar -> "self_scalar"
|
||||||
|
| _, _ -> arg_name
|
||||||
|
in
|
||||||
|
Some { Func.arg_name; arg_type; default_value }
|
||||||
|
| None ->
|
||||||
|
if Option.is_some default_value then None else raise Not_a_simple_arg)
|
||||||
|
in
|
||||||
|
let args =
|
||||||
|
args ~with_optional_scalar_args:(Set.mem with_optional_scalar_args name)
|
||||||
|
in
|
||||||
|
Some [ { Func.name; operator_name; overload_name; args; returns; kind } ]
|
||||||
|
with
|
||||||
|
| Not_a_simple_arg -> None)
|
||||||
|
else None)
|
||||||
|
|
||||||
let p out_channel s =
|
let p out_channel s =
|
||||||
Printf.ksprintf
|
Printf.ksprintf
|
||||||
|
@ -803,72 +796,71 @@ let print_inline out_channel s =
|
||||||
|
|
||||||
let write_cpp funcs filename =
|
let write_cpp funcs filename =
|
||||||
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
|
Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp ->
|
||||||
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
|
Out_channel.with_file (filename ^ ".h") ~f:(fun out_h ->
|
||||||
let pc s = p out_cpp s in
|
let pc s = p out_cpp s in
|
||||||
let ph s = p out_h s in
|
let ph s = p out_h s in
|
||||||
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||||
pc "";
|
pc "";
|
||||||
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!";
|
||||||
ph "";
|
ph "";
|
||||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||||
let c_typed_args_list = Func.c_typed_args_list func in
|
let c_typed_args_list = Func.c_typed_args_list func in
|
||||||
match func.returns with
|
match func.returns with
|
||||||
| `nothing ->
|
| `nothing ->
|
||||||
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
|
pc "void atg_%s(%s) {" exported_name c_typed_args_list;
|
||||||
pc " PROTECT(";
|
pc " PROTECT(";
|
||||||
pc " %s;" (Func.c_call func);
|
pc " %s;" (Func.c_call func);
|
||||||
pc " )";
|
pc " )";
|
||||||
pc "}";
|
pc "}";
|
||||||
pc "";
|
pc "";
|
||||||
ph "void atg_%s(%s);" exported_name c_typed_args_list
|
ph "void atg_%s(%s);" exported_name c_typed_args_list
|
||||||
| `fixed ntensors ->
|
| `fixed ntensors ->
|
||||||
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
|
pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list;
|
||||||
pc " PROTECT(";
|
pc " PROTECT(";
|
||||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||||
if ntensors = 1
|
if ntensors = 1
|
||||||
then pc " out__[0] = new torch::Tensor(outputs__);"
|
then pc " out__[0] = new torch::Tensor(outputs__);"
|
||||||
else
|
else
|
||||||
for i = 0 to ntensors - 1 do
|
for i = 0 to ntensors - 1 do
|
||||||
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
|
pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i
|
||||||
done;
|
done;
|
||||||
pc " )";
|
pc " )";
|
||||||
pc "}";
|
pc "}";
|
||||||
pc "";
|
pc "";
|
||||||
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
|
ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list
|
||||||
| `dynamic ->
|
| `dynamic ->
|
||||||
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
|
pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list;
|
||||||
pc " PROTECT(";
|
pc " PROTECT(";
|
||||||
pc " auto outputs__ = %s;" (Func.c_call func);
|
pc " auto outputs__ = %s;" (Func.c_call func);
|
||||||
(* the returned type is a C++ vector of tensors *)
|
(* the returned type is a C++ vector of tensors *)
|
||||||
pc " int sz = outputs__.size();";
|
pc " int sz = outputs__.size();";
|
||||||
pc
|
pc
|
||||||
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
|
" torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \
|
||||||
sizeof(torch::Tensor*));";
|
sizeof(torch::Tensor*));";
|
||||||
pc " for (int i = 0; i < sz; ++i)";
|
pc " for (int i = 0; i < sz; ++i)";
|
||||||
pc " out__[i] = new torch::Tensor(outputs__[i]);";
|
pc " out__[i] = new torch::Tensor(outputs__[i]);";
|
||||||
pc " out__[sz] = nullptr;";
|
pc " out__[sz] = nullptr;";
|
||||||
pc " return out__;";
|
pc " return out__;";
|
||||||
pc " )";
|
pc " )";
|
||||||
pc " return nullptr;";
|
pc " return nullptr;";
|
||||||
pc "}";
|
pc "}";
|
||||||
pc "";
|
pc "";
|
||||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
||||||
| (`bool | `int64_t | `double) as returns ->
|
| (`bool | `int64_t | `double) as returns ->
|
||||||
let c_type =
|
let c_type =
|
||||||
match returns with
|
match returns with
|
||||||
| `bool -> "int"
|
| `bool -> "int"
|
||||||
| `int64_t -> "int64_t"
|
| `int64_t -> "int64_t"
|
||||||
| `double -> "double"
|
| `double -> "double"
|
||||||
in
|
in
|
||||||
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
|
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list;
|
||||||
pc " PROTECT(";
|
pc " PROTECT(";
|
||||||
pc " return %s;" (Func.c_call func);
|
pc " return %s;" (Func.c_call func);
|
||||||
pc " )";
|
pc " )";
|
||||||
pc " return 0;";
|
pc " return 0;";
|
||||||
pc "}";
|
pc "}";
|
||||||
pc "";
|
pc "";
|
||||||
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
|
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list)))
|
||||||
|
|
||||||
|
|
||||||
let write_wrapper funcs filename =
|
let write_wrapper funcs filename =
|
||||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||||
|
@ -1402,54 +1394,43 @@ let methods =
|
||||||
; c "to" [ ca "self" Tensor; ca "device" Device ]
|
; c "to" [ ca "self" Tensor; ca "device" Device ]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
||||||
~wrapper_filename =
|
~wrapper_filename =
|
||||||
let funcs = read_yaml yaml_filename in
|
let funcs = read_yaml yaml_filename |> List.concat in
|
||||||
let funcs = methods @ funcs in
|
let funcs = methods @ funcs in
|
||||||
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
printf "Generating code for %d functions.\n%!" (List.length funcs);
|
||||||
(* Generate some unique names for overloaded functions. *)
|
(* Generate some unique names for overloaded functions. *)
|
||||||
let funcs =
|
let funcs =
|
||||||
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
|
List.map funcs ~f:(fun func -> Func.operator_name func, func)
|
||||||
|> Map.of_alist_multi (module String)
|
|> Map.of_alist_multi (module String)
|
||||||
|> Map.to_alist
|
|> Map.to_alist
|
||||||
|> List.concat_map ~f:(fun (name, funcs) ->
|
|> List.concat_map ~f:(fun (name, funcs) ->
|
||||||
match funcs with
|
match funcs with
|
||||||
| [] -> assert false
|
| [] -> assert false
|
||||||
| [func] -> [(name, func)]
|
| [ func ] -> [ name, func ]
|
||||||
| funcs ->
|
| funcs ->
|
||||||
let has_empty_overload =
|
let has_empty_overload =
|
||||||
List.exists funcs ~f:(fun (func : Func.t) ->
|
List.exists funcs ~f:(fun (func : Func.t) ->
|
||||||
String.is_empty func.overload_name )
|
String.is_empty func.overload_name)
|
||||||
in
|
in
|
||||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||||
match
|
match Int.compare (String.length f1.name) (String.length f2.name) with
|
||||||
Int.compare (String.length f1.name)
|
| 0 -> Int.compare (List.length f1.args) (List.length f2.args)
|
||||||
(String.length f2.name)
|
| cmp -> cmp)
|
||||||
with
|
|> List.mapi ~f:(fun index (func : Func.t) ->
|
||||||
| 0 ->
|
let operator_name = Func.operator_name func in
|
||||||
Int.compare (List.length f1.args) (List.length f2.args)
|
let overload_name = String.lowercase func.overload_name in
|
||||||
| cmp -> cmp )
|
let name =
|
||||||
|> List.mapi ~f:(fun index (func : Func.t) ->
|
if String.is_empty overload_name || (index = 0 && not has_empty_overload)
|
||||||
let operator_name =
|
then operator_name
|
||||||
String.lowercase func.operator_name
|
else if String.is_suffix operator_name ~suffix:"_"
|
||||||
in
|
then operator_name ^ overload_name ^ "_"
|
||||||
let overload_name =
|
else operator_name ^ "_" ^ overload_name
|
||||||
String.lowercase func.overload_name
|
in
|
||||||
in
|
name, func))
|
||||||
let name =
|
|
||||||
if
|
|
||||||
String.is_empty overload_name
|
|
||||||
|| (index = 0 && not has_empty_overload)
|
|
||||||
then operator_name
|
|
||||||
else if String.is_suffix operator_name ~suffix:"_" then
|
|
||||||
operator_name ^ overload_name ^ "_"
|
|
||||||
else operator_name ^ "_" ^ overload_name
|
|
||||||
in
|
|
||||||
(name, func) ) )
|
|
||||||
|> Map.of_alist_exn (module String)
|
|> Map.of_alist_exn (module String)
|
||||||
in
|
in
|
||||||
write_cpp funcs cpp_filename ;
|
write_cpp funcs cpp_filename;
|
||||||
write_ffi funcs ffi_filename ;
|
write_ffi funcs ffi_filename ;
|
||||||
write_must_wrapper funcs must_wrapper_filename ;
|
write_must_wrapper funcs must_wrapper_filename ;
|
||||||
write_wrapper funcs wrapper_filename
|
write_wrapper funcs wrapper_filename
|
||||||
|
|
|
@ -2666,10 +2666,10 @@ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||||
C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice)
|
C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice)
|
||||||
}
|
}
|
||||||
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32){
|
func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32){
|
||||||
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind))
|
||||||
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice))
|
||||||
C.atg_arange_start_step(ptr, start , end , coptionsKind, coptionsDevice)
|
C.atg_arange_start_step(ptr, start , end , step , coptionsKind, coptionsDevice)
|
||||||
}
|
}
|
||||||
func AtgArccos(ptr *Ctensor, self Ctensor){
|
func AtgArccos(ptr *Ctensor, self Ctensor){
|
||||||
C.atg_arccos(ptr, self)
|
C.atg_arccos(ptr, self)
|
||||||
|
@ -3004,8 +3004,8 @@ cdivisorOverrideVal := *(*C.int64_t)(unsafe.Pointer(&divisorOverrideVal))
|
||||||
cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull))
|
cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull))
|
||||||
C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull)
|
C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull)
|
||||||
}
|
}
|
||||||
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
|
func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor, beta Cscalar, alpha Cscalar){
|
||||||
C.atg_baddbmm(ptr, self, batch1, batch2)
|
C.atg_baddbmm(ptr, self, batch1, batch2, beta , alpha )
|
||||||
}
|
}
|
||||||
func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
|
func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){
|
||||||
C.atg_baddbmm_(ptr, self, batch1, batch2)
|
C.atg_baddbmm_(ptr, self, batch1, batch2)
|
||||||
|
|
|
@ -2068,7 +2068,7 @@ void atg__slow_conv2d_backward(tensor *out__, tensor grad_input, tensor grad_wei
|
||||||
|
|
||||||
void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) {
|
void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, at::ScalarType(dtype));
|
auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(std::get<0>(outputs__));
|
out__[0] = new torch::Tensor(std::get<0>(outputs__));
|
||||||
out__[1] = new torch::Tensor(std::get<1>(outputs__));
|
out__[1] = new torch::Tensor(std::get<1>(outputs__));
|
||||||
)
|
)
|
||||||
|
@ -2223,28 +2223,28 @@ void atg__sparse_csc_tensor_unsafe(tensor *out__, tensor ccol_indices, tensor ro
|
||||||
|
|
||||||
void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -2279,7 +2279,7 @@ void atg__sparse_log_softmax_backward_data_out(tensor *out__, tensor out, tensor
|
||||||
|
|
||||||
void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_log_softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -2336,7 +2336,7 @@ void atg__sparse_softmax_backward_data_out(tensor *out__, tensor out, tensor gra
|
||||||
|
|
||||||
void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_sparse_softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::_sparse_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -2629,14 +2629,14 @@ tensor *atg__to_cpu(tensor *tensors_data, int tensors_len) {
|
||||||
|
|
||||||
void atg__to_dense(tensor *out__, tensor self, int dtype) {
|
void atg__to_dense(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = self->_to_dense(at::ScalarType(dtype));
|
auto outputs__ = self->_to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) {
|
void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::_to_dense_out(*out, *self, at::ScalarType(dtype));
|
auto outputs__ = torch::_to_dense_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -3640,9 +3640,9 @@ void atg_arange_start(tensor *out__, scalar start, scalar end, int options_kind,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_arange_start_step(tensor *out__, scalar start, scalar end, int options_kind, int options_device) {
|
void atg_arange_start_step(tensor *out__, scalar start, scalar end, scalar step, int options_kind, int options_device) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::arange(*start, *end, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
|
auto outputs__ = torch::arange(*start, *end, *step, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -4120,9 +4120,9 @@ void atg_avg_pool3d_out(tensor *out__, tensor out, tensor self, int64_t *kernel_
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2) {
|
void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2);
|
auto outputs__ = torch::baddbmm(*self, *batch1, *batch2, *beta, *alpha);
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -5910,14 +5910,14 @@ void atg_cummin_out(tensor *out__, tensor values, tensor indices, tensor self, i
|
||||||
|
|
||||||
void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::cumprod(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::cumprod(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = self->cumprod_(dim, at::ScalarType(dtype));
|
auto outputs__ = self->cumprod_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -5931,28 +5931,28 @@ void atg_cumprod_backward(tensor *out__, tensor grad, tensor input, int64_t dim,
|
||||||
|
|
||||||
void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::cumprod_out(*out, *self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::cumprod_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::cumsum(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::cumsum(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = self->cumsum_(dim, at::ScalarType(dtype));
|
auto outputs__ = self->cumsum_(dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::cumsum_out(*out, *self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::cumsum_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -9912,28 +9912,28 @@ void atg_linalg_multi_dot_out(tensor *out__, tensor out, tensor *tensors_data, i
|
||||||
|
|
||||||
void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -10314,14 +10314,14 @@ void atg_log_sigmoid_out(tensor *out__, tensor out, tensor self) {
|
||||||
|
|
||||||
void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::log_softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::log_softmax_out(*out, *self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::log_softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -10953,21 +10953,21 @@ void atg_maximum_out(tensor *out__, tensor out, tensor self, tensor other) {
|
||||||
|
|
||||||
void atg_mean(tensor *out__, tensor self, int dtype) {
|
void atg_mean(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::mean(*self, at::ScalarType(dtype));
|
auto outputs__ = torch::mean(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -11742,14 +11742,14 @@ void atg_nan_to_num_out(tensor *out__, tensor out, tensor self, double nan_v, ui
|
||||||
|
|
||||||
void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -11814,14 +11814,14 @@ void atg_nanquantile_scalar_out(tensor *out__, tensor out, tensor self, double q
|
||||||
|
|
||||||
void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -11961,14 +11961,14 @@ void atg_native_norm_out(tensor *out__, tensor out, tensor self) {
|
||||||
|
|
||||||
void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -12702,28 +12702,28 @@ void atg_prelu(tensor *out__, tensor self, tensor weight) {
|
||||||
|
|
||||||
void atg_prod(tensor *out__, tensor self, int dtype) {
|
void atg_prod(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::prod(*self, at::ScalarType(dtype));
|
auto outputs__ = torch::prod(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) {
|
void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::prod(*self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) {
|
void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) {
|
void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::prod_out(*out, *self, at::ScalarType(dtype));
|
auto outputs__ = torch::prod_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -14593,14 +14593,14 @@ void atg_soft_margin_loss_out(tensor *out__, tensor out, tensor self, tensor tar
|
||||||
|
|
||||||
void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::softmax_out(*out, *self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -15528,7 +15528,7 @@ void atg_special_log_ndtr_out(tensor *out__, tensor out, tensor self) {
|
||||||
|
|
||||||
void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::special_log_softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::special_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -15913,7 +15913,7 @@ void atg_special_sinc_out(tensor *out__, tensor out, tensor self) {
|
||||||
|
|
||||||
void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::special_softmax(*self, dim, at::ScalarType(dtype));
|
auto outputs__ = torch::special_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -16437,28 +16437,28 @@ void atg_subtract_scalar_(tensor *out__, tensor self, scalar other) {
|
||||||
|
|
||||||
void atg_sum(tensor *out__, tensor self, int dtype) {
|
void atg_sum(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::sum(*self, at::ScalarType(dtype));
|
auto outputs__ = torch::sum(*self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype));
|
auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional<torch::IntArrayRef>(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) {
|
void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::sum_out(*out, *self, at::ScalarType(dtype));
|
auto outputs__ = torch::sum_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -16732,7 +16732,7 @@ void atg_to(tensor *out__, tensor self, int device) {
|
||||||
|
|
||||||
void atg_to_dense(tensor *out__, tensor self, int dtype) {
|
void atg_to_dense(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = self->to_dense(at::ScalarType(dtype));
|
auto outputs__ = self->to_dense(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -16767,7 +16767,7 @@ void atg_to_dtype_layout(tensor *out__, tensor self, int options_kind, int optio
|
||||||
|
|
||||||
void atg_to_mkldnn(tensor *out__, tensor self, int dtype) {
|
void atg_to_mkldnn(tensor *out__, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = self->to_mkldnn(at::ScalarType(dtype));
|
auto outputs__ = self->to_mkldnn(dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -16781,7 +16781,7 @@ void atg_to_mkldnn_backward(tensor *out__, tensor grad, tensor input) {
|
||||||
|
|
||||||
void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) {
|
void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) {
|
||||||
PROTECT(
|
PROTECT(
|
||||||
auto outputs__ = torch::to_mkldnn_out(*out, *self, at::ScalarType(dtype));
|
auto outputs__ = torch::to_mkldnn_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional<at::ScalarType>(at::ScalarType(dtype)));
|
||||||
out__[0] = new torch::Tensor(outputs__);
|
out__[0] = new torch::Tensor(outputs__);
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -497,7 +497,7 @@ void atg_any_dim(tensor *, tensor self, int64_t dim, int keepdim);
|
||||||
void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim);
|
void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim);
|
||||||
void atg_arange(tensor *, scalar end, int options_kind, int options_device);
|
void atg_arange(tensor *, scalar end, int options_kind, int options_device);
|
||||||
void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
||||||
void atg_arange_start_step(tensor *, scalar start, scalar end, int options_kind, int options_device);
|
void atg_arange_start_step(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device);
|
||||||
void atg_arccos(tensor *, tensor self);
|
void atg_arccos(tensor *, tensor self);
|
||||||
void atg_arccos_(tensor *, tensor self);
|
void atg_arccos_(tensor *, tensor self);
|
||||||
void atg_arccos_out(tensor *, tensor out, tensor self);
|
void atg_arccos_out(tensor *, tensor out, tensor self);
|
||||||
|
@ -563,7 +563,7 @@ void atg_avg_pool3d(tensor *, tensor self, int64_t *kernel_size_data, int kernel
|
||||||
void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||||
void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||||
void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null);
|
||||||
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2);
|
void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha);
|
||||||
void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2);
|
void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2);
|
||||||
void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2);
|
void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2);
|
||||||
void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device);
|
void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device);
|
||||||
|
|
|
@ -3961,9 +3961,9 @@ func MustArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, option
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
|
func MustArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) {
|
||||||
|
|
||||||
retVal, err := ArangeStartStep(start, end, optionsKind, optionsDevice)
|
retVal, err := ArangeStartStep(start, end, step, optionsKind, optionsDevice)
|
||||||
if err != nil { log.Fatal(err) }
|
if err != nil { log.Fatal(err) }
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
@ -4465,9 +4465,9 @@ func(ts *Tensor) MustAvgPool3dOut(out *Tensor, kernelSize []int64, stride []int6
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor) {
|
func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor) {
|
||||||
|
|
||||||
retVal, err := ts.Baddbmm(batch1, batch2, del)
|
retVal, err := ts.Baddbmm(batch1, batch2, beta, alpha, del)
|
||||||
if err != nil { log.Fatal(err) }
|
if err != nil { log.Fatal(err) }
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
|
|
|
@ -9327,10 +9327,10 @@ func ArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDev
|
||||||
// func.returns = `fixed 1`:
|
// func.returns = `fixed 1`:
|
||||||
// --------------------------
|
// --------------------------
|
||||||
|
|
||||||
func ArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
|
func ArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, optionsKind.CInt(), optionsDevice.CInt())
|
lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, step.cscalar, optionsKind.CInt(), optionsDevice.CInt())
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
err = fmt.Errorf("ArangeStartStep() failed: %w", err)
|
err = fmt.Errorf("ArangeStartStep() failed: %w", err)
|
||||||
return retVal, err
|
return retVal, err
|
||||||
|
@ -10585,11 +10585,11 @@ var cdivisorOverrideVal int64 = 0
|
||||||
// func.returns = `fixed 1`:
|
// func.returns = `fixed 1`:
|
||||||
// --------------------------
|
// --------------------------
|
||||||
|
|
||||||
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor, err error) {
|
func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor, err error) {
|
||||||
if del { defer ts.MustDrop() }
|
if del { defer ts.MustDrop() }
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor)
|
lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor, beta.cscalar, alpha.cscalar)
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
err = fmt.Errorf("Baddbmm() failed: %w", err)
|
err = fmt.Errorf("Baddbmm() failed: %w", err)
|
||||||
return retVal, err
|
return retVal, err
|
||||||
|
|
|
@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
rgbTs := x.MustUnbind(-3, false)
|
rgbTs := x.MustUnbind(-3, false)
|
||||||
r := &rgbTs[0]
|
r := rgbTs[0]
|
||||||
g := &rgbTs[1]
|
g := rgbTs[1]
|
||||||
b := &rgbTs[2]
|
b := rgbTs[2]
|
||||||
|
|
||||||
// This implementation closely follows the TF one:
|
// This implementation closely follows the TF one:
|
||||||
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
// https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138
|
||||||
|
@ -453,7 +453,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor {
|
||||||
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
|
a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3)
|
||||||
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
|
a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4)
|
||||||
|
|
||||||
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4})
|
out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1})
|
||||||
|
|
||||||
// Delete intermediate tensors
|
// Delete intermediate tensors
|
||||||
h.MustDrop()
|
h.MustDrop()
|
||||||
|
@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor {
|
||||||
x2 := x1T.Idx(wNar)
|
x2 := x1T.Idx(wNar)
|
||||||
x1T.MustDrop()
|
x1T.MustDrop()
|
||||||
out := x2.MustT(true)
|
out := x2.MustT(true)
|
||||||
chans[i] = *out
|
chans[i] = out
|
||||||
}
|
}
|
||||||
|
|
||||||
cropTs := ts.MustStack(chans, 0)
|
cropTs := ts.MustStack(chans, 0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user