From e9278816b2674a7cd4cd33e6814a32515200e07b Mon Sep 17 00:00:00 2001 From: sugarme Date: Sat, 5 Aug 2023 13:02:56 +1000 Subject: [PATCH] fixed incorrect APIs generation --- gen/gen.ml | 647 +++++++++++++++---------------- libtch/c-generated.go | 8 +- libtch/torch_api_generated.cpp.h | 98 ++--- libtch/torch_api_generated.h | 4 +- ts/must-tensor-generated.go | 8 +- ts/tensor-generated.go | 8 +- vision/aug/function.go | 10 +- 7 files changed, 382 insertions(+), 401 deletions(-) diff --git a/gen/gen.ml b/gen/gen.ml index bc72301..b5ed3ab 100644 --- a/gen/gen.ml +++ b/gen/gen.ml @@ -82,6 +82,14 @@ let no_tensor_options = ; "randint_like" ; "randn_like" ] +(* By default, scalar argument that have a default value are not available on + the Rust side, this is to preserve the Rust api simplicity assuming that + these scalars arguments are not often overriden. + Adding function name [foo] in [with_optional_scalar_args] results in having + explicit scalar arguments even if a default is present. *) +let with_optional_scalar_args = Set.of_list (module String) [ "arange"; "baddbmm" ] + + (* * let prefixed_functions = * Set.of_list @@ -133,24 +141,26 @@ module Func = struct | Double | DoubleOption | Tensor - | TensorOption - (* Tensor.t option *) + | TensorOption (* Tensor.t option *) | IntList | IntListOption | DoubleList | TensorOptList | TensorList - | TensorOptions - (* Tensor kind and device *) + | TensorOptions (* Tensor kind and device *) | Scalar | ScalarType + | ScalarTypeOption | Device | String | Layout | LayoutOption type arg = - {arg_name: string; arg_type: arg_type; default_value: string option} + { arg_name: string + ; arg_type: arg_type + ; default_value: string option + } (* `Func` type *) type t = @@ -160,7 +170,8 @@ module Func = struct ; args: arg list (* ; returns: [`fixed of int | `dynamic] *) ; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double | `nothing] ; (* number of tensors that are returned *) - kind: [`function_ | `method_] } + kind: [`function_ | `method_] + } let arg_type_of_string str ~is_nullable = match String.lowercase str with @@ -175,108 +186,111 @@ module Func = struct | "const at::itensorlistref &" | "at::tensorlist" -> Some TensorList | "at::device" -> Some Device | "const at::scalar &" | "at::scalar" -> Some Scalar - | "at::scalartype" -> Some ScalarType + | "at::scalartype" -> if is_nullable then Some ScalarTypeOption else Some ScalarType | "c10::string_view" -> Some String | "at::layout" -> Some (if is_nullable then LayoutOption else Layout) | _ -> None + let c_typed_args_list t = List.map t.args ~f:(fun { arg_name; arg_type; _ } -> - match arg_type with - | IntList | IntListOption -> - Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name - | DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name - | TensorOptList | TensorList -> - Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name - | TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name - | String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name - | Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name - | DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name - | otherwise -> - let simple_type_cstring = - match otherwise with - | Bool -> "int" - | Int64 -> "int64_t" - | Double -> "double" - | Tensor -> "tensor" - | TensorOption -> "tensor" - | ScalarType -> "int" - | Device -> "int" - | Scalar -> "scalar" - | Layout | LayoutOption -> "int8_t" - | Int64Option - | DoubleOption - | String - | IntList - | IntListOption - | DoubleList - | TensorOptList - | TensorList - | TensorOptions -> assert false - in - Printf.sprintf "%s %s" simple_type_cstring arg_name) + match arg_type with + | IntList | IntListOption -> + Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name + | DoubleList -> Printf.sprintf "double *%s_data, int %s_len" arg_name arg_name + | TensorOptList | TensorList -> + Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name + | TensorOptions -> Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name + | String -> Printf.sprintf "char* %s_ptr, int %s_len" arg_name arg_name + | Int64Option -> Printf.sprintf "int64_t %s_v, uint8_t %s_null" arg_name arg_name + | DoubleOption -> Printf.sprintf "double %s_v, uint8_t %s_null" arg_name arg_name + | otherwise -> + let simple_type_cstring = + match otherwise with + | Bool -> "int" + | Int64 -> "int64_t" + | Double -> "double" + | Tensor -> "tensor" + | TensorOption -> "tensor" + | ScalarType -> "int" + | ScalarTypeOption -> "int" + | Device -> "int" + | Scalar -> "scalar" + | Layout | LayoutOption -> "int8_t" + | Int64Option + | DoubleOption + | String + | IntList + | IntListOption + | DoubleList + | TensorOptList + | TensorList + | TensorOptions -> assert false + in + Printf.sprintf "%s %s" simple_type_cstring arg_name) |> String.concat ~sep:", " let c_args_list args = List.map args ~f:(fun { arg_name; arg_type; _ } -> - match arg_type with - | Scalar | Tensor -> "*" ^ arg_name - | Layout -> Printf.sprintf "static_cast(%s)" arg_name - | LayoutOption -> - Printf.sprintf - "(%s == -1 ? c10::nullopt : \ - c10::optional(static_cast(%s)))" - arg_name - arg_name - | TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name - | Bool -> "(bool)" ^ arg_name - | IntList -> - Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name - | IntListOption -> - Printf.sprintf - "%s_data == nullptr ? c10::nullopt : \ - c10::optional(torch::IntArrayRef(%s_data, %s_len))" - arg_name - arg_name - arg_name - | DoubleList -> - Printf.sprintf "at::ArrayRef(%s_data, %s_len)" arg_name arg_name - | String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name - | TensorOptList -> - Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name - | TensorList -> - Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name - | TensorOptions -> - Printf.sprintf - "at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))" - arg_name - arg_name - | Int64Option -> - Printf.sprintf - "%s_null ? c10::nullopt : c10::optional(%s_v)" - arg_name - arg_name - | DoubleOption -> - Printf.sprintf - "%s_null ? c10::nullopt : c10::optional(%s_v)" - arg_name - arg_name - | ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name - | Device -> Printf.sprintf "device_of_int(%s)" arg_name - | _ -> arg_name) + match arg_type with + | Scalar | Tensor -> "*" ^ arg_name + | Layout -> Printf.sprintf "static_cast(%s)" arg_name + | LayoutOption -> + Printf.sprintf + "(%s == -1 ? c10::nullopt : \ + c10::optional(static_cast(%s)))" + arg_name + arg_name + | TensorOption -> Printf.sprintf "(%s ? *%s : torch::Tensor())" arg_name arg_name + | Bool -> "(bool)" ^ arg_name + | IntList -> Printf.sprintf "torch::IntArrayRef(%s_data, %s_len)" arg_name arg_name + | IntListOption -> + Printf.sprintf + "%s_data == nullptr ? c10::nullopt : \ + c10::optional(torch::IntArrayRef(%s_data, %s_len))" + arg_name + arg_name + arg_name + | DoubleList -> + Printf.sprintf "at::ArrayRef(%s_data, %s_len)" arg_name arg_name + | String -> Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name + | TensorOptList -> + Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name arg_name + | TensorList -> Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name arg_name + | TensorOptions -> + Printf.sprintf + "at::device(device_of_int(%s_device)).dtype(at::ScalarType(%s_kind))" + arg_name + arg_name + | Int64Option -> + Printf.sprintf + "%s_null ? c10::nullopt : c10::optional(%s_v)" + arg_name + arg_name + | DoubleOption -> + Printf.sprintf + "%s_null ? c10::nullopt : c10::optional(%s_v)" + arg_name + arg_name + | ScalarType -> Printf.sprintf "at::ScalarType(%s)" arg_name + | ScalarTypeOption -> + Printf.sprintf + "%s < 0 ? c10::nullopt : c10::optional(at::ScalarType(%s))" + arg_name + arg_name + | Device -> Printf.sprintf "device_of_int(%s)" arg_name + | _ -> arg_name) |> String.concat ~sep:", " - let c_call t = match t.kind with | `function_ -> Printf.sprintf "torch::%s(%s)" t.name (c_args_list t.args) - | `method_ -> ( - match t.args with - | head :: tail -> - Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail) - | [] -> - Printf.failwithf "Method calls should have at least one argument %s" - t.name () ) + | `method_ -> + (match t.args with + | head :: tail -> + Printf.sprintf "%s->%s(%s)" head.arg_name t.name (c_args_list tail) + | [] -> + Printf.failwithf "Method calls should have at least one argument %s" t.name ()) (* let replace_map = @@ -289,6 +303,15 @@ module Func = struct ; ("to_device", "to_device_") ] *) + let operator_name t = + match String.lowercase t.operator_name with + | "scatter_reduce" -> + (* scatter_reduce is both an operator name and also obtained from the + scatter operator when using the reduce overload. *) + "_scatter_reduce" + | "scatter_reduce_" -> "_scatter_reduce_" + | other -> other + let is_method t = List.exists t.args ~f:(fun arg -> match arg.arg_name with "self" -> true | _ -> false ) @@ -321,18 +344,16 @@ module Func = struct let single_param = Printf.sprintf "%s %s" an in match arg.arg_type with | Bool -> single_param "int32" - | Layout -> single_param "int8" - | LayoutOption -> single_param "int8" + | Layout | LayoutOption -> single_param "int8" | Int64 -> single_param "int64" | Double -> single_param "float64" | Tensor -> single_param "Ctensor" | TensorOption -> single_param "Ctensor" | Scalar -> single_param "Cscalar" - | ScalarType -> single_param "int32" + | ScalarType | ScalarTypeOption -> single_param "int32" | Device -> single_param "int32" | String -> single_param "string" - | IntList -> Printf.sprintf "%sData []int64, %sLen int" an an - | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an + | IntList | IntListOption -> Printf.sprintf "%sData []int64, %sLen int" an an | DoubleList -> Printf.sprintf "%sData []float64, %sLen int" an an | TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an | TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an @@ -353,14 +374,12 @@ module Func = struct | Double -> Printf.sprintf "c%s" an | Tensor -> Printf.sprintf "%s" an | TensorOption -> Printf.sprintf "%s" an - | Layout -> Printf.sprintf "c%s" an - | LayoutOption -> Printf.sprintf "c%s" an + | Layout | LayoutOption -> Printf.sprintf "c%s" an | Scalar -> single_param "" - | ScalarType -> Printf.sprintf "c%s" an + | ScalarType | ScalarTypeOption -> Printf.sprintf "c%s" an | Device -> Printf.sprintf "c%s" an | String -> Printf.sprintf "c%s, c%sLen" an an - | IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an - | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an + | IntList | IntListOption -> Printf.sprintf "c%sDataPtr, c%sLen" an an | DoubleList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an | TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an @@ -383,12 +402,10 @@ module Func = struct Printf.sprintf "\nc%s := *(*C.double)(unsafe.Pointer(&%s))" an an | Tensor -> "" | TensorOption -> "" - | Layout -> - Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an - | LayoutOption -> + | Layout | LayoutOption -> Printf.sprintf "\nc%s := *(*C.int8_t)(unsafe.Pointer(&%s))" an an | Scalar -> "" - | ScalarType -> + | ScalarType | ScalarTypeOption -> Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an | Device -> Printf.sprintf "\nc%s := *(*C.int)(unsafe.Pointer(&%s))" an an @@ -399,13 +416,7 @@ module Func = struct %sLen := len(%s)\n\ c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))" an an an an an an - | IntList -> - Printf.sprintf - "\n\ - c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\ - c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))" - an an an an - | IntListOption -> + | IntList | IntListOption -> Printf.sprintf "\n\ c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\ @@ -494,14 +505,12 @@ module Func = struct let go_arg_type = match arg.arg_type with | Bool -> "bool" - | Layout -> "Layout" - | LayoutOption -> "Layout" + | Layout | LayoutOption -> "Layout" | Int64 -> "int64" | Double -> "float64" | Tensor -> "*Tensor" | TensorOption -> "*Tensor" - | IntList -> "[]int64" - | IntListOption -> "[]int64" + | IntList | IntListOption -> "[]int64" | DoubleList -> "[]float64" | TensorOptList -> "[]*Tensor" | TensorList -> "[]*Tensor" @@ -510,7 +519,7 @@ module Func = struct (* E.g. `type KindDevice struct{}` *) | TensorOptions -> "gotch.KindDevice" | Scalar -> "*Scalar" - | ScalarType -> "gotch.DType" + | ScalarType | ScalarTypeOption -> "gotch.DType" | Int64Option -> "[]int64" | DoubleOption -> "[]float64" | Device -> "gotch.Device" @@ -603,7 +612,7 @@ module Func = struct else Printf.sprintf "%s.ctensor" name | Scalar -> Printf.sprintf "%s.cscalar" name | Bool -> Printf.sprintf "c%s" name - | ScalarType -> Printf.sprintf "%s.CInt()" name + | ScalarType | ScalarTypeOption -> Printf.sprintf "%s.CInt()" name | Device -> Printf.sprintf "%s.CInt()" name | TensorOptions -> Printf.sprintf "%sKind.CInt(), %sDevice.CInt()" name name @@ -633,11 +642,10 @@ module Func = struct | Tensor -> "" | TensorOption -> "" | Scalar -> "" - | ScalarType -> "" + | ScalarType | ScalarTypeOption -> "" | Device -> "" | String -> "" - | IntList -> Printf.sprintf "%sLen := len(%s)\n" an an - | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an + | IntList | IntListOption -> Printf.sprintf "%sLen := len(%s)\n" an an | DoubleList -> Printf.sprintf "%sLen := len(%s)\n" an an | Int64Option -> Printf.sprintf @@ -667,8 +675,7 @@ module Func = struct "var c%s []lib.Ctensor\n\ \ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n" an an an an - | Layout -> "" - | LayoutOption -> "" + | Layout | LayoutOption -> "" | TensorOptions -> "" ) |> String.concat ~sep:"" end @@ -679,117 +686,103 @@ let read_yaml filename = let funcs = (* Split the file to avoid Yaml.of_string_exn segfaulting. *) In_channel.with_file filename ~f:In_channel.input_lines - |> List.group ~break:(fun _ l -> - String.length l > 0 && Char.( = ) l.[0] '-' ) + |> List.group ~break:(fun _ l -> String.length l > 0 && Char.( = ) l.[0] '-') |> List.concat_map ~f:(fun lines -> - Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list - ) + Yaml.of_string_exn (String.concat lines ~sep:"\n") |> extract_list) in - printf "Read %s, got %d functions.\n%!" filename (List.length funcs) ; + printf "Read %s, got %d functions.\n%!" filename (List.length funcs); List.filter_map funcs ~f:(fun yaml -> - let map = extract_map yaml in - let name = Map.find_exn map "name" |> extract_string in - let operator_name = Map.find_exn map "operator_name" |> extract_string in - let overload_name = Map.find_exn map "overload_name" |> extract_string in - let deprecated = Map.find_exn map "deprecated" |> extract_bool in - let method_of = - Map.find_exn map "method_of" - |> extract_list |> List.map ~f:extract_string + let map = extract_map yaml in + let name = Map.find_exn map "name" |> extract_string in + let operator_name = Map.find_exn map "operator_name" |> extract_string in + let overload_name = Map.find_exn map "overload_name" |> extract_string in + let deprecated = Map.find_exn map "deprecated" |> extract_bool in + let method_of = + Map.find_exn map "method_of" |> extract_list |> List.map ~f:extract_string + in + let arguments = Map.find_exn map "arguments" |> extract_list in + let returns = + let is_tensor returns = + let returns = extract_map returns in + let return_type = Map.find_exn returns "dynamic_type" |> extract_string in + String.( = ) return_type "at::Tensor" in - let arguments = Map.find_exn map "arguments" |> extract_list in - let returns = - let is_tensor returns = - let returns = extract_map returns in + let returns = Map.find_exn map "returns" |> extract_list in + if List.is_empty returns + then Some `nothing + else if List.for_all returns ~f:is_tensor + then Some (`fixed (List.length returns)) + else ( + match returns with + | [ returns ] -> let return_type = - Map.find_exn returns "dynamic_type" |> extract_string + Map.find_exn (extract_map returns) "dynamic_type" |> extract_string in - String.( = ) return_type "at::Tensor" - in - let returns = Map.find_exn map "returns" |> extract_list in - if List.for_all returns ~f:is_tensor then - Some (`fixed (List.length returns)) - else - match returns with - | [returns] -> ( - let return_type = - Map.find_exn (extract_map returns) "dynamic_type" - |> extract_string - in - match return_type with - | "bool" -> Some `bool - | "int64_t" -> Some `int64_t - | "double" -> Some `double - | "at::TensorList" - |"dynamic_type: const c10::List> &" -> - Some `dynamic - | _ -> None ) - | [] | _ :: _ :: _ -> None - in - let kind = - if List.exists method_of ~f:(String.( = ) "namespace") then - Some `function_ - else if List.exists method_of ~f:(String.( = ) "Tensor") then - Some `method_ - else None - in - if - (not deprecated) - && (not - (List.exists excluded_prefixes ~f:(fun prefix -> - String.is_prefix name ~prefix ))) - && (not - (List.exists excluded_suffixes ~f:(fun suffix -> - String.is_suffix name ~suffix ))) - && not (Set.mem excluded_functions name) - then - Option.both returns kind - |> Option.bind ~f:(fun (returns, kind) -> - try - let args = - List.filter_map arguments ~f:(fun arg -> - let arg = extract_map arg in - let arg_name = - Map.find_exn arg "name" |> extract_string - in - let arg_type = - Map.find_exn arg "dynamic_type" |> extract_string - in - let is_nullable = - Map.find arg "is_nullable" - |> Option.value_map ~default:false ~f:extract_bool - in - let default_value = - Map.find arg "default" |> Option.map ~f:extract_string - in - match Func.arg_type_of_string arg_type ~is_nullable with - | Some Scalar - when Option.is_some default_value && not is_nullable - -> - None - | Some TensorOptions - when Option.is_some default_value - && Set.mem no_tensor_options name -> - None - | Some arg_type -> - let arg_name = - match (arg_name, arg_type) with - | "self", Scalar -> "self_scalar" - | _, _ -> arg_name - in - Some {Func.arg_name; arg_type; default_value} - | None -> - if Option.is_some default_value then None - else raise Not_a_simple_arg ) + (match return_type with + | "bool" -> Some `bool + | "int64_t" -> Some `int64_t + | "double" -> Some `double + | "at::TensorList" | "dynamic_type: const c10::List> &" + -> Some `dynamic + | _ -> None) + | [] | _ :: _ :: _ -> None) + in + let kind = + if List.exists method_of ~f:(String.( = ) "namespace") + then Some `function_ + else if List.exists method_of ~f:(String.( = ) "Tensor") + then Some `method_ + else None + in + if (not deprecated) + && (not + (List.exists excluded_prefixes ~f:(fun prefix -> + String.is_prefix name ~prefix))) + && (not + (List.exists excluded_suffixes ~f:(fun suffix -> + String.is_suffix name ~suffix))) + && not (Set.mem excluded_functions name) + then + Option.both returns kind + |> Option.bind ~f:(fun (returns, kind) -> + try + let args ~with_optional_scalar_args = + List.filter_map arguments ~f:(fun arg -> + let arg = extract_map arg in + let arg_name = Map.find_exn arg "name" |> extract_string in + let arg_type = Map.find_exn arg "dynamic_type" |> extract_string in + let is_nullable = + Map.find arg "is_nullable" + |> Option.value_map ~default:false ~f:extract_bool in - Some - { Func.name - ; operator_name - ; overload_name - ; args - ; returns - ; kind } - with Not_a_simple_arg -> None ) - else None ) + let default_value = + Map.find arg "default" |> Option.map ~f:extract_string + in + match Func.arg_type_of_string arg_type ~is_nullable with + | Some Scalar when Option.is_some default_value && not is_nullable -> + if with_optional_scalar_args + then Some { Func.arg_name; arg_type = Scalar; default_value } + else None + | Some TensorOptions + when Option.is_some default_value && Set.mem no_tensor_options name -> + None + | Some arg_type -> + let arg_name = + match arg_name, arg_type with + | "self", Scalar -> "self_scalar" + | _, _ -> arg_name + in + Some { Func.arg_name; arg_type; default_value } + | None -> + if Option.is_some default_value then None else raise Not_a_simple_arg) + in + let args = + args ~with_optional_scalar_args:(Set.mem with_optional_scalar_args name) + in + Some [ { Func.name; operator_name; overload_name; args; returns; kind } ] + with + | Not_a_simple_arg -> None) + else None) let p out_channel s = Printf.ksprintf @@ -803,72 +796,71 @@ let print_inline out_channel s = let write_cpp funcs filename = Out_channel.with_file (filename ^ ".cpp.h") ~f:(fun out_cpp -> - Out_channel.with_file (filename ^ ".h") ~f:(fun out_h -> - let pc s = p out_cpp s in - let ph s = p out_h s in - pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; - pc ""; - ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; - ph ""; - Map.iteri funcs ~f:(fun ~key:exported_name ~data:func -> - let c_typed_args_list = Func.c_typed_args_list func in - match func.returns with - | `nothing -> - pc "void atg_%s(%s) {" exported_name c_typed_args_list; - pc " PROTECT("; - pc " %s;" (Func.c_call func); - pc " )"; - pc "}"; - pc ""; - ph "void atg_%s(%s);" exported_name c_typed_args_list - | `fixed ntensors -> - pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list; - pc " PROTECT("; - pc " auto outputs__ = %s;" (Func.c_call func); - if ntensors = 1 - then pc " out__[0] = new torch::Tensor(outputs__);" - else - for i = 0 to ntensors - 1 do - pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i - done; - pc " )"; - pc "}"; - pc ""; - ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list - | `dynamic -> - pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list; - pc " PROTECT("; - pc " auto outputs__ = %s;" (Func.c_call func); - (* the returned type is a C++ vector of tensors *) - pc " int sz = outputs__.size();"; - pc - " torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \ - sizeof(torch::Tensor*));"; - pc " for (int i = 0; i < sz; ++i)"; - pc " out__[i] = new torch::Tensor(outputs__[i]);"; - pc " out__[sz] = nullptr;"; - pc " return out__;"; - pc " )"; - pc " return nullptr;"; - pc "}"; - pc ""; - ph "tensor *atg_%s(%s);" exported_name c_typed_args_list - | (`bool | `int64_t | `double) as returns -> - let c_type = - match returns with - | `bool -> "int" - | `int64_t -> "int64_t" - | `double -> "double" - in - pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list; - pc " PROTECT("; - pc " return %s;" (Func.c_call func); - pc " )"; - pc " return 0;"; - pc "}"; - pc ""; - ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list))) - + Out_channel.with_file (filename ^ ".h") ~f:(fun out_h -> + let pc s = p out_cpp s in + let ph s = p out_h s in + pc "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; + pc ""; + ph "// THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND!"; + ph ""; + Map.iteri funcs ~f:(fun ~key:exported_name ~data:func -> + let c_typed_args_list = Func.c_typed_args_list func in + match func.returns with + | `nothing -> + pc "void atg_%s(%s) {" exported_name c_typed_args_list; + pc " PROTECT("; + pc " %s;" (Func.c_call func); + pc " )"; + pc "}"; + pc ""; + ph "void atg_%s(%s);" exported_name c_typed_args_list + | `fixed ntensors -> + pc "void atg_%s(tensor *out__, %s) {" exported_name c_typed_args_list; + pc " PROTECT("; + pc " auto outputs__ = %s;" (Func.c_call func); + if ntensors = 1 + then pc " out__[0] = new torch::Tensor(outputs__);" + else + for i = 0 to ntensors - 1 do + pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i + done; + pc " )"; + pc "}"; + pc ""; + ph "void atg_%s(tensor *, %s);" exported_name c_typed_args_list + | `dynamic -> + pc "tensor *atg_%s(%s) {" exported_name c_typed_args_list; + pc " PROTECT("; + pc " auto outputs__ = %s;" (Func.c_call func); + (* the returned type is a C++ vector of tensors *) + pc " int sz = outputs__.size();"; + pc + " torch::Tensor **out__ = (torch::Tensor**)malloc((sz + 1) * \ + sizeof(torch::Tensor*));"; + pc " for (int i = 0; i < sz; ++i)"; + pc " out__[i] = new torch::Tensor(outputs__[i]);"; + pc " out__[sz] = nullptr;"; + pc " return out__;"; + pc " )"; + pc " return nullptr;"; + pc "}"; + pc ""; + ph "tensor *atg_%s(%s);" exported_name c_typed_args_list + | (`bool | `int64_t | `double) as returns -> + let c_type = + match returns with + | `bool -> "int" + | `int64_t -> "int64_t" + | `double -> "double" + in + pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list; + pc " PROTECT("; + pc " return %s;" (Func.c_call func); + pc " )"; + pc " return 0;"; + pc "}"; + pc ""; + ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list))) let write_wrapper funcs filename = Out_channel.with_file filename ~f:(fun out_ml -> @@ -1402,54 +1394,43 @@ let methods = ; c "to" [ ca "self" Tensor; ca "device" Device ] ] - let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename ~wrapper_filename = - let funcs = read_yaml yaml_filename in + let funcs = read_yaml yaml_filename |> List.concat in let funcs = methods @ funcs in - printf "Generating code for %d functions.\n%!" (List.length funcs) ; + printf "Generating code for %d functions.\n%!" (List.length funcs); (* Generate some unique names for overloaded functions. *) let funcs = - List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func)) + List.map funcs ~f:(fun func -> Func.operator_name func, func) |> Map.of_alist_multi (module String) |> Map.to_alist |> List.concat_map ~f:(fun (name, funcs) -> - match funcs with - | [] -> assert false - | [func] -> [(name, func)] - | funcs -> - let has_empty_overload = - List.exists funcs ~f:(fun (func : Func.t) -> - String.is_empty func.overload_name ) - in - List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) -> - match - Int.compare (String.length f1.name) - (String.length f2.name) - with - | 0 -> - Int.compare (List.length f1.args) (List.length f2.args) - | cmp -> cmp ) - |> List.mapi ~f:(fun index (func : Func.t) -> - let operator_name = - String.lowercase func.operator_name - in - let overload_name = - String.lowercase func.overload_name - in - let name = - if - String.is_empty overload_name - || (index = 0 && not has_empty_overload) - then operator_name - else if String.is_suffix operator_name ~suffix:"_" then - operator_name ^ overload_name ^ "_" - else operator_name ^ "_" ^ overload_name - in - (name, func) ) ) + match funcs with + | [] -> assert false + | [ func ] -> [ name, func ] + | funcs -> + let has_empty_overload = + List.exists funcs ~f:(fun (func : Func.t) -> + String.is_empty func.overload_name) + in + List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) -> + match Int.compare (String.length f1.name) (String.length f2.name) with + | 0 -> Int.compare (List.length f1.args) (List.length f2.args) + | cmp -> cmp) + |> List.mapi ~f:(fun index (func : Func.t) -> + let operator_name = Func.operator_name func in + let overload_name = String.lowercase func.overload_name in + let name = + if String.is_empty overload_name || (index = 0 && not has_empty_overload) + then operator_name + else if String.is_suffix operator_name ~suffix:"_" + then operator_name ^ overload_name ^ "_" + else operator_name ^ "_" ^ overload_name + in + name, func)) |> Map.of_alist_exn (module String) in - write_cpp funcs cpp_filename ; + write_cpp funcs cpp_filename; write_ffi funcs ffi_filename ; write_must_wrapper funcs must_wrapper_filename ; write_wrapper funcs wrapper_filename diff --git a/libtch/c-generated.go b/libtch/c-generated.go index d298afd..5b6a4fd 100644 --- a/libtch/c-generated.go +++ b/libtch/c-generated.go @@ -2666,10 +2666,10 @@ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind)) coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice)) C.atg_arange_start(ptr, start , end , coptionsKind, coptionsDevice) } -func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, optionsKind int32, optionsDevice int32){ +func AtgArangeStartStep(ptr *Ctensor, start Cscalar, end Cscalar, step Cscalar, optionsKind int32, optionsDevice int32){ coptionsKind := *(*C.int)(unsafe.Pointer(&optionsKind)) coptionsDevice := *(*C.int)(unsafe.Pointer(&optionsDevice)) - C.atg_arange_start_step(ptr, start , end , coptionsKind, coptionsDevice) + C.atg_arange_start_step(ptr, start , end , step , coptionsKind, coptionsDevice) } func AtgArccos(ptr *Ctensor, self Ctensor){ C.atg_arccos(ptr, self) @@ -3004,8 +3004,8 @@ cdivisorOverrideVal := *(*C.int64_t)(unsafe.Pointer(&divisorOverrideVal)) cdivisorOverrideNull := *(*C.uint8_t)(unsafe.Pointer(&divisorOverrideNull)) C.atg_avg_pool3d_out(ptr, out, self, ckernelSizeDataPtr, ckernelSizeLen, cstrideDataPtr, cstrideLen, cpaddingDataPtr, cpaddingLen, cceilMode, ccountIncludePad, cdivisorOverrideVal, cdivisorOverrideNull) } -func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){ - C.atg_baddbmm(ptr, self, batch1, batch2) +func AtgBaddbmm(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor, beta Cscalar, alpha Cscalar){ + C.atg_baddbmm(ptr, self, batch1, batch2, beta , alpha ) } func AtgBaddbmm_(ptr *Ctensor, self Ctensor, batch1 Ctensor, batch2 Ctensor){ C.atg_baddbmm_(ptr, self, batch1, batch2) diff --git a/libtch/torch_api_generated.cpp.h b/libtch/torch_api_generated.cpp.h index a4de779..f79121f 100644 --- a/libtch/torch_api_generated.cpp.h +++ b/libtch/torch_api_generated.cpp.h @@ -2068,7 +2068,7 @@ void atg__slow_conv2d_backward(tensor *out__, tensor grad_input, tensor grad_wei void atg__sobol_engine_draw(tensor *out__, tensor quasi, int64_t n, tensor sobolstate, int64_t dimension, int64_t num_generated, int dtype) { PROTECT( - auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, at::ScalarType(dtype)); + auto outputs__ = torch::_sobol_engine_draw(*quasi, n, *sobolstate, dimension, num_generated, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(std::get<0>(outputs__)); out__[1] = new torch::Tensor(std::get<1>(outputs__)); ) @@ -2223,28 +2223,28 @@ void atg__sparse_csc_tensor_unsafe(tensor *out__, tensor ccol_indices, tensor ro void atg__sparse_csr_prod(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_csr_prod(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg__sparse_csr_prod_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_csr_prod_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg__sparse_csr_sum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_csr_sum(*self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg__sparse_csr_sum_dim_dtype_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_csr_sum_out(*out, *self, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -2279,7 +2279,7 @@ void atg__sparse_log_softmax_backward_data_out(tensor *out__, tensor out, tensor void atg__sparse_log_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_log_softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -2336,7 +2336,7 @@ void atg__sparse_softmax_backward_data_out(tensor *out__, tensor out, tensor gra void atg__sparse_softmax_int(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::_sparse_softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::_sparse_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -2629,14 +2629,14 @@ tensor *atg__to_cpu(tensor *tensors_data, int tensors_len) { void atg__to_dense(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = self->_to_dense(at::ScalarType(dtype)); + auto outputs__ = self->_to_dense(dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg__to_dense_out(tensor *out__, tensor out, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::_to_dense_out(*out, *self, at::ScalarType(dtype)); + auto outputs__ = torch::_to_dense_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -3640,9 +3640,9 @@ void atg_arange_start(tensor *out__, scalar start, scalar end, int options_kind, ) } -void atg_arange_start_step(tensor *out__, scalar start, scalar end, int options_kind, int options_device) { +void atg_arange_start_step(tensor *out__, scalar start, scalar end, scalar step, int options_kind, int options_device) { PROTECT( - auto outputs__ = torch::arange(*start, *end, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind))); + auto outputs__ = torch::arange(*start, *end, *step, at::device(device_of_int(options_device)).dtype(at::ScalarType(options_kind))); out__[0] = new torch::Tensor(outputs__); ) } @@ -4120,9 +4120,9 @@ void atg_avg_pool3d_out(tensor *out__, tensor out, tensor self, int64_t *kernel_ ) } -void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2) { +void atg_baddbmm(tensor *out__, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha) { PROTECT( - auto outputs__ = torch::baddbmm(*self, *batch1, *batch2); + auto outputs__ = torch::baddbmm(*self, *batch1, *batch2, *beta, *alpha); out__[0] = new torch::Tensor(outputs__); ) } @@ -5910,14 +5910,14 @@ void atg_cummin_out(tensor *out__, tensor values, tensor indices, tensor self, i void atg_cumprod(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::cumprod(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::cumprod(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_cumprod_(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = self->cumprod_(dim, at::ScalarType(dtype)); + auto outputs__ = self->cumprod_(dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -5931,28 +5931,28 @@ void atg_cumprod_backward(tensor *out__, tensor grad, tensor input, int64_t dim, void atg_cumprod_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::cumprod_out(*out, *self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::cumprod_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_cumsum(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::cumsum(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::cumsum(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_cumsum_(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = self->cumsum_(dim, at::ScalarType(dtype)); + auto outputs__ = self->cumsum_(dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_cumsum_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::cumsum_out(*out, *self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::cumsum_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -9912,28 +9912,28 @@ void atg_linalg_multi_dot_out(tensor *out__, tensor out, tensor *tensors_data, i void atg_linalg_norm(tensor *out__, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::linalg_norm(*self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_linalg_norm_ord_str(tensor *out__, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::linalg_norm(*self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_linalg_norm_ord_str_out(tensor *out__, tensor out, tensor self, char* ord_ptr, int ord_len, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::linalg_norm_out(*out, *self, std::string(ord_ptr, ord_len), dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_linalg_norm_out(tensor *out__, tensor out, tensor self, scalar ord, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::linalg_norm_out(*out, *self, *ord, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -10314,14 +10314,14 @@ void atg_log_sigmoid_out(tensor *out__, tensor out, tensor self) { void atg_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::log_softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_log_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::log_softmax_out(*out, *self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::log_softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -10953,21 +10953,21 @@ void atg_maximum_out(tensor *out__, tensor out, tensor self, tensor other) { void atg_mean(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::mean(*self, at::ScalarType(dtype)); + auto outputs__ = torch::mean(*self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_mean_dim(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::mean(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_mean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::mean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -11742,14 +11742,14 @@ void atg_nan_to_num_out(tensor *out__, tensor out, tensor self, double nan_v, ui void atg_nanmean(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::nanmean(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_nanmean_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::nanmean_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -11814,14 +11814,14 @@ void atg_nanquantile_scalar_out(tensor *out__, tensor out, tensor self, double q void atg_nansum(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::nansum(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_nansum_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::nansum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -11961,14 +11961,14 @@ void atg_native_norm_out(tensor *out__, tensor out, tensor self) { void atg_native_norm_scalaropt_dim_dtype(tensor *out__, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::native_norm(*self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_native_norm_scalaropt_dim_dtype_out(tensor *out__, tensor out, tensor self, scalar p, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::native_norm_out(*out, *self, *p, torch::IntArrayRef(dim_data, dim_len), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -12702,28 +12702,28 @@ void atg_prelu(tensor *out__, tensor self, tensor weight) { void atg_prod(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::prod(*self, at::ScalarType(dtype)); + auto outputs__ = torch::prod(*self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_prod_dim_int(tensor *out__, tensor self, int64_t dim, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::prod(*self, dim, (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::prod(*self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_prod_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::prod_out(*out, *self, dim, (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_prod_out(tensor *out__, tensor out, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::prod_out(*out, *self, at::ScalarType(dtype)); + auto outputs__ = torch::prod_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -14593,14 +14593,14 @@ void atg_soft_margin_loss_out(tensor *out__, tensor out, tensor self, tensor tar void atg_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_softmax_int_out(tensor *out__, tensor out, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::softmax_out(*out, *self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::softmax_out(*out, *self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -15528,7 +15528,7 @@ void atg_special_log_ndtr_out(tensor *out__, tensor out, tensor self) { void atg_special_log_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::special_log_softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::special_log_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -15913,7 +15913,7 @@ void atg_special_sinc_out(tensor *out__, tensor out, tensor self) { void atg_special_softmax(tensor *out__, tensor self, int64_t dim, int dtype) { PROTECT( - auto outputs__ = torch::special_softmax(*self, dim, at::ScalarType(dtype)); + auto outputs__ = torch::special_softmax(*self, dim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -16437,28 +16437,28 @@ void atg_subtract_scalar_(tensor *out__, tensor self, scalar other) { void atg_sum(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::sum(*self, at::ScalarType(dtype)); + auto outputs__ = torch::sum(*self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_sum_dim_intlist(tensor *out__, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::sum(*self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_sum_intlist_out(tensor *out__, tensor out, tensor self, int64_t *dim_data, int dim_len, int keepdim, int dtype) { PROTECT( - auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, at::ScalarType(dtype)); + auto outputs__ = torch::sum_out(*out, *self, dim_data == nullptr ? c10::nullopt : c10::optional(torch::IntArrayRef(dim_data, dim_len)), (bool)keepdim, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } void atg_sum_out(tensor *out__, tensor out, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::sum_out(*out, *self, at::ScalarType(dtype)); + auto outputs__ = torch::sum_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -16732,7 +16732,7 @@ void atg_to(tensor *out__, tensor self, int device) { void atg_to_dense(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = self->to_dense(at::ScalarType(dtype)); + auto outputs__ = self->to_dense(dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -16767,7 +16767,7 @@ void atg_to_dtype_layout(tensor *out__, tensor self, int options_kind, int optio void atg_to_mkldnn(tensor *out__, tensor self, int dtype) { PROTECT( - auto outputs__ = self->to_mkldnn(at::ScalarType(dtype)); + auto outputs__ = self->to_mkldnn(dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } @@ -16781,7 +16781,7 @@ void atg_to_mkldnn_backward(tensor *out__, tensor grad, tensor input) { void atg_to_mkldnn_out(tensor *out__, tensor out, tensor self, int dtype) { PROTECT( - auto outputs__ = torch::to_mkldnn_out(*out, *self, at::ScalarType(dtype)); + auto outputs__ = torch::to_mkldnn_out(*out, *self, dtype < 0 ? c10::nullopt : c10::optional(at::ScalarType(dtype))); out__[0] = new torch::Tensor(outputs__); ) } diff --git a/libtch/torch_api_generated.h b/libtch/torch_api_generated.h index 5c92e14..92829c1 100644 --- a/libtch/torch_api_generated.h +++ b/libtch/torch_api_generated.h @@ -497,7 +497,7 @@ void atg_any_dim(tensor *, tensor self, int64_t dim, int keepdim); void atg_any_out(tensor *, tensor out, tensor self, int64_t dim, int keepdim); void atg_arange(tensor *, scalar end, int options_kind, int options_device); void atg_arange_start(tensor *, scalar start, scalar end, int options_kind, int options_device); -void atg_arange_start_step(tensor *, scalar start, scalar end, int options_kind, int options_device); +void atg_arange_start_step(tensor *, scalar start, scalar end, scalar step, int options_kind, int options_device); void atg_arccos(tensor *, tensor self); void atg_arccos_(tensor *, tensor self); void atg_arccos_out(tensor *, tensor out, tensor self); @@ -563,7 +563,7 @@ void atg_avg_pool3d(tensor *, tensor self, int64_t *kernel_size_data, int kernel void atg_avg_pool3d_backward(tensor *, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); void atg_avg_pool3d_backward_grad_input(tensor *, tensor grad_input, tensor grad_output, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); void atg_avg_pool3d_out(tensor *, tensor out, tensor self, int64_t *kernel_size_data, int kernel_size_len, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int ceil_mode, int count_include_pad, int64_t divisor_override_v, uint8_t divisor_override_null); -void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2); +void atg_baddbmm(tensor *, tensor self, tensor batch1, tensor batch2, scalar beta, scalar alpha); void atg_baddbmm_(tensor *, tensor self, tensor batch1, tensor batch2); void atg_baddbmm_out(tensor *, tensor out, tensor self, tensor batch1, tensor batch2); void atg_bartlett_window(tensor *, int64_t window_length, int options_kind, int options_device); diff --git a/ts/must-tensor-generated.go b/ts/must-tensor-generated.go index 0b1bed3..9120973 100644 --- a/ts/must-tensor-generated.go +++ b/ts/must-tensor-generated.go @@ -3961,9 +3961,9 @@ func MustArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, option return retVal } -func MustArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) { +func MustArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor) { - retVal, err := ArangeStartStep(start, end, optionsKind, optionsDevice) + retVal, err := ArangeStartStep(start, end, step, optionsKind, optionsDevice) if err != nil { log.Fatal(err) } return retVal @@ -4465,9 +4465,9 @@ func(ts *Tensor) MustAvgPool3dOut(out *Tensor, kernelSize []int64, stride []int6 return retVal } -func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor) { +func(ts *Tensor) MustBaddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor) { - retVal, err := ts.Baddbmm(batch1, batch2, del) + retVal, err := ts.Baddbmm(batch1, batch2, beta, alpha, del) if err != nil { log.Fatal(err) } return retVal diff --git a/ts/tensor-generated.go b/ts/tensor-generated.go index 1cf7c7c..d13aea0 100644 --- a/ts/tensor-generated.go +++ b/ts/tensor-generated.go @@ -9327,10 +9327,10 @@ func ArangeStart(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDev // func.returns = `fixed 1`: // -------------------------- -func ArangeStartStep(start *Scalar, end *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) { +func ArangeStartStep(start *Scalar, end *Scalar, step *Scalar, optionsKind gotch.DType, optionsDevice gotch.Device)(retVal *Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, optionsKind.CInt(), optionsDevice.CInt()) + lib.AtgArangeStartStep(ptr, start.cscalar, end.cscalar, step.cscalar, optionsKind.CInt(), optionsDevice.CInt()) if err = TorchErr(); err != nil { err = fmt.Errorf("ArangeStartStep() failed: %w", err) return retVal, err @@ -10585,11 +10585,11 @@ var cdivisorOverrideVal int64 = 0 // func.returns = `fixed 1`: // -------------------------- -func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, del bool)(retVal *Tensor, err error) { +func(ts *Tensor) Baddbmm(batch1 *Tensor, batch2 *Tensor, beta *Scalar, alpha *Scalar, del bool)(retVal *Tensor, err error) { if del { defer ts.MustDrop() } ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) - lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor) + lib.AtgBaddbmm(ptr, ts.ctensor, batch1.ctensor, batch2.ctensor, beta.cscalar, alpha.cscalar) if err = TorchErr(); err != nil { err = fmt.Errorf("Baddbmm() failed: %w", err) return retVal, err diff --git a/vision/aug/function.go b/vision/aug/function.go index 974bce0..344c846 100644 --- a/vision/aug/function.go +++ b/vision/aug/function.go @@ -255,9 +255,9 @@ func rgb2Gray(x *ts.Tensor, outChanOpt ...int64) *ts.Tensor { } rgbTs := x.MustUnbind(-3, false) - r := &rgbTs[0] - g := &rgbTs[1] - b := &rgbTs[2] + r := rgbTs[0] + g := rgbTs[1] + b := rgbTs[2] // This implementation closely follows the TF one: // https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/ops/image_ops_impl.py#L2105-L2138 @@ -453,7 +453,7 @@ func hsv2RGB(x *ts.Tensor) *ts.Tensor { a3 := ts.MustStack([]*ts.Tensor{p, p, t, v, v, q}, -3) a4 := ts.MustStack([]*ts.Tensor{a1, a2, a3}, -4) - out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}) + out := ts.MustEinsum("...ijk, ...xijk -> ...xjk", []*ts.Tensor{mask, a4}, []int64{0, 1}) // Delete intermediate tensors h.MustDrop() @@ -579,7 +579,7 @@ func crop(x *ts.Tensor, top, left, height, width int64) *ts.Tensor { x2 := x1T.Idx(wNar) x1T.MustDrop() out := x2.MustT(true) - chans[i] = *out + chans[i] = out } cropTs := ts.MustStack(chans, 0)