2022-03-12 07:20:20 +00:00
|
|
|
|
package ts
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
// #include "stdlib.h"
|
|
|
|
|
import "C"
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"log"
|
|
|
|
|
"unsafe"
|
|
|
|
|
|
|
|
|
|
lib "github.com/sugarme/gotch/libtch"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
// NOTE. This is a temporarily patched to make it run.
|
2023-07-05 14:56:48 +01:00
|
|
|
|
// TODO. make change at generator for []*Tensor input
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) Lstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor, err error) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
// NOTE: `atg_lstm` will create 3 consecutive Ctensors in memory of C land. The first
|
|
|
|
|
// Ctensor will have address given by `ctensorPtr1` here.
|
|
|
|
|
// The next pointers can be calculated based on `ctensorPtr1`
|
|
|
|
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
|
|
|
|
ctensorPtr3 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr2)) + unsafe.Sizeof(ctensorPtr1)))
|
|
|
|
|
|
|
|
|
|
var chxData []lib.Ctensor
|
|
|
|
|
for _, t := range hxData {
|
|
|
|
|
chxData = append(chxData, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var cparamsData []lib.Ctensor
|
|
|
|
|
for _, t := range paramsData {
|
|
|
|
|
cparamsData = append(cparamsData, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var chasBiases int32 = 0
|
|
|
|
|
if hasBiases {
|
|
|
|
|
chasBiases = 1
|
|
|
|
|
}
|
|
|
|
|
var ctrain int32 = 0
|
|
|
|
|
if train {
|
|
|
|
|
ctrain = 1
|
|
|
|
|
}
|
|
|
|
|
var cbidirectional int32 = 0
|
|
|
|
|
if bidirectional {
|
|
|
|
|
cbidirectional = 1
|
|
|
|
|
}
|
|
|
|
|
var cbatchFirst int32 = 0
|
|
|
|
|
if batchFirst {
|
|
|
|
|
cbatchFirst = 1
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtgLstm(ctensorPtr1, ts.ctensor, chxData, len(hxData), cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
|
|
|
|
err = TorchErr()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return output, h, c, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), newTensor(*ctensorPtr3), nil
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustLstm(hxData []*Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h, c *Tensor) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
output, h, c, err := ts.Lstm(hxData, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
|
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return output, h, c
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) Gru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor, err error) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
// NOTE: `atg_gru` will create 2 consecutive Ctensors in memory of C land.
|
|
|
|
|
// The first Ctensor will have address given by `ctensorPtr1` here.
|
|
|
|
|
// The next pointer can be calculated based on `ctensorPtr1`
|
|
|
|
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
|
|
|
|
|
|
|
|
|
var cparamsData []lib.Ctensor
|
|
|
|
|
for _, t := range paramsData {
|
|
|
|
|
cparamsData = append(cparamsData, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var chasBiases int32 = 0
|
|
|
|
|
if hasBiases {
|
|
|
|
|
chasBiases = 1
|
|
|
|
|
}
|
|
|
|
|
var ctrain int32 = 0
|
|
|
|
|
if train {
|
|
|
|
|
ctrain = 1
|
|
|
|
|
}
|
|
|
|
|
var cbidirectional int32 = 0
|
|
|
|
|
if bidirectional {
|
|
|
|
|
cbidirectional = 1
|
|
|
|
|
}
|
|
|
|
|
var cbatchFirst int32 = 0
|
|
|
|
|
if batchFirst {
|
|
|
|
|
cbatchFirst = 1
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtgGru(ctensorPtr1, ts.ctensor, hx.ctensor, cparamsData, len(paramsData), chasBiases, numLayers, dropout, ctrain, cbidirectional, cbatchFirst)
|
|
|
|
|
err = TorchErr()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return output, h, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustGru(hx *Tensor, paramsData []*Tensor, hasBiases bool, numLayers int64, dropout float64, train bool, bidirectional bool, batchFirst bool) (output, h *Tensor) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
output, h, err := ts.Gru(hx, paramsData, hasBiases, numLayers, dropout, train, bidirectional, batchFirst)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return output, h
|
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) TopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor, err error) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
// NOTE: `lib.AtgTopk` will return 2 tensors in C memory. First tensor pointer
|
|
|
|
|
// is given by ctensorPtr1
|
|
|
|
|
ctensorPtr1 := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
ctensorPtr2 := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(ctensorPtr1)) + unsafe.Sizeof(ctensorPtr1)))
|
|
|
|
|
var clargest int32 = 0
|
|
|
|
|
if largest {
|
|
|
|
|
clargest = 1
|
|
|
|
|
}
|
|
|
|
|
var csorted int32 = 0
|
|
|
|
|
if sorted {
|
|
|
|
|
csorted = 1
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
lib.AtgTopk(ctensorPtr1, ts.ctensor, k, dim, clargest, csorted)
|
|
|
|
|
err = TorchErr()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return ts1, ts2, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
return newTensor(*ctensorPtr1), newTensor(*ctensorPtr2), nil
|
2020-07-22 06:56:30 +01:00
|
|
|
|
}
|
|
|
|
|
|
2020-10-31 08:25:32 +00:00
|
|
|
|
func (ts *Tensor) MustTopK(k int64, dim int64, largest bool, sorted bool) (ts1, ts2 *Tensor) {
|
2020-07-22 06:56:30 +01:00
|
|
|
|
|
|
|
|
|
ts1, ts2, err := ts.TopK(k, dim, largest, sorted)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ts1, ts2
|
|
|
|
|
}
|
2020-08-02 07:17:49 +01:00
|
|
|
|
|
|
|
|
|
// NOTE. `NLLLoss` is a version of `NllLoss` in tensor-generated
|
|
|
|
|
// with default weight, reduction and ignoreIndex
|
2020-11-01 00:59:08 +00:00
|
|
|
|
func (ts *Tensor) NLLLoss(target *Tensor, del bool) (retVal *Tensor, err error) {
|
2020-08-02 07:17:49 +01:00
|
|
|
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
reduction := int64(1) // Mean of loss
|
|
|
|
|
ignoreIndex := int64(-100)
|
2020-08-03 00:56:59 +01:00
|
|
|
|
// defer C.free(unsafe.Pointer(ptr))
|
2020-08-02 07:17:49 +01:00
|
|
|
|
|
2020-08-03 00:56:59 +01:00
|
|
|
|
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, nil, reduction, ignoreIndex)
|
2020-08-02 07:17:49 +01:00
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-04 14:26:20 +01:00
|
|
|
|
retVal = newTensor(*ptr)
|
2020-08-02 07:17:49 +01:00
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2020-11-01 00:59:08 +00:00
|
|
|
|
func (ts *Tensor) MustNLLLoss(target *Tensor, del bool) (retVal *Tensor) {
|
2020-08-03 00:56:59 +01:00
|
|
|
|
retVal, err := ts.NLLLoss(target, del)
|
2020-08-02 07:17:49 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
2020-09-24 00:52:31 +01:00
|
|
|
|
|
|
|
|
|
// NOTE: the following 9 APIs are missing from `tensor-generated.go` with
|
|
|
|
|
// pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`.
|
|
|
|
|
// The returning tensor pointer actually is the FIRST element of a vector
|
|
|
|
|
// of C tensor pointers. Next pointer will be calculated from the first.
|
|
|
|
|
// In C land, verifying a valid pointer is to check whether it points to **NULL**.
|
|
|
|
|
//
|
|
|
|
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
|
|
|
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
|
|
|
|
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
|
|
|
|
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
|
|
|
|
// tensor *atg_nonzero_numpy(tensor self);
|
|
|
|
|
// tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
|
|
|
|
|
// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim);
|
|
|
|
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
|
|
|
|
// tensor *atg_where(tensor condition);
|
|
|
|
|
|
2020-09-24 02:03:21 +01:00
|
|
|
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func AlignTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
var ctensors []lib.Ctensor
|
|
|
|
|
for _, t := range tensors {
|
|
|
|
|
ctensors = append(ctensors, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgAlignTensors(ctensors, len(ctensors))
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func MustAlignTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
for _, t := range tensors {
|
|
|
|
|
defer t.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
retVal, err := AlignTensors(tensors)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func BroadcastTensors(tensors []*Tensor) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
var ctensors []lib.Ctensor
|
|
|
|
|
for _, t := range tensors {
|
|
|
|
|
ctensors = append(ctensors, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgBroadcastTensors(ctensors, len(ctensors))
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func MustBroadcastTensors(tensors []*Tensor, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
for _, t := range tensors {
|
|
|
|
|
defer t.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := BroadcastTensors(tensors)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) Chunk(chunks int64, dim int64) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
// calculate the next pointer value
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.Chunk(chunks, dim)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func Meshgrid(tensors []*Tensor) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
var ctensors []lib.Ctensor
|
|
|
|
|
for _, t := range tensors {
|
|
|
|
|
ctensors = append(ctensors, t.ctensor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgMeshgrid(ctensors, len(ctensors))
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func MustMeshgrid(tensors []*Tensor) (retVal []*Tensor) {
|
2022-01-17 06:18:54 +00:00
|
|
|
|
retVal, err := Meshgrid(tensors)
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_nonzero_numpy(tensor self);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) NonzeroNumpy() (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustNonzeroNumpy(del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.NonzeroNumpy()
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
2020-09-24 00:52:31 +01:00
|
|
|
|
// Split splits tensor into chunks
|
|
|
|
|
//
|
|
|
|
|
// Parameters:
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// - splitSize – size of a single chunk
|
|
|
|
|
// - dim – dimension along which to split the tensor.
|
|
|
|
|
//
|
2020-09-24 00:52:31 +01:00
|
|
|
|
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) Split(splitSize, dim int64) (retVal []*Tensor, err error) {
|
2020-09-24 00:52:31 +01:00
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
|
|
|
|
|
// C tensor is the `ctensorsPtr` value. The next pointer will be
|
|
|
|
|
// calculated from there. The vector of tensors will end if the calculated
|
|
|
|
|
// pointer value is `null`.
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 00:52:31 +01:00
|
|
|
|
for {
|
|
|
|
|
// calculate the next pointer value
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 00:52:31 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 00:52:31 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []*Tensor) {
|
2020-09-24 00:52:31 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.Split(splitSize, dim)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
// SplitWithSizes splits tensor into chunks
|
|
|
|
|
//
|
|
|
|
|
// Parameters:
|
2023-07-04 14:26:20 +01:00
|
|
|
|
// - splitSizes – slice of sizes for each chunk
|
|
|
|
|
// - dim – dimension along which to split the tensor.
|
|
|
|
|
//
|
2020-09-24 02:03:21 +01:00
|
|
|
|
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
|
|
|
|
|
// C tensor is the `ctensorsPtr` value. The next pointer will be
|
|
|
|
|
// calculated from there. The vector of tensors will end if the calculated
|
|
|
|
|
// pointer value is `null`.
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
// calculate the next pointer value
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.SplitWithSizes(splitSizes, dim)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) Unbind(dim int64) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
|
|
|
|
|
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func (ts *Tensor) MustUnbind(dim int64, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer ts.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := ts.Unbind(dim)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// tensor *atg_where(tensor condition);
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func Where(condition Tensor) (retVal []*Tensor, err error) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
ctensorsPtr := lib.AtgWhere(condition.ctensor)
|
|
|
|
|
if err = TorchErr(); err != nil {
|
|
|
|
|
return retVal, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
currentPtr := ctensorsPtr
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*currentPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
for {
|
|
|
|
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
2023-07-05 14:56:48 +01:00
|
|
|
|
if nextPtr == nil {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
retVal = append(retVal, newTensor(*nextPtr))
|
2020-09-24 02:03:21 +01:00
|
|
|
|
currentPtr = nextPtr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal, nil
|
|
|
|
|
}
|
|
|
|
|
|
2023-07-05 14:56:48 +01:00
|
|
|
|
func MustWhere(condition Tensor, del bool) (retVal []*Tensor) {
|
2020-09-24 02:03:21 +01:00
|
|
|
|
if del {
|
|
|
|
|
defer condition.MustDrop()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
retVal, err := Where(condition)
|
|
|
|
|
if err != nil {
|
|
|
|
|
log.Fatal(err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return retVal
|
|
|
|
|
}
|
2021-05-19 04:48:29 +01:00
|
|
|
|
|
|
|
|
|
// NOTE. patches for APIs `agt_` missing in tensor/ but existing in lib
|
|
|
|
|
// ====================================================================
|
|
|
|
|
|
2021-08-15 12:59:10 +01:00
|
|
|
|
// // void atg_lstsq(tensor *, tensor self, tensor A);
|
|
|
|
|
// func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
|
|
|
|
// if del {
|
|
|
|
|
// defer ts.MustDrop()
|
|
|
|
|
// }
|
|
|
|
|
// ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
|
|
|
|
//
|
|
|
|
|
// lib.AtgLstsq(ptr, ts.ctensor, a.ctensor)
|
|
|
|
|
// if err = TorchErr(); err != nil {
|
|
|
|
|
// return retVal, err
|
|
|
|
|
// }
|
|
|
|
|
// retVal = &Tensor{ctensor: *ptr}
|
|
|
|
|
//
|
|
|
|
|
// return retVal, err
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
|
|
|
|
// retVal, err := ts.Lstsq(a, del)
|
|
|
|
|
// if err != nil {
|
|
|
|
|
// log.Fatal(err)
|
|
|
|
|
// }
|
|
|
|
|
//
|
|
|
|
|
// return retVal
|
|
|
|
|
// }
|