added missing Tensor methods returning multiple tensor values

This commit is contained in:
sugarme 2021-08-15 21:59:10 +10:00
parent 880a1b25df
commit 4060f00193
5 changed files with 6394 additions and 2595 deletions

View File

@ -848,11 +848,11 @@ let write_wrapper funcs filename =
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " if del { defer ts.MustDrop() }\n" ;
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
@ -871,11 +871,11 @@ let write_wrapper funcs filename =
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " if del { defer ts.MustDrop() }\n" ;
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
@ -887,6 +887,44 @@ let write_wrapper funcs filename =
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `fixed ntensors ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
else pm "func %s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm " if del { defer ts.MustDrop() }\n" ;
for i = 0 to ntensors - 1 do
(* pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i *)
if i = 0 then
pm
" ctensorPtr0 := \
(*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n"
else
pm
" ctensorPtr%d := \
(*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr%d)) \
+ unsafe.Sizeof(ctensorPtr0)))\n"
i (i - 1)
done ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm " %s(ctensorPtr0, %s)\n" cfunc_name
(Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
(* NOTE. if in_place method, no retVal return *)
if not (Func.is_inplace func) then
for i = 0 to ntensors - 1 do
pm " retVal%d = &Tensor{ctensor: *ctensorPtr%d}\n" i i
done
else pm " ts.ctensor = *ptr\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `bool ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
@ -894,10 +932,10 @@ let write_wrapper funcs filename =
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
@ -911,10 +949,10 @@ let write_wrapper funcs filename =
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
if is_method && not is_inplace then
pm "if del { defer ts.MustDrop() }\n" ;
pm " if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
@ -931,15 +969,13 @@ let write_wrapper funcs filename =
pm "if del { defer ts.MustDrop() }\n" ;
pm " \n" ;
pm " %s" (Func.go_binding_body func) ;
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
pm " if err = TorchErr(); err != nil {\n" ;
pm " return %s\n"
(Func.go_return_notype func ~fallible:true) ;
pm " }\n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
pm "} \n"
| `fixed _ -> pm "" ) ;
(* TODO. implement for return multiple tensor - []Tensor *)
pm "} \n" ) ;
pm "// End of implementing Tensor ================================= \n"
)
@ -982,7 +1018,6 @@ let write_must_wrapper funcs filename =
let go_args_list = Func.go_typed_args_list func in
let go_args_list_notype = Func.go_notype_args_list func in
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
(* TODO. implement multiple tensors return function []Tensor *)
let excluded_funcs =
[ "Chunk"
; "AlignTensors"
@ -1067,6 +1102,30 @@ let write_must_wrapper funcs filename =
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `fixed _ ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
else pm "func Must%s(" gofunc_name ;
pm "%s" go_args_list ;
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
pm " \n" ;
(* NOTE. No return retVal for in_place method *)
if Func.is_inplace func then
if is_method then
pm " err := ts.%s(%s)\n" gofunc_name go_args_list_notype
else pm " err := %s(%s)\n" gofunc_name go_args_list_notype
else if is_method then
pm " %s, err := ts.%s(%s)\n"
(Func.go_return_notype func ~fallible:false)
gofunc_name go_args_list_notype
else
pm " %s, err := %s(%s)\n"
(Func.go_return_notype func ~fallible:false)
gofunc_name go_args_list_notype ;
pm " if err != nil { log.Fatal(err) }\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `bool ->
pm "\n" ;
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
@ -1117,9 +1176,7 @@ let write_must_wrapper funcs filename =
pm " if err != nil { log.Fatal(err) }\n" ;
pm " \n" ;
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
pm "} \n"
| `fixed _ -> pm "" ) ;
(* TODO. implement for return multiple tensor - []Tensor *)
pm "} \n" ) ;
pm "// End of implementing Tensor ================================= \n"
)

File diff suppressed because it is too large Load Diff

View File

@ -564,27 +564,27 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
// ====================================================================
// void atg_lstsq(tensor *, tensor self, tensor A);
func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = &Tensor{ctensor: *ptr}
return retVal, err
}
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
retVal, err := ts.Lstsq(a, del)
if err != nil {
log.Fatal(err)
}
return retVal
}
// // void atg_lstsq(tensor *, tensor self, tensor A);
// func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
// if del {
// defer ts.MustDrop()
// }
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
//
// lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
// if err = TorchErr(); err != nil {
// return retVal, err
// }
// retVal = &Tensor{ctensor: *ptr}
//
// return retVal, err
// }
//
// func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
// retVal, err := ts.Lstsq(a, del)
// if err != nil {
// log.Fatal(err)
// }
//
// return retVal
// }

File diff suppressed because it is too large Load Diff

View File

@ -814,10 +814,15 @@ func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
// bMat := ts.MustOfSlice(startPoints).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
bMat := ts.MustOfSlice(startData).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
res := bMat.MustLstsq(aMat, true)
// res := bMat.MustLstsq(aMat, true)
// Ref. https://github.com/pytorch/vision/blob/d7fa36f221cb2ff670cd4267b83a801cece52522/torchvision/transforms/functional.py#L572
solution, residuals, rank, singularValues := bMat.MustLinalgLstsq(aMat, nil, "gels", true)
residuals.MustDrop()
rank.MustDrop()
singularValues.MustDrop()
aMat.MustDrop()
outputTs := res.MustSqueezeDim(1, true)
outputTs := solution.MustSqueezeDim(1, true)
output := outputTs.Float64Values()
outputTs.MustDrop()