WIP: update to v1.9

This commit is contained in:
sugarme 2021-07-22 00:38:55 +10:00
parent 6a3bb63048
commit bc12fc8605
7 changed files with 146977 additions and 6186 deletions

View File

@ -1,8 +1,7 @@
(* Automatically generate the C++ -> C -> Go bindings.
This takes as input the Descriptions.yaml file that gets generated when
func (Func.c_go_args_list func) building PyTorch from source.
(* Automatically generated C++ -> C -> Go bindings.
Input: Declarations-VERSION.yaml artifact generated when building Pytorch from source.
Run with: dune exec gen/gen.exe
*)
*)
open Base
open Stdio
@ -29,7 +28,14 @@ let excluded_functions =
; "_cummin_helper"
; "_cummax_helper"
; "retain_grad"
; "_validate_sparse_coo_tensor_args" ]
; "_validate_sparse_coo_tensor_args"
; "_backward"
; "size"
; "stride"
; "_assert_async"
; "gradient"
; "linalg_vector_norm"
; "linalg_vector_norm_out" ]
let no_tensor_options =
Set.of_list
@ -85,9 +91,12 @@ module Func = struct
| DoubleOption
| Tensor
| TensorOption
(* Tensor.t option *)
| IntList
| TensorOptList
| TensorList
| TensorOptions
(* Tensor kind and device *)
| Scalar
| ScalarType
| Device
@ -99,8 +108,10 @@ module Func = struct
(* `Func` type *)
type t =
{ name: string
; args: arg list
; returns: [`fixed of int | `dynamic]
; operator_name: string
; overload_name: string
; args: arg list (* ; returns: [`fixed of int | `dynamic] *)
; returns: [`fixed of int | `dynamic | `bool | `int64_t | `double]
; (* number of tensors that are returned *)
kind: [`function_ | `method_] }
@ -109,14 +120,14 @@ module Func = struct
| "bool" -> Some Bool
| "int64_t" -> Some (if is_nullable then Int64Option else Int64)
| "double" -> Some (if is_nullable then DoubleOption else Double)
| "booltensor" | "indextensor" | "tensor" ->
Some (if is_nullable then TensorOption else Tensor)
| "tensoroptions" -> Some TensorOptions
| "intarrayref" | "intlist" -> Some IntList
| "tensorlist" -> Some TensorList
| "device" -> Some Device
| "scalar" -> Some Scalar
| "scalartype" -> Some ScalarType
| "at::tensor" -> Some (if is_nullable then TensorOption else Tensor)
| "at::tensoroptions" -> Some TensorOptions
| "at::intarrayref" -> Some IntList
| "const c10::list<c10::optional<at::tensor>> &" -> Some TensorOptList
| "at::tensorlist" -> Some TensorList
| "at::device" -> Some Device
| "const at::scalar &" | "at::scalar" -> Some Scalar
| "at::scalartype" -> Some ScalarType
| "std::string" -> Some String
| _ -> None
@ -125,7 +136,7 @@ module Func = struct
match arg_type with
| IntList ->
Printf.sprintf "int64_t *%s_data, int %s_len" arg_name arg_name
| TensorList ->
| TensorOptList | TensorList ->
Printf.sprintf "tensor *%s_data, int %s_len" arg_name arg_name
| TensorOptions ->
Printf.sprintf "int %s_kind, int %s_device" arg_name arg_name
@ -145,8 +156,8 @@ module Func = struct
| ScalarType -> "int"
| Device -> "int"
| Scalar -> "scalar"
| Int64Option | DoubleOption | String | IntList | TensorList
|TensorOptions ->
| Int64Option | DoubleOption | String | IntList | TensorOptList
|TensorList | TensorOptions ->
assert false
in
Printf.sprintf "%s %s" simple_type_cstring arg_name )
@ -164,6 +175,9 @@ module Func = struct
arg_name
| String ->
Printf.sprintf "std::string(%s_ptr, %s_len)" arg_name arg_name
| TensorOptList ->
Printf.sprintf "of_carray_tensor_opt(%s_data, %s_len)" arg_name
arg_name
| TensorList ->
Printf.sprintf "of_carray_tensor(%s_data, %s_len)" arg_name
arg_name
@ -196,14 +210,15 @@ module Func = struct
t.name () )
(*
* let replace_map =
* Map.of_alist_exn
* (module String)
* [ ("t", "tr")
* ; ("where", "where_")
* ; ("view", "view_")
* ; ("unsafe", "unsafe_") ]
* *)
let replace_map =
Map.of_alist_exn
(module String)
[ ("t", "tr")
; ("where", "where_")
; ("view", "view_")
; ("unsafe", "unsafe_")
; ("to_device", "to_device_") ]
*)
let is_method t =
List.exists t.args ~f:(fun arg ->
@ -245,6 +260,7 @@ module Func = struct
| Device -> single_param "int32"
| String -> single_param "string"
| IntList -> Printf.sprintf "%sData []int64, %sLen int" an an
| TensorOptList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
| TensorList -> Printf.sprintf "%sData []Ctensor, %sLen int" an an
| Int64Option -> Printf.sprintf "%sVal int64, %sNull int" an an
| DoubleOption -> Printf.sprintf "%sVal float64, %sNull int" an an
@ -268,6 +284,7 @@ module Func = struct
| Device -> Printf.sprintf "c%s" an
| String -> Printf.sprintf "c%s, c%sLen" an an
| IntList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorOptList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| TensorList -> Printf.sprintf "c%sDataPtr, c%sLen" an an
| Int64Option -> Printf.sprintf "c%sVal, c%sNull" an an
| DoubleOption -> Printf.sprintf "c%sVal, c%sNull" an an
@ -306,6 +323,12 @@ module Func = struct
c%sDataPtr := (*C.int64_t)(unsafe.Pointer(&%sData[0]))\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an
| TensorOptList ->
Printf.sprintf
"\n\
c%sDataPtr := (*Ctensor)(unsafe.Pointer(&%sData[0]))\n\
c%sLen := *(*C.int)(unsafe.Pointer(&%sLen))"
an an an an
| TensorList ->
Printf.sprintf
"\n\
@ -382,6 +405,7 @@ module Func = struct
| Tensor -> "*Tensor"
| TensorOption -> "*Tensor"
| IntList -> "[]int64"
| TensorOptList -> "[]Tensor"
| TensorList -> "[]Tensor"
| String -> "string"
(* TODO. Struct{Kind gotch.DType Device gotch.Device} *)
@ -435,6 +459,9 @@ module Func = struct
List.init v ~f:(fun i -> Printf.sprintf "retVal%d *Tensor" i)
|> String.concat ~sep:", " |> Printf.sprintf "%s"
| `dynamic -> "retVal []Tensor"
| `bool -> "retVal bool"
| `int64_t -> "retVal int64"
| `double -> "retVal float64"
in
if is_inplace t then
if fallible then Printf.sprintf "err error" else Printf.sprintf ""
@ -449,6 +476,9 @@ module Func = struct
List.init v ~f:(fun i -> Printf.sprintf "retVal%d" i)
|> String.concat ~sep:", " |> Printf.sprintf "%s"
| `dynamic -> "retVal"
| `bool -> "retVal"
| `int64_t -> "retVal"
| `double -> "retVal"
in
if is_inplace t then
if fallible then Printf.sprintf "err" else Printf.sprintf ""
@ -511,6 +541,11 @@ module Func = struct
\ c%sNull = 0\n\
\ }\n"
an an an an an an
| TensorOptList ->
Printf.sprintf
" var c%s []lib.Ctensor\n\
\ for _, t := range %s {c%s = append(c%s, t.ctensor)}\n"
an an an an
| TensorList ->
Printf.sprintf
" var c%s []lib.Ctensor\n\
@ -536,6 +571,8 @@ let read_yaml filename =
List.filter_map funcs ~f:(fun yaml ->
let map = extract_map yaml in
let name = Map.find_exn map "name" |> extract_string in
let operator_name = Map.find_exn map "operator_name" |> extract_string in
let overload_name = Map.find_exn map "overload_name" |> extract_string in
let deprecated = Map.find_exn map "deprecated" |> extract_bool in
let method_of =
Map.find_exn map "method_of"
@ -548,22 +585,26 @@ let read_yaml filename =
let return_type =
Map.find_exn returns "dynamic_type" |> extract_string
in
String.( = ) return_type "Tensor"
|| String.( = ) return_type "BoolTensor"
|| String.( = ) return_type "IndexTensor"
String.( = ) return_type "at::Tensor"
in
let returns = Map.find_exn map "returns" |> extract_list in
if List.for_all returns ~f:is_tensor then
Some (`fixed (List.length returns))
else
match returns with
| [returns] ->
| [returns] -> (
let return_type =
Map.find_exn (extract_map returns) "dynamic_type"
|> extract_string
in
if String.( = ) return_type "TensorList" then Some `dynamic
else None
match return_type with
| "bool" -> Some `bool
| "int64_t" -> Some `int64_t
| "double" -> Some `double
| "at::TensorList"
|"dynamic_type: const c10::List<c10::optional<Tensor>> &" ->
Some `dynamic
| _ -> None )
| [] | _ :: _ :: _ -> None
in
let kind =
@ -622,7 +663,13 @@ let read_yaml filename =
if Option.is_some default_value then None
else raise Not_a_simple_arg )
in
Some {Func.name; args; returns; kind}
Some
{ Func.name
; operator_name
; overload_name
; args
; returns
; kind }
with Not_a_simple_arg -> None )
else None )
@ -684,8 +731,23 @@ let write_cpp funcs filename =
pc " return nullptr;" ;
pc "}" ;
pc "" ;
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list ) )
)
ph "tensor *atg_%s(%s);" exported_name c_typed_args_list
| (`bool | `int64_t | `double) as returns ->
let c_type =
match returns with
| `bool -> "int"
| `int64_t -> "int64_t"
| `double -> "double"
in
pc "%s atg_%s(%s) {" c_type exported_name c_typed_args_list ;
pc " PROTECT(" ;
pc " return %s;" (Func.c_call func) ;
pc " )" ;
pc " return 0;" ;
pc "}" ;
pc "" ;
ph "%s atg_%s(%s);" c_type exported_name c_typed_args_list )
) )
let write_wrapper funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
@ -777,7 +839,8 @@ let write_wrapper funcs filename =
pm " }\n" ;
(* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
pm " retVal = &Tensor{ctensor: *ptr}\n"
else pm " ts.ctensor = *ptr\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
@ -799,10 +862,62 @@ let write_wrapper funcs filename =
pm " }\n" ;
(* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then
pm " retVal = &Tensor{ctensor: *ptr}\n" ;
pm " retVal = &Tensor{ctensor: *ptr}\n"
else pm " ts.ctensor = *ptr\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `bool ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `int64_t ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `double ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `fixed _ -> pm "" ) ;
(* TODO. implement for return multiple tensor - []Tensor *)
pm "// End of implementing Tensor ================================= \n"
@ -913,6 +1028,57 @@ let write_must_wrapper funcs filename =
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `bool ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
pm " \n" ;
if is_method then
pm " retVal, err := ts.%s(%s)\n" gofunc_name
go_args_list_notype
else
pm " retVal, err := %s(%s)\n" gofunc_name
go_args_list_notype ;
pm " if err != nil { log.Fatal(err) }\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `int64_t ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
pm " \n" ;
if is_method then
pm " retVal, err := ts.%s(%s)\n" gofunc_name
go_args_list_notype
else
pm " retVal, err := %s(%s)\n" gofunc_name
go_args_list_notype ;
pm " if err != nil { log.Fatal(err) }\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `double ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
pm " \n" ;
if is_method then
pm " retVal, err := ts.%s(%s)\n" gofunc_name
go_args_list_notype
else
pm " retVal, err := %s(%s)\n" gofunc_name
go_args_list_notype ;
pm " if err != nil { log.Fatal(err) }\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `fixed _ -> pm "" ) ;
(* TODO. implement for return multiple tensor - []Tensor *)
pm "// End of implementing Tensor ================================= \n"
@ -969,19 +1135,29 @@ let write_ffi funcs filename =
in
match func.Func.returns with
| `fixed _ ->
pm "func Atg%s(ptr *Ctensor, %s){%s \nC.atg_%s(ptr, %s)\n}"
pm "func Atg%s(ptr *Ctensor, %s){%s \n\tC.atg_%s(ptr, %s)\n}"
ffifunc_name (Func.c_go_args_list func)
(Func.c_go_args_list_body func)
exported_name
(Func.c_go_args_list_notype func)
| `dynamic -> pm ""
| `bool -> pm ""
| `int64_t -> pm ""
| `double -> pm ""
(* TODO: need more implement here *)
(* pm "func Atg%s(%s)(retValPtr *Ctensor)" *)
(* (Func.go_name exported_name) *)
(* (Func.c_go_args_list func) *) ) )
let methods =
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
let c name args =
{ Func.name
; operator_name= name
; overload_name= ""
; args
; returns= `fixed 1
; kind= `method_ }
in
let ca arg_name arg_type = {Func.arg_name; arg_type; default_value= None} in
[ c "grad" [ca "self" Tensor]
; c "set_requires_grad" [ca "self" Tensor; ca "r" Bool]
@ -995,7 +1171,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
printf "Generating code for %d functions.\n%!" (List.length funcs) ;
(* Generate some unique names for overloaded functions. *)
let funcs =
List.map funcs ~f:(fun func -> (String.lowercase func.name, func))
List.map funcs ~f:(fun func -> (String.lowercase func.operator_name, func))
|> Map.of_alist_multi (module String)
|> Map.to_alist
|> List.concat_map ~f:(fun (name, funcs) ->
@ -1003,11 +1179,35 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
| [] -> assert false
| [func] -> [(name, func)]
| funcs ->
let has_empty_overload =
List.exists funcs ~f:(fun (func : Func.t) ->
String.is_empty func.overload_name )
in
List.sort funcs ~compare:(fun (f1 : Func.t) (f2 : Func.t) ->
Int.compare (List.length f1.args) (List.length f2.args) )
|> List.mapi ~f:(fun i func ->
( (if i = 0 then name else Printf.sprintf "%s%d" name i)
, func ) ) )
match
Int.compare (String.length f1.name)
(String.length f2.name)
with
| 0 ->
Int.compare (List.length f1.args) (List.length f2.args)
| cmp -> cmp )
|> List.mapi ~f:(fun index (func : Func.t) ->
let operator_name =
String.lowercase func.operator_name
in
let overload_name =
String.lowercase func.overload_name
in
let name =
if
String.is_empty overload_name
|| (index = 0 && not has_empty_overload)
then operator_name
else if String.is_suffix operator_name ~suffix:"_" then
operator_name ^ overload_name ^ "_"
else operator_name ^ "_" ^ overload_name
in
(name, func) ) )
|> Map.of_alist_exn (module String)
in
write_cpp funcs cpp_filename ;
@ -1016,7 +1216,7 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~must_wrapper_filename
write_wrapper funcs wrapper_filename
let () =
run ~yaml_filename:"gen/pytorch/Declarations-v1.7.0.yaml"
run ~yaml_filename:"gen/pytorch/Declarations-v1.9.0.yaml"
~cpp_filename:"libtch/torch_api_generated"
~ffi_filename:"libtch/c-generated.go"
~must_wrapper_filename:"tensor/must-tensor-generated.go"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff