WIP(gen/gen.ml): c_generated

This commit is contained in:
sugarme 2020-06-05 00:04:12 +10:00
parent df1c0b34ff
commit 12f5eaa9d7
7 changed files with 93 additions and 30 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@
target/
_build/
data/
tmp/
gen/.merlin
**/*.rs.bk
Cargo.lock

View File

@ -34,4 +34,6 @@ func main() {
fmt.Printf("ys shape: %v\n", ys.Size())
xs.Eq1(*ys)
// xs.Matmul(*ys)
}

View File

@ -55,4 +55,25 @@ func main() {
fmt.Printf("DType: %v\n", ts.DType())
dx := [][]int32{
{1, 1},
{1, 1},
}
dy := [][]int32{
{1, 2, 3},
{1, 1, 1},
}
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 2})
if err != nil {
log.Fatal(err)
}
ys, err := wrapper.NewTensorFromData(dy, []int64{2, 3})
if err != nil {
log.Fatal(err)
}
xs.Matmul(*ys)
}

View File

@ -1,6 +1,6 @@
(* Automatically generate the C++ -> C -> rust bindings.
This takes as input the Descriptions.yaml file that gets generated when
building PyTorch from source.
(Func.c_go_args_list func) building PyTorch from source.
Run with: dune exec gen/gen.exe
*)
@ -189,29 +189,50 @@ module Func = struct
let go_name name =
let name =
Map.find replace_map name |> Option.value ~default:name
|> String.lowercase
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
|> String.capitalize
|> String.substr_replace_all ~pattern:"__" ~with_:""
in
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
if String.is_prefix name ~prefix:"_" then
"Internal" ^ (name |> String.capitalize)
else name |> String.capitalize
let c_go_args_list t =
List.map t.args ~f:(fun arg ->
let an = arg.arg_name in
let single_param = Printf.sprintf "%s_: %s" an in
let single_param = Printf.sprintf "%s %s" an in
match arg.arg_type with
| Bool -> single_param "c_int"
| Int64 -> single_param "int64"
| Double -> single_param "float64"
| Tensor -> single_param "*C_tensor"
| TensorOption -> single_param "*C_tensor"
| Scalar -> single_param "*C_scalar"
| ScalarType -> single_param "c_int"
| Device -> single_param "c_int"
| String -> Printf.sprintf "%s_ptr int, %s_len c_int" an an
| IntList -> Printf.sprintf "%s_data int64, %s_len c_int" an an
| TensorList -> Printf.sprintf "%s_data *C_tensor, %s_len c_int" an an
| Bool -> single_param "C.int"
| Int64 -> single_param "C.long"
| Double -> single_param "C.double"
| Tensor -> single_param "Ctensor"
| TensorOption -> single_param "Ctensor"
| Scalar -> single_param "Cscalar"
| ScalarType -> single_param "C.int"
| Device -> single_param "C.int"
| String -> Printf.sprintf "%s_ptr C.int, %s_len C.int" an an
| IntList -> Printf.sprintf "%s_data C.long, %s_len C.int" an an
| TensorList -> Printf.sprintf "%s_data Ctensor, %s_len C.int" an an
| TensorOptions ->
Printf.sprintf "%s_kind c_int, %s_device c_int" an an )
Printf.sprintf "%s_kind C.int, %s_device C.int" an an )
|> String.concat ~sep:", "
let c_go_args_list_notype t =
List.map t.args ~f:(fun arg ->
let an = arg.arg_name in
let single_param = Printf.sprintf "%s %s" an in
match arg.arg_type with
| Bool -> single_param ""
| Int64 -> single_param ""
| Double -> single_param ""
| Tensor -> single_param ""
| TensorOption -> single_param ""
| Scalar -> single_param ""
| ScalarType -> single_param ""
| Device -> single_param ""
| String -> Printf.sprintf "%s_ptr, %s_len" an an
| IntList -> Printf.sprintf "%s_data, %s_len" an an
| TensorList -> Printf.sprintf "%s_data, %s_len" an an
| TensorOptions -> Printf.sprintf "%s_kind, %s_device" an an )
|> String.concat ~sep:", "
let self_name = "self"
@ -561,20 +582,22 @@ let write_ffi funcs filename =
Out_channel.with_file filename ~f:(fun out_ml ->
let pm s = p out_ml s in
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
pm "#[allow(clippy::all)]" ;
pm "use crate::{C_scalar, C_tensor};" ;
pm "use libc::c_int;" ;
pm "package libtch" ;
pm "" ;
pm "// #include \"stdbool.h\" " ;
pm "// #include \"torch_api.h\" " ;
pm "import \"C\"" ;
pm "" ;
pm "extern \"C\" {" ;
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
match func.Func.returns with
| `fixed _ ->
pm " func atg_%s(out__: *C_tensor, %s);" exported_name
(Func.c_go_args_list func)
pm "func Atg_%s(ptr *Ctensor, %s){C.atg_%s(ptr, %s)}"
(Func.go_name exported_name)
(Func.c_go_args_list func) exported_name
(Func.c_go_args_list_notype func)
| `dynamic ->
pm " func atg_%s(%s) -> *C_tensor;" exported_name
(Func.c_go_args_list func) ) ;
pm "}" )
pm "func Atg_%s(%s)(*Ctensor)" exported_name
(Func.c_go_args_list func) ) )
let methods =
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
@ -613,7 +636,6 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
let () =
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
~cpp_filename:"libtch/torch_api_generated"
~ffi_filename:"libtch/c_generated.go"
~wrapper_filename:"libtch/tensor_generated.go"
~fallible_wrapper_filename:"libtch/tensor_fallible_generated.go"
~cpp_filename:"tmp/torch_api_generated" ~ffi_filename:"tmp/c_generated.go"
~wrapper_filename:"tmp/tensor_generated.go"
~fallible_wrapper_filename:"tmp/tensor_fallible_generated.go"

View File

@ -10,6 +10,7 @@ import (
// NOTE: C.tensor is a C pointer to torch::Tensor
type Ctensor = C.tensor
type Cscalar = C.scalar
func AtNewTensor() Ctensor {
return C.at_new_tensor()

View File

@ -9,3 +9,8 @@ import "C"
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_eq1(ptr, self, other)
}
// void atg_matmul(tensor *, tensor self, tensor other);
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_matmul(ptr, self, other)
}

View File

@ -249,3 +249,14 @@ func (ts Tensor) Eq1(other Tensor) {
lib.AtPrint(*ctensorPtr)
}
func (ts Tensor) Matmul(other Tensor) {
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
lib.AtPrint(*ctensorPtr)
}