From 0d86ab1cf3990be51fe6948fece90b02d885b3e8 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 19 May 2021 13:48:29 +1000 Subject: [PATCH] added missing ts.Lstsq API --- tensor/patch.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tensor/patch.go b/tensor/patch.go index bf81d46..8140b0a 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -560,3 +560,31 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) { return retVal } + +// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib +// ==================================================================== + +// void atg_lstsq(tensor *, tensor self, tensor A); +func (ts *Tensor) Lstsq(a *ts.Tensor, del bool) (retVal *Tensor, err error) { + if del { + defer ts.MustDrop() + } + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + + lib.AtgLstsq(ptr, ts.ctensor, a.ctensor) + if err = TorchErr(); err != nil { + return retVal, err + } + retVal = &Tensor{ctensor: *ptr} + + return retVal, err +} + +func (ts *Tensor) MustLstsq(a *ts.Tensor, del bool) (retVal *Tensor) { + retVal, err := ts.Lstsq(del) + if err != nil { + log.Fatal(err) + } + + return retVal +}