added missing ts.Lstsq API

This commit is contained in:
sugarme 2021-05-19 13:48:29 +10:00
parent c151197cc8
commit 0d86ab1cf3

View File

@ -560,3 +560,31 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
return retVal
}
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
// ====================================================================
// void atg_lstsq(tensor *, tensor self, tensor A);
func (ts *Tensor) Lstsq(a *ts.Tensor, del bool) (retVal *Tensor, err error) {
if del {
defer ts.MustDrop()
}
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = &Tensor{ctensor: *ptr}
return retVal, err
}
func (ts *Tensor) MustLstsq(a *ts.Tensor, del bool) (retVal *Tensor) {
retVal, err := ts.Lstsq(del)
if err != nil {
log.Fatal(err)
}
return retVal
}