WIP(gen/gen.ml): c_generated
This commit is contained in:
parent
df1c0b34ff
commit
12f5eaa9d7
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -15,6 +15,7 @@
|
|||
target/
|
||||
_build/
|
||||
data/
|
||||
tmp/
|
||||
gen/.merlin
|
||||
**/*.rs.bk
|
||||
Cargo.lock
|
||||
|
|
|
@ -34,4 +34,6 @@ func main() {
|
|||
fmt.Printf("ys shape: %v\n", ys.Size())
|
||||
|
||||
xs.Eq1(*ys)
|
||||
|
||||
// xs.Matmul(*ys)
|
||||
}
|
||||
|
|
|
@ -55,4 +55,25 @@ func main() {
|
|||
|
||||
fmt.Printf("DType: %v\n", ts.DType())
|
||||
|
||||
dx := [][]int32{
|
||||
{1, 1},
|
||||
{1, 1},
|
||||
}
|
||||
|
||||
dy := [][]int32{
|
||||
{1, 2, 3},
|
||||
{1, 1, 1},
|
||||
}
|
||||
|
||||
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 2})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
ys, err := wrapper.NewTensorFromData(dy, []int64{2, 3})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
xs.Matmul(*ys)
|
||||
|
||||
}
|
||||
|
|
82
gen/gen.ml
82
gen/gen.ml
|
@ -1,6 +1,6 @@
|
|||
(* Automatically generate the C++ -> C -> rust bindings.
|
||||
This takes as input the Descriptions.yaml file that gets generated when
|
||||
building PyTorch from source.
|
||||
(Func.c_go_args_list func) building PyTorch from source.
|
||||
|
||||
Run with: dune exec gen/gen.exe
|
||||
*)
|
||||
|
@ -189,29 +189,50 @@ module Func = struct
|
|||
let go_name name =
|
||||
let name =
|
||||
Map.find replace_map name |> Option.value ~default:name
|
||||
|> String.lowercase
|
||||
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
|
||||
|> String.capitalize
|
||||
|> String.substr_replace_all ~pattern:"__" ~with_:""
|
||||
in
|
||||
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
|
||||
if String.is_prefix name ~prefix:"_" then
|
||||
"Internal" ^ (name |> String.capitalize)
|
||||
else name |> String.capitalize
|
||||
|
||||
let c_go_args_list t =
|
||||
List.map t.args ~f:(fun arg ->
|
||||
let an = arg.arg_name in
|
||||
let single_param = Printf.sprintf "%s_: %s" an in
|
||||
let single_param = Printf.sprintf "%s %s" an in
|
||||
match arg.arg_type with
|
||||
| Bool -> single_param "c_int"
|
||||
| Int64 -> single_param "int64"
|
||||
| Double -> single_param "float64"
|
||||
| Tensor -> single_param "*C_tensor"
|
||||
| TensorOption -> single_param "*C_tensor"
|
||||
| Scalar -> single_param "*C_scalar"
|
||||
| ScalarType -> single_param "c_int"
|
||||
| Device -> single_param "c_int"
|
||||
| String -> Printf.sprintf "%s_ptr int, %s_len c_int" an an
|
||||
| IntList -> Printf.sprintf "%s_data int64, %s_len c_int" an an
|
||||
| TensorList -> Printf.sprintf "%s_data *C_tensor, %s_len c_int" an an
|
||||
| Bool -> single_param "C.int"
|
||||
| Int64 -> single_param "C.long"
|
||||
| Double -> single_param "C.double"
|
||||
| Tensor -> single_param "Ctensor"
|
||||
| TensorOption -> single_param "Ctensor"
|
||||
| Scalar -> single_param "Cscalar"
|
||||
| ScalarType -> single_param "C.int"
|
||||
| Device -> single_param "C.int"
|
||||
| String -> Printf.sprintf "%s_ptr C.int, %s_len C.int" an an
|
||||
| IntList -> Printf.sprintf "%s_data C.long, %s_len C.int" an an
|
||||
| TensorList -> Printf.sprintf "%s_data Ctensor, %s_len C.int" an an
|
||||
| TensorOptions ->
|
||||
Printf.sprintf "%s_kind c_int, %s_device c_int" an an )
|
||||
Printf.sprintf "%s_kind C.int, %s_device C.int" an an )
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let c_go_args_list_notype t =
|
||||
List.map t.args ~f:(fun arg ->
|
||||
let an = arg.arg_name in
|
||||
let single_param = Printf.sprintf "%s %s" an in
|
||||
match arg.arg_type with
|
||||
| Bool -> single_param ""
|
||||
| Int64 -> single_param ""
|
||||
| Double -> single_param ""
|
||||
| Tensor -> single_param ""
|
||||
| TensorOption -> single_param ""
|
||||
| Scalar -> single_param ""
|
||||
| ScalarType -> single_param ""
|
||||
| Device -> single_param ""
|
||||
| String -> Printf.sprintf "%s_ptr, %s_len" an an
|
||||
| IntList -> Printf.sprintf "%s_data, %s_len" an an
|
||||
| TensorList -> Printf.sprintf "%s_data, %s_len" an an
|
||||
| TensorOptions -> Printf.sprintf "%s_kind, %s_device" an an )
|
||||
|> String.concat ~sep:", "
|
||||
|
||||
let self_name = "self"
|
||||
|
@ -561,20 +582,22 @@ let write_ffi funcs filename =
|
|||
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||
let pm s = p out_ml s in
|
||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
||||
pm "#[allow(clippy::all)]" ;
|
||||
pm "use crate::{C_scalar, C_tensor};" ;
|
||||
pm "use libc::c_int;" ;
|
||||
pm "package libtch" ;
|
||||
pm "" ;
|
||||
pm "// #include \"stdbool.h\" " ;
|
||||
pm "// #include \"torch_api.h\" " ;
|
||||
pm "import \"C\"" ;
|
||||
pm "" ;
|
||||
pm "extern \"C\" {" ;
|
||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||
match func.Func.returns with
|
||||
| `fixed _ ->
|
||||
pm " func atg_%s(out__: *C_tensor, %s);" exported_name
|
||||
(Func.c_go_args_list func)
|
||||
pm "func Atg_%s(ptr *Ctensor, %s){C.atg_%s(ptr, %s)}"
|
||||
(Func.go_name exported_name)
|
||||
(Func.c_go_args_list func) exported_name
|
||||
(Func.c_go_args_list_notype func)
|
||||
| `dynamic ->
|
||||
pm " func atg_%s(%s) -> *C_tensor;" exported_name
|
||||
(Func.c_go_args_list func) ) ;
|
||||
pm "}" )
|
||||
pm "func Atg_%s(%s)(*Ctensor)" exported_name
|
||||
(Func.c_go_args_list func) ) )
|
||||
|
||||
let methods =
|
||||
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
||||
|
@ -613,7 +636,6 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
|
|||
|
||||
let () =
|
||||
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
|
||||
~cpp_filename:"libtch/torch_api_generated"
|
||||
~ffi_filename:"libtch/c_generated.go"
|
||||
~wrapper_filename:"libtch/tensor_generated.go"
|
||||
~fallible_wrapper_filename:"libtch/tensor_fallible_generated.go"
|
||||
~cpp_filename:"tmp/torch_api_generated" ~ffi_filename:"tmp/c_generated.go"
|
||||
~wrapper_filename:"tmp/tensor_generated.go"
|
||||
~fallible_wrapper_filename:"tmp/tensor_fallible_generated.go"
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
// NOTE: C.tensor is a C pointer to torch::Tensor
|
||||
type Ctensor = C.tensor
|
||||
type Cscalar = C.scalar
|
||||
|
||||
func AtNewTensor() Ctensor {
|
||||
return C.at_new_tensor()
|
||||
|
|
|
@ -9,3 +9,8 @@ import "C"
|
|||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_eq1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_matmul(tensor *, tensor self, tensor other);
|
||||
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_matmul(ptr, self, other)
|
||||
}
|
||||
|
|
|
@ -249,3 +249,14 @@ func (ts Tensor) Eq1(other Tensor) {
|
|||
|
||||
lib.AtPrint(*ctensorPtr)
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) {
|
||||
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
lib.AtPrint(*ctensorPtr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user