WIP: update to v1.9
This commit is contained in:
parent
6a3bb63048
commit
bc12fc8605
290
gen/gen.ml
290
gen/gen.ml
|
@ -1,8 +1,7 @@
|
|||
(* Automatically generate the C++ -> C -> Go bindings.
|
||||
This takes as input the Descriptions.yaml file that gets generated when
|
||||
func (Func.c_go_args_list func) building PyTorch from source.
|
||||
(* Automatically generated C++ -> C -> Go bindings.
|
||||
Input: Declarations-VERSION.yaml artifact generated when building Pytorch from source.
|
||||
Run with: dune exec gen/gen.exe
|
||||
*)
|
||||
*)
|
||||
open Base
|
||||
open Stdio
|
||||
|
||||
|
@ -29,7 +28,14 @@ let excluded_functions =
|
|||
; "_cummin_helper"
|
||||
; "_cummax_helper"
|
||||
; "retain_grad"
|
||||
; "_validate_sparse_coo_tensor_args" ]
|
||||
; "_validate_sparse_coo_tensor_args"
|
||||
; "_backward"
|
||||
; "size"
|
||||
; "stride"
|
||||
; "_assert_async"
|
||||
; "gradient"
|
||||
; "linalg_vector_norm"
|
||||
; "linalg_vector_norm_out" ]
|
||||
|
||||
let no_tensor_options =
|
||||
Set.of_list
|
||||
|
@ -85,9 +91,12 @@ module Func = struct
|
|||
| DoubleOption
|
||||
| Tensor
|
||||
| TensorOption
|
||||
(* Tensor.t option *)
|
||||
| IntList
|
||||
| TensorOptList
|
||||
| TensorList
|
||||
| TensorOptions
|
||||
(* Tensor kind and device *)
|
||||
| Scalar
|
||||
| ScalarType
|
||||
| Device
|
||||
|
@ -99,8 +108,10 @@ module Func = struct
|
|||
(* `Func` type *)
|
||||
type t =
|
||||
{ name: string
|
||||
; args: arg list
|
||||
; returns: [`fixed of int | `dynamic]
|
||||
; operator_name: string
|
||||
; overload_name: string
|
||||
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
|
||||
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double]
|
||||
; (* number of tensors that are returned *)
|
||||
kind: [`function_ | `method_] }
|
||||
|
||||
|
@ -109,14 +120,14 @@ module Func = struct
|
|||
| "bool" -> Some Bool
|
||||
| "int64_t" -> Some (if is_nullable then Int64Option else Int64)
|
||||
| "double" -> Some (if is_nullable then DoubleOption else Double)
|
||||
| "booltensor" | "indextensor" | "tensor" ->
|
||||
Some (if is_nullable then TensorOption else Tensor)
|
||||
| "tensoroptions" -> Some TensorOptions
|
||||
| "intarrayref" | "intlist" -> Some IntList
|
||||
| "tensorlist" -> Some TensorList
|
||||
| "device" -> Some Device
|
||||
| "scalar" -> Some Scalar
|
||||
| "scalartype" -> Some ScalarType
|
||||
| "at::tensor" -> Some (if is_nullable then TensorOption else Tensor)
|
||||
| "at::tensoroptions" -> Some TensorOptions
|
||||
| "at::intarrayref" -> Some IntList
|
||||
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
|
||||
| "at::tensorlist" -> Some TensorList
|
||||
| "at::device" -> Some Device
|
||||
| "const at::scalar &" | "at::scalar" -> Some Scalar
|
||||
| "at::scalartype" -> Some ScalarType
|
||||
| "std::string" -> Some String
|
||||
| _ -> None
|
||||
|
||||
|
@ -125,7 +136,7 @@ module Func = struct
|
|||
match arg_type with
|
||||
| IntList ->
|
||||
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorList ->
|
||||
| TensorOptList | TensorList ->
|
||||
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
|
||||
|
@ -145,8 +156,8 @@ module Func = struct
|
|||
| ScalarType -> "int"
|
||||
| Device -> "int"
|
||||
| Scalar -> "scalar"
|
||||
| Int64Option | DoubleOption | String | IntList | TensorList
|
||||
|TensorOptions ->
|
||||
| Int64Option | DoubleOption | String | IntList | TensorOptList
|
||||
|TensorList | TensorOptions ->
|
||||
assert false
|
||||
in
|
||||
Printf.sprintf "%s %s" simple_type_cstring arg_name )
|
||||
|
@ -164,6 +175,9 @@ module Func = struct
|
|||
arg_name
|
||||
| String ->
|
||||
Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
|
||||
| TensorOptList ->
|
||||
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name
|
||||
arg_name
|
||||
| TensorList ->
|
||||
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name
|
||||
arg_name
|
||||
|
@ -196,14 +210,15 @@ module Func = struct
|
|||
t.name () )
|
||||
|
||||
(*
|
||||
* let replace_map =
|
||||
* Map.of_alist_exn
|
||||
* (module String)
|
||||
* [ ("t", "tr")
|
||||
* ; ("where", "where_")
|
||||
* ; ("view", "view_")
|
||||
* ; ("unsafe", "unsafe_") ]
|
||||
* *)
|
||||
let replace_map =
|
||||
Map.of_alist_exn
|
||||
(module String)
|
||||
[ ("t", "tr")
|
||||
; ("where", "where_")
|
||||
; ("view", "view_")
|
||||
; ("unsafe", "unsafe_")
|
||||
; ("to_device", "to_device_") ]
|
||||
*)
|
||||
|
||||
let is_method t =
|
||||
List.exists t.args ~f:(fun arg ->
|
||||
|
@ -245,6 +260,7 @@ module Func = struct
|
|||
| Device -> single_param "int32"
|
||||
| String -> single_param "string"
|
||||
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
|
||||
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
|
||||
| Int64Option -> Printf.sprintf "%sVal int64, %sNull int" an an
|
||||
| DoubleOption -> Printf.sprintf "%sVal float64, %sNull int" an an
|
||||
|
@ -268,6 +284,7 @@ module Func = struct
|
|||
| Device -> Printf.sprintf "c%s" an
|
||||
| String -> Printf.sprintf "c%s, c%sLen" an an
|
||||
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
|
||||
| Int64Option -> Printf.sprintf "c%sVal, c%sNull" an an
|
||||
| DoubleOption -> Printf.sprintf "c%sVal, c%sNull" an an
|
||||
|
@ -306,6 +323,12 @@ module Func = struct
|
|||
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an
|
||||
| TensorOptList ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
c%sDataPtr := (*Ctensor)(unsafe.Pointer(&%sData[0]))\n\
|
||||
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
|
||||
an an an an
|
||||
| TensorList ->
|
||||
Printf.sprintf
|
||||
"\n\
|
||||
|
@ -382,6 +405,7 @@ module Func = struct
|
|||
| Tensor -> "*Tensor"
|
||||
| TensorOption -> "*Tensor"
|
||||
| IntList -> "[]int64"
|
||||
| TensorOptList -> "[]Tensor"
|
||||
| TensorList -> "[]Tensor"
|
||||
| String -> "string"
|
||||
(* TODO. Struct{Kind gotch.DType Device gotch.Device} *)
|
||||
|
@ -435,6 +459,9 @@ module Func = struct
|
|||
List.init v ~f:(fun i -> Printf.sprintf "retVal%d *Tensor" i)
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||
| `dynamic -> "retVal []Tensor"
|
||||
| `bool -> "retVal bool"
|
||||
| `int64_t -> "retVal int64"
|
||||
| `double -> "retVal float64"
|
||||
in
|
||||
if is_inplace t then
|
||||
if fallible then Printf.sprintf "err error" else Printf.sprintf ""
|
||||
|
@ -449,6 +476,9 @@ module Func = struct
|
|||
List.init v ~f:(fun i -> Printf.sprintf "retVal%d" i)
|
||||
|> String.concat ~sep:", " |> Printf.sprintf "%s"
|
||||
| `dynamic -> "retVal"
|
||||
| `bool -> "retVal"
|
||||
| `int64_t -> "retVal"
|
||||
| `double -> "retVal"
|
||||
in
|
||||
if is_inplace t then
|
||||
if fallible then Printf.sprintf "err" else Printf.sprintf ""
|
||||
|
@ -511,6 +541,11 @@ module Func = struct
|
|||
\ c%sNull = 0\n\
|
||||
\ }\n"
|
||||
an an an an an an
|
||||
| TensorOptList ->
|
||||
Printf.sprintf
|
||||
" var c%s []lib.Ctensor\n\
|
||||
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
|
||||
an an an an
|
||||
| TensorList ->
|
||||
Printf.sprintf
|
||||
" var c%s []lib.Ctensor\n\
|
||||
|
@ -536,6 +571,8 @@ let read_yaml filename =
|
|||
List.filter_map funcs ~f:(fun yaml ->
|
||||
let map = extract_map yaml in
|
||||
let name = Map.find_exn map "name" |> extract_string in
|
||||
let operator_name = Map.find_exn map "operator_name" |> extract_string in
|
||||
let overload_name = Map.find_exn map "overload_name" |> extract_string in
|
||||
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
|
||||
let method_of =
|
||||
Map.find_exn map "method_of"
|
||||
|
@ -548,22 +585,26 @@ let read_yaml filename =
|
|||
let return_type =
|
||||
Map.find_exn returns "dynamic_type" |> extract_string
|
||||
in
|
||||
String.( = ) return_type "Tensor"
|
||||
|| String.( = ) return_type "BoolTensor"
|
||||
|| String.( = ) return_type "IndexTensor"
|
||||
String.( = ) return_type "at::Tensor"
|
||||
in
|
||||
let returns = Map.find_exn map "returns" |> extract_list in
|
||||
if List.for_all returns ~f:is_tensor then
|
||||
Some (`fixed (List.length returns))
|
||||
else
|
||||
match returns with
|
||||
| [returns] ->
|
||||
| [returns] -> (
|
||||
let return_type =
|
||||
Map.find_exn (extract_map returns) "dynamic_type"
|
||||
|> extract_string
|
||||
in
|
||||
if String.( = ) return_type "TensorList" then Some `dynamic
|
||||
else None
|
||||
match return_type with
|
||||
| "bool" -> Some `bool
|
||||
| "int64_t" -> Some `int64_t
|
||||
| "double" -> Some `double
|
||||
| "at::TensorList"
|
||||
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
|
||||
Some `dynamic
|
||||
| _ -> None )
|
||||
| [] | _ :: _ :: _ -> None
|
||||
in
|
||||
let kind =
|
||||
|
@ -622,7 +663,13 @@ let read_yaml filename =
|
|||
if Option.is_some default_value then None
|
||||
else raise Not_a_simple_arg )
|
||||
in
|
||||
Some {Func.name; args; returns; kind}
|
||||
Some
|
||||
{ Func.name
|
||||
; operator_name
|
||||
; overload_name
|
||||
; args
|
||||
; returns
|
||||
; kind }
|
||||
with Not_a_simple_arg -> None )
|
||||
else None )
|
||||
|
||||
|
@ -684,8 +731,23 @@ let write_cpp funcs filename =
|
|||
pc " return nullptr;" ;
|
||||
pc "}" ;
|
||||
pc "" ;
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ) )
|
||||
)
|
||||
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
|
||||
| (`bool | `int64_t | `double) as returns ->
|
||||
let c_type =
|
||||
match returns with
|
||||
| `bool -> "int"
|
||||
| `int64_t -> "int64_t"
|
||||
| `double -> "double"
|
||||
in
|
||||
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list ;
|
||||
pc " PROTECT(" ;
|
||||
pc " return %s;" (Func.c_call func) ;
|
||||
pc " )" ;
|
||||
pc " return 0;" ;
|
||||
pc "}" ;
|
||||
pc "" ;
|
||||
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list )
|
||||
) )
|
||||
|
||||
let write_wrapper funcs filename =
|
||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
|
@ -777,7 +839,8 @@ let write_wrapper funcs filename =
|
|||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n"
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
|
@ -799,10 +862,62 @@ let write_wrapper funcs filename =
|
|||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n"
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `int64_t ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `double ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
|
@ -913,6 +1028,57 @@ let write_must_wrapper funcs filename =
|
|||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `int64_t ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `double ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
if is_method then
|
||||
pm " retVal, err := ts.%s(%s)\n" gofunc_name
|
||||
go_args_list_notype
|
||||
else
|
||||
pm " retVal, err := %s(%s)\n" gofunc_name
|
||||
go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
|
@ -969,19 +1135,29 @@ let write_ffi funcs filename =
|
|||
in
|
||||
match func.Func.returns with
|
||||
| `fixed _ ->
|
||||
pm "func Atg%s(ptr *Ctensor, %s){%s \nC.atg_%s(ptr, %s)\n}"
|
||||
pm "func Atg%s(ptr *Ctensor, %s){%s \n\tC.atg_%s(ptr, %s)\n}"
|
||||
ffifunc_name (Func.c_go_args_list func)
|
||||
(Func.c_go_args_list_body func)
|
||||
exported_name
|
||||
(Func.c_go_args_list_notype func)
|
||||
| `dynamic -> pm ""
|
||||
| `bool -> pm ""
|
||||
| `int64_t -> pm ""
|
||||
| `double -> pm ""
|
||||
(* TODO: need more implement here *)
|
||||
(* pm "func Atg%s(%s)(retValPtr *Ctensor)" *)
|
||||
(* (Func.go_name exported_name) *)
|
||||
(* (Func.c_go_args_list func) *) ) )
|
||||
|
||||
let methods =
|
||||
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
||||
let c name args =
|
||||
{ Func.name
|
||||
; operator_name= name
|
||||
; overload_name= ""
|
||||
; args
|
||||
; returns= `fixed 1
|
||||
; kind= `method_ }
|
||||
in
|
||||
let ca arg_name arg_type = {Func.arg_name; arg_type; default_value= None} in
|
||||
[ c "grad" [ca "self" Tensor]
|
||||
; c "set_requires_grad" [ca "self" Tensor; ca "r" Bool]
|
||||
|
@ -995,7 +1171,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
|
||||
(* Generate some unique names for overloaded functions. *)
|
||||
let funcs =
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
|
||||
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
|
||||
|> Map.of_alist_multi (module String)
|
||||
|> Map.to_alist
|
||||
|> List.concat_map ~f:(fun (name, funcs) ->
|
||||
|
@ -1003,11 +1179,35 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
| [] -> assert false
|
||||
| [func] -> [(name, func)]
|
||||
| funcs ->
|
||||
let has_empty_overload =
|
||||
List.exists funcs ~f:(fun (func : Func.t) ->
|
||||
String.is_empty func.overload_name )
|
||||
in
|
||||
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args) )
|
||||
|> List.mapi ~f:(fun i func ->
|
||||
( (if i = 0 then name else Printf.sprintf "%s%d" name i)
|
||||
, func ) ) )
|
||||
match
|
||||
Int.compare (String.length f1.name)
|
||||
(String.length f2.name)
|
||||
with
|
||||
| 0 ->
|
||||
Int.compare (List.length f1.args) (List.length f2.args)
|
||||
| cmp -> cmp )
|
||||
|> List.mapi ~f:(fun index (func : Func.t) ->
|
||||
let operator_name =
|
||||
String.lowercase func.operator_name
|
||||
in
|
||||
let overload_name =
|
||||
String.lowercase func.overload_name
|
||||
in
|
||||
let name =
|
||||
if
|
||||
String.is_empty overload_name
|
||||
|| (index = 0 && not has_empty_overload)
|
||||
then operator_name
|
||||
else if String.is_suffix operator_name ~suffix:"_" then
|
||||
operator_name ^ overload_name ^ "_"
|
||||
else operator_name ^ "_" ^ overload_name
|
||||
in
|
||||
(name, func) ) )
|
||||
|> Map.of_alist_exn (module String)
|
||||
in
|
||||
write_cpp funcs cpp_filename ;
|
||||
|
@ -1016,7 +1216,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
|
|||
write_wrapper funcs wrapper_filename
|
||||
|
||||
let () =
|
||||
run ~yaml_filename:"gen/pytorch/Declarations-v1.7.0.yaml"
|
||||
run ~yaml_filename:"gen/pytorch/Declarations-v1.9.0.yaml"
|
||||
~cpp_filename:"libtch/torch_api_generated"
|
||||
~ffi_filename:"libtch/c-generated.go"
|
||||
~must_wrapper_filename:"tensor/must-tensor-generated.go"
|
||||
|
|
129672
gen/pytorch/Declarations-v1.9.0.yaml
Normal file
129672
gen/pytorch/Declarations-v1.9.0.yaml
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user