added missing ts.Lstsq API
This commit is contained in:
parent
c151197cc8
commit
0d86ab1cf3
|
@ -560,3 +560,31 @@ func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
|
||||
// ====================================================================
|
||||
|
||||
// void atg_lstsq(tensor *, tensor self, tensor A);
|
||||
func (ts *Tensor) Lstsq(a *ts.Tensor, del bool) (retVal *Tensor, err error) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = &Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustLstsq(a *ts.Tensor, del bool) (retVal *Tensor) {
|
||||
retVal, err := ts.Lstsq(del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user