feat(patch): completed 9 missing APIs for 'tensor *atg_' pattern

This commit is contained in:
sugarme 2020-09-24 11:03:21 +10:00
parent e5735c77dc
commit 34efa9e3da
2 changed files with 362 additions and 3 deletions

View File

@ -0,0 +1,45 @@
package tensor_test
import (
"testing"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func ExampleTensor_Split(t *testing.T) {
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
splitTensors := tensor.MustSplit(2, 0, false)
for _, t := range splitTensors {
t.Print()
}
//Output:
// 0 1
// 2 3
// [ CPUFloatType{2,2} ]
// 4 5
// 6 7
// [ CPUFloatType{2,2} ]
// 8 9
// [ CPUFloatType{1,2} ]
}
func ExampleTensorSplitWithSizes(t *testing.T) {
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
splitTensors := tensor.MustSplitWithSizes([]int64{1, 4}, 0, false)
for _, t := range splitTensors {
t.Print()
}
//Output:
// 0 1
// [ CPUFloatType{1,2} ]
// 2 3
// 4 5
// 6 7
// 8 9
// [ CPUFloatType{4,2} ]
}

View File

@ -4,11 +4,9 @@ package tensor
import "C"
import (
"fmt"
"log"
"unsafe"
// "github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
@ -201,10 +199,208 @@ func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) {
// tensor *atg_unbind(tensor self, int64_t dim);
// tensor *atg_where(tensor condition);
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgAlignTensors(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
}
}
retVal, err := AlignTensors(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgBroadcastTensors(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
if del {
for _, t := range tensors {
defer t.MustDrop()
}
}
retVal, err := BroadcastTensors(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
func (ts Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func (ts Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Chunk(chunks, dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
func (ts Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
ctensorsPtr := lib.AtgMeshgrid(ctensors, len(ctensors))
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func (ts Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Meshgrid(tensors)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_nonzero_numpy(tensor self);
func (ts Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func (ts Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.NonzeroNumpy()
if err != nil {
log.Fatal(err)
}
return retVal
}
// Split splits tensor into chunks
//
// Parameters:
// - splitSize size of a single chunk or list of sizes for each chunk
// - splitSize size of a single chunk
// - dim dimension along which to split the tensor.
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
@ -246,3 +442,121 @@ func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
return retVal
}
// SplitWithSizes splits tensor into chunks
//
// Parameters:
// - splitSizes slice of sizes for each chunk
// - dim dimension along which to split the tensor.
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
func (ts Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
if err = TorchErr(); err != nil {
return retVal, err
}
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
// C tensor is the `ctensorsPtr` value. The next pointer will be
// calculated from there. The vector of tensors will end if the calculated
// pointer value is `null`.
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
// calculate the next pointer value
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func (ts Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.SplitWithSizes(splitSizes, dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_unbind(tensor self, int64_t dim);
func (ts Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func (ts Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
if del {
defer ts.MustDrop()
}
retVal, err := ts.Unbind(dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
// tensor *atg_where(tensor condition);
func Where(condition Tensor) (retVal []Tensor, err error) {
ctensorsPtr := lib.AtgWhere(condition.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
currentPtr := ctensorsPtr
retVal = append(retVal, Tensor{ctensor: *currentPtr})
for {
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
if *nextPtr == nil {
break
}
retVal = append(retVal, Tensor{ctensor: *nextPtr})
currentPtr = nextPtr
}
return retVal, nil
}
func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
if del {
defer condition.MustDrop()
}
retVal, err := Where(condition)
if err != nil {
log.Fatal(err)
}
return retVal
}