generated newTensor() for GC collection

This commit is contained in:
sugarme 2023-07-26 23:19:38 +10:00
parent aa23a1e59b
commit 3cd8d8560f
2 changed files with 4983 additions and 2519 deletions

View File

@ -883,6 +883,7 @@ let write_wrapper funcs filename =
pm "\n\n" ;
pm "import(\n" ;
pm " \"unsafe\"\n" ;
pm " \"fmt\"\n" ;
pm "\n" ;
pm " \"github.com/sugarme/gotch\"\n" ;
pm " lib \"github.com/sugarme/gotch/libtch\"\n" ;
@ -982,12 +983,13 @@ let write_wrapper funcs filename =
pm " %s" (Func.go_binding_body func) ;
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
(* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then
pm " retVal = &Tensor{ctensor: *ptr}\n"
pm " retVal = newTensor(*ptr, \"%s\")\n" gofunc_name
else pm " ts.ctensor = *ptr\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
@ -1024,13 +1026,14 @@ let write_wrapper funcs filename =
pm " %s(ctensorPtr0, %s)\n" cfunc_name
(Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
(* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then
for i = 0 to ntensors - 1 do
pm " retVal%d = &Tensor{ctensor: *ctensorPtr%d}\n" i i
pm " retVal%d = newTensor(*ctensorPtr%d, \"%s_%d\")\n" i i gofunc_name i
done
else pm " ts.ctensor = *ptr\n" ;
pm " \n" ;
@ -1052,6 +1055,7 @@ let write_wrapper funcs filename =
pm " %s" (Func.go_binding_body func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
@ -1073,6 +1077,7 @@ let write_wrapper funcs filename =
pm " %s" (Func.go_binding_body func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
@ -1094,6 +1099,7 @@ let write_wrapper funcs filename =
pm " %s" (Func.go_binding_body func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;

File diff suppressed because it is too large Load Diff