generated newTensor() for GC collection
This commit is contained in:
parent
aa23a1e59b
commit
3cd8d8560f
10
gen/gen.ml
10
gen/gen.ml
|
@ -883,6 +883,7 @@ let write_wrapper funcs filename =
|
|||
pm "\n\n" ;
|
||||
pm "import(\n" ;
|
||||
pm " \"unsafe\"\n" ;
|
||||
pm " \"fmt\"\n" ;
|
||||
pm "\n" ;
|
||||
pm " \"github.com/sugarme/gotch\"\n" ;
|
||||
pm " lib \"github.com/sugarme/gotch/libtch\"\n" ;
|
||||
|
@ -982,12 +983,13 @@ let write_wrapper funcs filename =
|
|||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
pm " retVal = &Tensor{ctensor: *ptr}\n"
|
||||
pm " retVal = newTensor(*ptr, \"%s\")\n" gofunc_name
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
|
@ -1024,13 +1026,14 @@ let write_wrapper funcs filename =
|
|||
pm " %s(ctensorPtr0, %s)\n" cfunc_name
|
||||
(Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
for i = 0 to ntensors - 1 do
|
||||
pm " retVal%d = &Tensor{ctensor: *ctensorPtr%d}\n" i i
|
||||
pm " retVal%d = newTensor(*ctensorPtr%d, \"%s_%d\")\n" i i gofunc_name i
|
||||
done
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
|
@ -1052,6 +1055,7 @@ let write_wrapper funcs filename =
|
|||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
|
@ -1073,6 +1077,7 @@ let write_wrapper funcs filename =
|
|||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
|
@ -1094,6 +1099,7 @@ let write_wrapper funcs filename =
|
|||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " err = fmt.Errorf(\"%s() failed: %%w\", err)\n" gofunc_name;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user