diff --git a/.gitignore b/.gitignore index 35e0622..931dedc 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ target/ _build/ data/ +tmp/ gen/.merlin **/*.rs.bk Cargo.lock diff --git a/example/error/main.go b/example/error/main.go index a4ec930..11981b9 100644 --- a/example/error/main.go +++ b/example/error/main.go @@ -34,4 +34,6 @@ func main() { fmt.Printf("ys shape: %v\n", ys.Size()) xs.Eq1(*ys) + + // xs.Matmul(*ys) } diff --git a/example/tensor/main.go b/example/tensor/main.go index f6a51cc..7530906 100644 --- a/example/tensor/main.go +++ b/example/tensor/main.go @@ -55,4 +55,25 @@ func main() { fmt.Printf("DType: %v\n", ts.DType()) + dx := [][]int32{ + {1, 1}, + {1, 1}, + } + + dy := [][]int32{ + {1, 2, 3}, + {1, 1, 1}, + } + + xs, err := wrapper.NewTensorFromData(dx, []int64{2, 2}) + if err != nil { + log.Fatal(err) + } + ys, err := wrapper.NewTensorFromData(dy, []int64{2, 3}) + if err != nil { + log.Fatal(err) + } + + xs.Matmul(*ys) + } diff --git a/gen/gen.ml b/gen/gen.ml index 9065a16..fd4a99e 100644 --- a/gen/gen.ml +++ b/gen/gen.ml @@ -1,6 +1,6 @@ (* Automatically generate the C++ -> C -> rust bindings. This takes as input the Descriptions.yaml file that gets generated when - building PyTorch from source. + (Func.c_go_args_list func) building PyTorch from source. Run with: dune exec gen/gen.exe *) @@ -189,29 +189,50 @@ module Func = struct let go_name name = let name = Map.find replace_map name |> Option.value ~default:name - |> String.lowercase - |> String.substr_replace_all ~pattern:"__" ~with_:"_" + |> String.capitalize + |> String.substr_replace_all ~pattern:"__" ~with_:"" in - if String.is_prefix name ~prefix:"_" then "internal" ^ name else name + if String.is_prefix name ~prefix:"_" then + "Internal" ^ (name |> String.capitalize) + else name |> String.capitalize let c_go_args_list t = List.map t.args ~f:(fun arg -> let an = arg.arg_name in - let single_param = Printf.sprintf "%s_: %s" an in + let single_param = Printf.sprintf "%s %s" an in match arg.arg_type with - | Bool -> single_param "c_int" - | Int64 -> single_param "int64" - | Double -> single_param "float64" - | Tensor -> single_param "*C_tensor" - | TensorOption -> single_param "*C_tensor" - | Scalar -> single_param "*C_scalar" - | ScalarType -> single_param "c_int" - | Device -> single_param "c_int" - | String -> Printf.sprintf "%s_ptr int, %s_len c_int" an an - | IntList -> Printf.sprintf "%s_data int64, %s_len c_int" an an - | TensorList -> Printf.sprintf "%s_data *C_tensor, %s_len c_int" an an + | Bool -> single_param "C.int" + | Int64 -> single_param "C.long" + | Double -> single_param "C.double" + | Tensor -> single_param "Ctensor" + | TensorOption -> single_param "Ctensor" + | Scalar -> single_param "Cscalar" + | ScalarType -> single_param "C.int" + | Device -> single_param "C.int" + | String -> Printf.sprintf "%s_ptr C.int, %s_len C.int" an an + | IntList -> Printf.sprintf "%s_data C.long, %s_len C.int" an an + | TensorList -> Printf.sprintf "%s_data Ctensor, %s_len C.int" an an | TensorOptions -> - Printf.sprintf "%s_kind c_int, %s_device c_int" an an ) + Printf.sprintf "%s_kind C.int, %s_device C.int" an an ) + |> String.concat ~sep:", " + + let c_go_args_list_notype t = + List.map t.args ~f:(fun arg -> + let an = arg.arg_name in + let single_param = Printf.sprintf "%s %s" an in + match arg.arg_type with + | Bool -> single_param "" + | Int64 -> single_param "" + | Double -> single_param "" + | Tensor -> single_param "" + | TensorOption -> single_param "" + | Scalar -> single_param "" + | ScalarType -> single_param "" + | Device -> single_param "" + | String -> Printf.sprintf "%s_ptr, %s_len" an an + | IntList -> Printf.sprintf "%s_data, %s_len" an an + | TensorList -> Printf.sprintf "%s_data, %s_len" an an + | TensorOptions -> Printf.sprintf "%s_kind, %s_device" an an ) |> String.concat ~sep:", " let self_name = "self" @@ -561,20 +582,22 @@ let write_ffi funcs filename = Out_channel.with_file filename ~f:(fun out_ml -> let pm s = p out_ml s in pm "/* THIS FILE IS AUTOMATICALLY GENERATED, DO NOT EDIT BY HAND! */" ; - pm "#[allow(clippy::all)]" ; - pm "use crate::{C_scalar, C_tensor};" ; - pm "use libc::c_int;" ; + pm "package libtch" ; + pm "" ; + pm "// #include \"stdbool.h\" " ; + pm "// #include \"torch_api.h\" " ; + pm "import \"C\"" ; pm "" ; - pm "extern \"C\" {" ; Map.iteri funcs ~f:(fun ~key:exported_name ~data:func -> match func.Func.returns with | `fixed _ -> - pm " func atg_%s(out__: *C_tensor, %s);" exported_name - (Func.c_go_args_list func) + pm "func Atg_%s(ptr *Ctensor, %s){C.atg_%s(ptr, %s)}" + (Func.go_name exported_name) + (Func.c_go_args_list func) exported_name + (Func.c_go_args_list_notype func) | `dynamic -> - pm " func atg_%s(%s) -> *C_tensor;" exported_name - (Func.c_go_args_list func) ) ; - pm "}" ) + pm "func Atg_%s(%s)(*Ctensor)" exported_name + (Func.c_go_args_list func) ) ) let methods = let c name args = {Func.name; args; returns= `fixed 1; kind= `method_} in @@ -613,7 +636,6 @@ let run ~yaml_filename ~cpp_filename ~ffi_filename ~wrapper_filename let () = run ~yaml_filename:"third_party/pytorch/Declarations-v1.5.0.yaml" - ~cpp_filename:"libtch/torch_api_generated" - ~ffi_filename:"libtch/c_generated.go" - ~wrapper_filename:"libtch/tensor_generated.go" - ~fallible_wrapper_filename:"libtch/tensor_fallible_generated.go" + ~cpp_filename:"tmp/torch_api_generated" ~ffi_filename:"tmp/c_generated.go" + ~wrapper_filename:"tmp/tensor_generated.go" + ~fallible_wrapper_filename:"tmp/tensor_fallible_generated.go" diff --git a/libtch/tensor.go b/libtch/tensor.go index da227d4..a31da25 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -10,6 +10,7 @@ import ( // NOTE: C.tensor is a C pointer to torch::Tensor type Ctensor = C.tensor +type Cscalar = C.scalar func AtNewTensor() Ctensor { return C.at_new_tensor() diff --git a/libtch/tensor_generated_sample.go b/libtch/tensor_generated_sample.go index 5aa60ec..441404c 100644 --- a/libtch/tensor_generated_sample.go +++ b/libtch/tensor_generated_sample.go @@ -9,3 +9,8 @@ import "C" func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) { C.atg_eq1(ptr, self, other) } + +// void atg_matmul(tensor *, tensor self, tensor other); +func AtgMatmul(ptr *Ctensor, self Ctensor, other Ctensor) { + C.atg_matmul(ptr, self, other) +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 729c574..a61c952 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -249,3 +249,14 @@ func (ts Tensor) Eq1(other Tensor) { lib.AtPrint(*ctensorPtr) } + +func (ts Tensor) Matmul(other Tensor) { + ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor) + + if err := TorchErr(); err != nil { + log.Fatal(err) + } + + lib.AtPrint(*ctensorPtr) +}