diff --git a/tensor/patch.go b/tensor/patch.go index bf81d46..8140b0a 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -560,3 +560,31 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) { return retVal } + +// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib +// ==================================================================== + +// void atg_lstsq(tensor *, tensor self, tensor A); +func (ts *Tensor) Lstsq(a *ts.Tensor, del bool) (retVal *Tensor, err error) { + if del { + defer ts.MustDrop() + } + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + lib.AtgLstsq(ptr, ts.ctensor, a.ctensor) + if err = TorchErr(); err != nil { + return retVal, err + } + retVal = &Tensor{ctensor: *ptr} + + return retVal, err +} + +func (ts *Tensor) MustLstsq(a *ts.Tensor, del bool) (retVal *Tensor) { + retVal, err := ts.Lstsq(del) + if err != nil { + log.Fatal(err) + } + + return retVal +}