added missing Tensor methods returning multiple tensor values
This commit is contained in:
parent
880a1b25df
commit
4060f00193
89
gen/gen.ml
89
gen/gen.ml
|
@ -848,11 +848,11 @@ let write_wrapper funcs filename =
|
||||||
pm "%s" go_args_list ;
|
pm "%s" go_args_list ;
|
||||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
if is_method && not is_inplace then
|
if is_method && not is_inplace then
|
||||||
pm "if del { defer ts.MustDrop() }\n" ;
|
pm " if del { defer ts.MustDrop() }\n" ;
|
||||||
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " %s" (Func.go_binding_body func) ;
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
pm " if err = TorchErr(); err != nil {\n" ;
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
pm " return %s\n"
|
pm " return %s\n"
|
||||||
(Func.go_return_notype func ~fallible:true) ;
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
@ -871,11 +871,11 @@ let write_wrapper funcs filename =
|
||||||
pm "%s" go_args_list ;
|
pm "%s" go_args_list ;
|
||||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
if is_method && not is_inplace then
|
if is_method && not is_inplace then
|
||||||
pm "if del { defer ts.MustDrop() }\n" ;
|
pm " if del { defer ts.MustDrop() }\n" ;
|
||||||
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
pm " ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " %s" (Func.go_binding_body func) ;
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
pm "%s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
pm " %s(ptr, %s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
pm " if err = TorchErr(); err != nil {\n" ;
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
pm " return %s\n"
|
pm " return %s\n"
|
||||||
(Func.go_return_notype func ~fallible:true) ;
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
@ -887,6 +887,44 @@ let write_wrapper funcs filename =
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||||
pm "} \n"
|
pm "} \n"
|
||||||
|
| `fixed ntensors ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||||
|
else pm "func %s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
|
if is_method && not is_inplace then
|
||||||
|
pm " if del { defer ts.MustDrop() }\n" ;
|
||||||
|
for i = 0 to ntensors - 1 do
|
||||||
|
(* pc " out__[%d] = new torch::Tensor(std::get<%d>(outputs__));" i i *)
|
||||||
|
if i = 0 then
|
||||||
|
pm
|
||||||
|
" ctensorPtr0 := \
|
||||||
|
(*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))\n"
|
||||||
|
else
|
||||||
|
pm
|
||||||
|
" ctensorPtr%d := \
|
||||||
|
(*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr%d)) \
|
||||||
|
+ unsafe.Sizeof(ctensorPtr0)))\n"
|
||||||
|
i (i - 1)
|
||||||
|
done ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
|
pm " %s(ctensorPtr0, %s)\n" cfunc_name
|
||||||
|
(Func.go_binding_args func) ;
|
||||||
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
|
pm " return %s\n"
|
||||||
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm " }\n" ;
|
||||||
|
(* NOTE. if in_place method, no retVal return *)
|
||||||
|
if not (Func.is_inplace func) then
|
||||||
|
for i = 0 to ntensors - 1 do
|
||||||
|
pm " retVal%d = &Tensor{ctensor: *ctensorPtr%d}\n" i i
|
||||||
|
done
|
||||||
|
else pm " ts.ctensor = *ptr\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||||
|
pm "} \n"
|
||||||
| `bool ->
|
| `bool ->
|
||||||
pm "\n" ;
|
pm "\n" ;
|
||||||
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
if is_method then pm "func(ts *Tensor) %s(" gofunc_name
|
||||||
|
@ -894,10 +932,10 @@ let write_wrapper funcs filename =
|
||||||
pm "%s" go_args_list ;
|
pm "%s" go_args_list ;
|
||||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
if is_method && not is_inplace then
|
if is_method && not is_inplace then
|
||||||
pm "if del { defer ts.MustDrop() }\n" ;
|
pm " if del { defer ts.MustDrop() }\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " %s" (Func.go_binding_body func) ;
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
pm " if err = TorchErr(); err != nil {\n" ;
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
pm " return %s\n"
|
pm " return %s\n"
|
||||||
(Func.go_return_notype func ~fallible:true) ;
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
@ -911,10 +949,10 @@ let write_wrapper funcs filename =
|
||||||
pm "%s" go_args_list ;
|
pm "%s" go_args_list ;
|
||||||
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:true) ;
|
||||||
if is_method && not is_inplace then
|
if is_method && not is_inplace then
|
||||||
pm "if del { defer ts.MustDrop() }\n" ;
|
pm " if del { defer ts.MustDrop() }\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " %s" (Func.go_binding_body func) ;
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
pm " if err = TorchErr(); err != nil {\n" ;
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
pm " return %s\n"
|
pm " return %s\n"
|
||||||
(Func.go_return_notype func ~fallible:true) ;
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
|
@ -931,15 +969,13 @@ let write_wrapper funcs filename =
|
||||||
pm "if del { defer ts.MustDrop() }\n" ;
|
pm "if del { defer ts.MustDrop() }\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " %s" (Func.go_binding_body func) ;
|
pm " %s" (Func.go_binding_body func) ;
|
||||||
pm "retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
pm " retVal = %s(%s)\n" cfunc_name (Func.go_binding_args func) ;
|
||||||
pm " if err = TorchErr(); err != nil {\n" ;
|
pm " if err = TorchErr(); err != nil {\n" ;
|
||||||
pm " return %s\n"
|
pm " return %s\n"
|
||||||
(Func.go_return_notype func ~fallible:true) ;
|
(Func.go_return_notype func ~fallible:true) ;
|
||||||
pm " }\n" ;
|
pm " }\n" ;
|
||||||
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
pm " return %s\n" (Func.go_return_notype func ~fallible:true) ;
|
||||||
pm "} \n"
|
pm "} \n" ) ;
|
||||||
| `fixed _ -> pm "" ) ;
|
|
||||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
|
||||||
pm "// End of implementing Tensor ================================= \n"
|
pm "// End of implementing Tensor ================================= \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -982,7 +1018,6 @@ let write_must_wrapper funcs filename =
|
||||||
let go_args_list = Func.go_typed_args_list func in
|
let go_args_list = Func.go_typed_args_list func in
|
||||||
let go_args_list_notype = Func.go_notype_args_list func in
|
let go_args_list_notype = Func.go_notype_args_list func in
|
||||||
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
|
(* NOTE. temporarily excluding these functions as not implemented at FFI *)
|
||||||
(* TODO. implement multiple tensors return function []Tensor *)
|
|
||||||
let excluded_funcs =
|
let excluded_funcs =
|
||||||
[ "Chunk"
|
[ "Chunk"
|
||||||
; "AlignTensors"
|
; "AlignTensors"
|
||||||
|
@ -1067,6 +1102,30 @@ let write_must_wrapper funcs filename =
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||||
pm "} \n"
|
pm "} \n"
|
||||||
|
| `fixed _ ->
|
||||||
|
pm "\n" ;
|
||||||
|
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||||
|
else pm "func Must%s(" gofunc_name ;
|
||||||
|
pm "%s" go_args_list ;
|
||||||
|
pm ")(%s) { \n" (Func.go_return_type func ~fallible:false) ;
|
||||||
|
pm " \n" ;
|
||||||
|
(* NOTE. No return retVal for in_place method *)
|
||||||
|
if Func.is_inplace func then
|
||||||
|
if is_method then
|
||||||
|
pm " err := ts.%s(%s)\n" gofunc_name go_args_list_notype
|
||||||
|
else pm " err := %s(%s)\n" gofunc_name go_args_list_notype
|
||||||
|
else if is_method then
|
||||||
|
pm " %s, err := ts.%s(%s)\n"
|
||||||
|
(Func.go_return_notype func ~fallible:false)
|
||||||
|
gofunc_name go_args_list_notype
|
||||||
|
else
|
||||||
|
pm " %s, err := %s(%s)\n"
|
||||||
|
(Func.go_return_notype func ~fallible:false)
|
||||||
|
gofunc_name go_args_list_notype ;
|
||||||
|
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||||
|
pm " \n" ;
|
||||||
|
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||||
|
pm "} \n"
|
||||||
| `bool ->
|
| `bool ->
|
||||||
pm "\n" ;
|
pm "\n" ;
|
||||||
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
if is_method then pm "func(ts *Tensor) Must%s(" gofunc_name
|
||||||
|
@ -1117,9 +1176,7 @@ let write_must_wrapper funcs filename =
|
||||||
pm " if err != nil { log.Fatal(err) }\n" ;
|
pm " if err != nil { log.Fatal(err) }\n" ;
|
||||||
pm " \n" ;
|
pm " \n" ;
|
||||||
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
pm " return %s\n" (Func.go_return_notype func ~fallible:false) ;
|
||||||
pm "} \n"
|
pm "} \n" ) ;
|
||||||
| `fixed _ -> pm "" ) ;
|
|
||||||
(* TODO. implement for return multiple tensor - []Tensor *)
|
|
||||||
pm "// End of implementing Tensor ================================= \n"
|
pm "// End of implementing Tensor ================================= \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -564,27 +564,27 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
|
||||||
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
|
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
|
||||||
// ====================================================================
|
// ====================================================================
|
||||||
|
|
||||||
// void atg_lstsq(tensor *, tensor self, tensor A);
|
// // void atg_lstsq(tensor *, tensor self, tensor A);
|
||||||
func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
// func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
||||||
if del {
|
// if del {
|
||||||
defer ts.MustDrop()
|
// defer ts.MustDrop()
|
||||||
}
|
// }
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
//
|
||||||
lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
// lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
||||||
if err = TorchErr(); err != nil {
|
// if err = TorchErr(); err != nil {
|
||||||
return retVal, err
|
// return retVal, err
|
||||||
}
|
// }
|
||||||
retVal = &Tensor{ctensor: *ptr}
|
// retVal = &Tensor{ctensor: *ptr}
|
||||||
|
//
|
||||||
return retVal, err
|
// return retVal, err
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
// func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
||||||
retVal, err := ts.Lstsq(a, del)
|
// retVal, err := ts.Lstsq(a, del)
|
||||||
if err != nil {
|
// if err != nil {
|
||||||
log.Fatal(err)
|
// log.Fatal(err)
|
||||||
}
|
// }
|
||||||
|
//
|
||||||
return retVal
|
// return retVal
|
||||||
}
|
// }
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -814,10 +814,15 @@ func perspectiveCoeff(startPoints, endPoints [][]int64) []float64 {
|
||||||
// bMat := ts.MustOfSlice(startPoints).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
// bMat := ts.MustOfSlice(startPoints).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
||||||
bMat := ts.MustOfSlice(startData).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
bMat := ts.MustOfSlice(startData).MustTotype(gotch.Float, true).MustView([]int64{8}, true)
|
||||||
|
|
||||||
res := bMat.MustLstsq(aMat, true)
|
// res := bMat.MustLstsq(aMat, true)
|
||||||
|
// Ref. https://github.com/pytorch/vision/blob/d7fa36f221cb2ff670cd4267b83a801cece52522/torchvision/transforms/functional.py#L572
|
||||||
|
solution, residuals, rank, singularValues := bMat.MustLinalgLstsq(aMat, nil, "gels", true)
|
||||||
|
residuals.MustDrop()
|
||||||
|
rank.MustDrop()
|
||||||
|
singularValues.MustDrop()
|
||||||
|
|
||||||
aMat.MustDrop()
|
aMat.MustDrop()
|
||||||
outputTs := res.MustSqueezeDim(1, true)
|
outputTs := solution.MustSqueezeDim(1, true)
|
||||||
output := outputTs.Float64Values()
|
output := outputTs.Float64Values()
|
||||||
outputTs.MustDrop()
|
outputTs.MustDrop()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user