added missing Tensor methods returning multiple tensor values
This commit is contained in:
parent
880a1b25df
commit
4060f00193
89
gen/gen.ml
89
gen/gen.ml
|
@ -848,11 +848,11 @@ let write_wrapper funcs filename =
|
|||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " if del { defer ts.MustDrop() }\n" ;
|
||||
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
|
@ -871,11 +871,11 @@ let write_wrapper funcs filename =
|
|||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " if del { defer ts.MustDrop() }\n" ;
|
||||
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
|
@ -887,6 +887,44 @@ let write_wrapper funcs filename =
|
|||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `fixed ntensors ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
else pm "func %s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm " if del { defer ts.MustDrop() }\n" ;
|
||||
for i = 0 to ntensors - 1 do
|
||||
(* pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i *)
|
||||
if i = 0 then
|
||||
pm
|
||||
" ctensorPtr0 := \
|
||||
(*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n"
|
||||
else
|
||||
pm
|
||||
" ctensorPtr%d := \
|
||||
(*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr%d)) \
|
||||
+ unsafe.Sizeof(ctensorPtr0)))\n"
|
||||
i (i - 1)
|
||||
done ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm " %s(ctensorPtr0, %s)\n" cfunc_name
|
||||
(Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
(* NOTE. if in_place method, no retVal return *)
|
||||
if not (Func.is_inplace func) then
|
||||
for i = 0 to ntensors - 1 do
|
||||
pm " retVal%d = &Tensor{ctensor: *ctensorPtr%d}\n" i i
|
||||
done
|
||||
else pm " ts.ctensor = *ptr\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||
|
@ -894,10 +932,10 @@ let write_wrapper funcs filename =
|
|||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
|
@ -911,10 +949,10 @@ let write_wrapper funcs filename =
|
|||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||
if is_method && not is_inplace then
|
||||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
|
@ -931,15 +969,13 @@ let write_wrapper funcs filename =
|
|||
pm "if del { defer ts.MustDrop() }\n" ;
|
||||
pm " \n" ;
|
||||
pm " %s" (Func.go_binding_body func) ;
|
||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||
pm " if err = TorchErr(); err != nil {\n" ;
|
||||
pm " return %s\n"
|
||||
(Func.go_return_notype func ~fallible:true) ;
|
||||
pm " }\n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "} \n" ) ;
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
)
|
||||
|
||||
|
@ -982,7 +1018,6 @@ let write_must_wrapper funcs filename =
|
|||
let go_args_list = Func.go_typed_args_list func in
|
||||
let go_args_list_notype = Func.go_notype_args_list func in
|
||||
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
|
||||
(* TODO. implement multiple tensors return function []Tensor *)
|
||||
let excluded_funcs =
|
||||
[ "Chunk"
|
||||
; "AlignTensors"
|
||||
|
@ -1067,6 +1102,30 @@ let write_must_wrapper funcs filename =
|
|||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `fixed _ ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
else pm "func Must%s(" gofunc_name ;
|
||||
pm "%s" go_args_list ;
|
||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||
pm " \n" ;
|
||||
(* NOTE. No return retVal for in_place method *)
|
||||
if Func.is_inplace func then
|
||||
if is_method then
|
||||
pm " err := ts.%s(%s)\n" gofunc_name go_args_list_notype
|
||||
else pm " err := %s(%s)\n" gofunc_name go_args_list_notype
|
||||
else if is_method then
|
||||
pm " %s, err := ts.%s(%s)\n"
|
||||
(Func.go_return_notype func ~fallible:false)
|
||||
gofunc_name go_args_list_notype
|
||||
else
|
||||
pm " %s, err := %s(%s)\n"
|
||||
(Func.go_return_notype func ~fallible:false)
|
||||
gofunc_name go_args_list_notype ;
|
||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `bool ->
|
||||
pm "\n" ;
|
||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||
|
@ -1117,9 +1176,7 @@ let write_must_wrapper funcs filename =
|
|||
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||
pm " \n" ;
|
||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||
pm "} \n"
|
||||
| `fixed _ -> pm "" ) ;
|
||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
||||
pm "} \n" ) ;
|
||||
pm "// End of implementing Tensor ================================= \n"
|
||||
)
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -564,27 +564,27 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
|
|||
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
|
||||
// ====================================================================
|
||||
|
||||
// void atg_lstsq(tensor *, tensor self, tensor A);
|
||||
func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
||||
retVal, err := ts.Lstsq(a, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
// // void atg_lstsq(tensor *, tensor self, tensor A);
|
||||
// func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
||||
// if del {
|
||||
// defer ts.MustDrop()
|
||||
// }
|
||||
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
//
|
||||
// lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
||||
// if err = TorchErr(); err != nil {
|
||||
// return retVal, err
|
||||
// }
|
||||
// retVal = &Tensor{ctensor: *ptr}
|
||||
//
|
||||
// return retVal, err
|
||||
// }
|
||||
//
|
||||
// func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
||||
// retVal, err := ts.Lstsq(a, del)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// return retVal
|
||||
// }
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -814,10 +814,15 @@ func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
|
|||
// bMat := ts.MustOfSlice(startPoints).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
||||
bMat := ts.MustOfSlice(startData).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
||||
|
||||
res := bMat.MustLstsq(aMat, true)
|
||||
// res := bMat.MustLstsq(aMat, true)
|
||||
// Ref. https://github.com/pytorch/vision/blob/d7fa36f221cb2ff670cd4267b83a801cece52522/torchvision/transforms/functional.py#L572
|
||||
solution, residuals, rank, singularValues := bMat.MustLinalgLstsq(aMat, nil, "gels", true)
|
||||
residuals.MustDrop()
|
||||
rank.MustDrop()
|
||||
singularValues.MustDrop()
|
||||
|
||||
aMat.MustDrop()
|
||||
outputTs := res.MustSqueezeDim(1, true)
|
||||
outputTs := solution.MustSqueezeDim(1, true)
|
||||
output := outputTs.Float64Values()
|
||||
outputTs.MustDrop()
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user