WIP(gen/gen.ml): c_generated
This commit is contained in:
parent
df1c0b34ff
commit
12f5eaa9d7
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -15,6 +15,7 @@
|
||||||
target/
|
target/
|
||||||
_build/
|
_build/
|
||||||
data/
|
data/
|
||||||
|
tmp/
|
||||||
gen/.merlin
|
gen/.merlin
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
Cargo.lock
|
Cargo.lock
|
||||||
|
|
|
@ -34,4 +34,6 @@ func main() {
|
||||||
fmt.Printf("ys shape: %v\n", ys.Size())
|
fmt.Printf("ys shape: %v\n", ys.Size())
|
||||||
|
|
||||||
xs.Eq1(*ys)
|
xs.Eq1(*ys)
|
||||||
|
|
||||||
|
// xs.Matmul(*ys)
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,4 +55,25 @@ func main() {
|
||||||
|
|
||||||
fmt.Printf("DType: %v\n", ts.DType())
|
fmt.Printf("DType: %v\n", ts.DType())
|
||||||
|
|
||||||
|
dx := [][]int32{
|
||||||
|
{1, 1},
|
||||||
|
{1, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
dy := [][]int32{
|
||||||
|
{1, 2, 3},
|
||||||
|
{1, 1, 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
xs, err := wrapper.NewTensorFromData(dx, []int64{2, 2})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
ys, err := wrapper.NewTensorFromData(dy, []int64{2, 3})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xs.Matmul(*ys)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
82
gen/gen.ml
82
gen/gen.ml
|
@ -1,6 +1,6 @@
|
||||||
(* Automatically generate the C++ -> C -> rust bindings.
|
(* Automatically generate the C++ -> C -> rust bindings.
|
||||||
This takes as input the Descriptions.yaml file that gets generated when
|
This takes as input the Descriptions.yaml file that gets generated when
|
||||||
building PyTorch from source.
|
(Func.c_go_args_list func) building PyTorch from source.
|
||||||
|
|
||||||
Run with: dune exec gen/gen.exe
|
Run with: dune exec gen/gen.exe
|
||||||
*)
|
*)
|
||||||
|
@ -189,29 +189,50 @@ module Func = struct
|
||||||
let go_name name =
|
let go_name name =
|
||||||
let name =
|
let name =
|
||||||
Map.find replace_map name |> Option.value ~default:name
|
Map.find replace_map name |> Option.value ~default:name
|
||||||
|> String.lowercase
|
|> String.capitalize
|
||||||
|> String.substr_replace_all ~pattern:"__" ~with_:"_"
|
|> String.substr_replace_all ~pattern:"__" ~with_:""
|
||||||
in
|
in
|
||||||
if String.is_prefix name ~prefix:"_" then "internal" ^ name else name
|
if String.is_prefix name ~prefix:"_" then
|
||||||
|
"Internal" ^ (name |> String.capitalize)
|
||||||
|
else name |> String.capitalize
|
||||||
|
|
||||||
let c_go_args_list t =
|
let c_go_args_list t =
|
||||||
List.map t.args ~f:(fun arg ->
|
List.map t.args ~f:(fun arg ->
|
||||||
let an = arg.arg_name in
|
let an = arg.arg_name in
|
||||||
let single_param = Printf.sprintf "%s_: %s" an in
|
let single_param = Printf.sprintf "%s %s" an in
|
||||||
match arg.arg_type with
|
match arg.arg_type with
|
||||||
| Bool -> single_param "c_int"
|
| Bool -> single_param "C.int"
|
||||||
| Int64 -> single_param "int64"
|
| Int64 -> single_param "C.long"
|
||||||
| Double -> single_param "float64"
|
| Double -> single_param "C.double"
|
||||||
| Tensor -> single_param "*C_tensor"
|
| Tensor -> single_param "Ctensor"
|
||||||
| TensorOption -> single_param "*C_tensor"
|
| TensorOption -> single_param "Ctensor"
|
||||||
| Scalar -> single_param "*C_scalar"
|
| Scalar -> single_param "Cscalar"
|
||||||
| ScalarType -> single_param "c_int"
|
| ScalarType -> single_param "C.int"
|
||||||
| Device -> single_param "c_int"
|
| Device -> single_param "C.int"
|
||||||
| String -> Printf.sprintf "%s_ptr int, %s_len c_int" an an
|
| String -> Printf.sprintf "%s_ptr C.int, %s_len C.int" an an
|
||||||
| IntList -> Printf.sprintf "%s_data int64, %s_len c_int" an an
|
| IntList -> Printf.sprintf "%s_data C.long, %s_len C.int" an an
|
||||||
| TensorList -> Printf.sprintf "%s_data *C_tensor, %s_len c_int" an an
|
| TensorList -> Printf.sprintf "%s_data Ctensor, %s_len C.int" an an
|
||||||
| TensorOptions ->
|
| TensorOptions ->
|
||||||
Printf.sprintf "%s_kind c_int, %s_device c_int" an an )
|
Printf.sprintf "%s_kind C.int, %s_device C.int" an an )
|
||||||
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
|
let c_go_args_list_notype t =
|
||||||
|
List.map t.args ~f:(fun arg ->
|
||||||
|
let an = arg.arg_name in
|
||||||
|
let single_param = Printf.sprintf "%s %s" an in
|
||||||
|
match arg.arg_type with
|
||||||
|
| Bool -> single_param ""
|
||||||
|
| Int64 -> single_param ""
|
||||||
|
| Double -> single_param ""
|
||||||
|
| Tensor -> single_param ""
|
||||||
|
| TensorOption -> single_param ""
|
||||||
|
| Scalar -> single_param ""
|
||||||
|
| ScalarType -> single_param ""
|
||||||
|
| Device -> single_param ""
|
||||||
|
| String -> Printf.sprintf "%s_ptr, %s_len" an an
|
||||||
|
| IntList -> Printf.sprintf "%s_data, %s_len" an an
|
||||||
|
| TensorList -> Printf.sprintf "%s_data, %s_len" an an
|
||||||
|
| TensorOptions -> Printf.sprintf "%s_kind, %s_device" an an )
|
||||||
|> String.concat ~sep:", "
|
|> String.concat ~sep:", "
|
||||||
|
|
||||||
let self_name = "self"
|
let self_name = "self"
|
||||||
|
@ -561,20 +582,22 @@ let write_ffi funcs filename =
|
||||||
Out_channel.with_file filename ~f:(fun out_ml ->
|
Out_channel.with_file filename ~f:(fun out_ml ->
|
||||||
let pm s = p out_ml s in
|
let pm s = p out_ml s in
|
||||||
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ;
|
||||||
pm "#[allow(clippy::all)]" ;
|
pm "package libtch" ;
|
||||||
pm "use crate::{C_scalar, C_tensor};" ;
|
pm "" ;
|
||||||
pm "use libc::c_int;" ;
|
pm "// #include \"stdbool.h\" " ;
|
||||||
|
pm "// #include \"torch_api.h\" " ;
|
||||||
|
pm "import \"C\"" ;
|
||||||
pm "" ;
|
pm "" ;
|
||||||
pm "extern \"C\" {" ;
|
|
||||||
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
Map.iteri funcs ~f:(fun ~key:exported_name ~data:func ->
|
||||||
match func.Func.returns with
|
match func.Func.returns with
|
||||||
| `fixed _ ->
|
| `fixed _ ->
|
||||||
pm " func atg_%s(out__: *C_tensor, %s);" exported_name
|
pm "func Atg_%s(ptr *Ctensor, %s){C.atg_%s(ptr, %s)}"
|
||||||
(Func.c_go_args_list func)
|
(Func.go_name exported_name)
|
||||||
|
(Func.c_go_args_list func) exported_name
|
||||||
|
(Func.c_go_args_list_notype func)
|
||||||
| `dynamic ->
|
| `dynamic ->
|
||||||
pm " func atg_%s(%s) -> *C_tensor;" exported_name
|
pm "func Atg_%s(%s)(*Ctensor)" exported_name
|
||||||
(Func.c_go_args_list func) ) ;
|
(Func.c_go_args_list func) ) )
|
||||||
pm "}" )
|
|
||||||
|
|
||||||
let methods =
|
let methods =
|
||||||
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in
|
||||||
|
@ -613,7 +636,6 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename
|
||||||
|
|
||||||
let () =
|
let () =
|
||||||
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
|
run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml"
|
||||||
~cpp_filename:"libtch/torch_api_generated"
|
~cpp_filename:"tmp/torch_api_generated" ~ffi_filename:"tmp/c_generated.go"
|
||||||
~ffi_filename:"libtch/c_generated.go"
|
~wrapper_filename:"tmp/tensor_generated.go"
|
||||||
~wrapper_filename:"libtch/tensor_generated.go"
|
~fallible_wrapper_filename:"tmp/tensor_fallible_generated.go"
|
||||||
~fallible_wrapper_filename:"libtch/tensor_fallible_generated.go"
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
|
|
||||||
// NOTE: C.tensor is a C pointer to torch::Tensor
|
// NOTE: C.tensor is a C pointer to torch::Tensor
|
||||||
type Ctensor = C.tensor
|
type Ctensor = C.tensor
|
||||||
|
type Cscalar = C.scalar
|
||||||
|
|
||||||
func AtNewTensor() Ctensor {
|
func AtNewTensor() Ctensor {
|
||||||
return C.at_new_tensor()
|
return C.at_new_tensor()
|
||||||
|
|
|
@ -9,3 +9,8 @@ import "C"
|
||||||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_eq1(ptr, self, other)
|
C.atg_eq1(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_matmul(tensor *, tensor self, tensor other);
|
||||||
|
func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
|
C.atg_matmul(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
|
@ -249,3 +249,14 @@ func (ts Tensor) Eq1(other Tensor) {
|
||||||
|
|
||||||
lib.AtPrint(*ctensorPtr)
|
lib.AtPrint(*ctensorPtr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Matmul(other Tensor) {
|
||||||
|
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
|
||||||
|
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtPrint(*ctensorPtr)
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user