feat(patch): completed 9 missing APIs for 'tensor *atg_' pattern
This commit is contained in:
parent
e5735c77dc
commit
34efa9e3da
45
tensor/patch-example_test.go
Normal file
45
tensor/patch-example_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package tensor_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func ExampleTensor_Split(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
splitTensors := tensor.MustSplit(2, 0, false)
|
||||
|
||||
for _, t := range splitTensors {
|
||||
t.Print()
|
||||
}
|
||||
|
||||
//Output:
|
||||
// 0 1
|
||||
// 2 3
|
||||
// [ CPUFloatType{2,2} ]
|
||||
// 4 5
|
||||
// 6 7
|
||||
// [ CPUFloatType{2,2} ]
|
||||
// 8 9
|
||||
// [ CPUFloatType{1,2} ]
|
||||
}
|
||||
|
||||
func ExampleTensorSplitWithSizes(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true)
|
||||
splitTensors := tensor.MustSplitWithSizes([]int64{1, 4}, 0, false)
|
||||
|
||||
for _, t := range splitTensors {
|
||||
t.Print()
|
||||
}
|
||||
|
||||
//Output:
|
||||
// 0 1
|
||||
// [ CPUFloatType{1,2} ]
|
||||
// 2 3
|
||||
// 4 5
|
||||
// 6 7
|
||||
// 8 9
|
||||
// [ CPUFloatType{4,2} ]
|
||||
}
|
320
tensor/patch.go
320
tensor/patch.go
|
@ -4,11 +4,9 @@ package tensor
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"unsafe"
|
||||
|
||||
// "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
|
@ -201,10 +199,208 @@ func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) {
|
|||
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||
// tensor *atg_where(tensor condition);
|
||||
|
||||
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||
func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
ctensorsPtr := lib.AtgAlignTensors(ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
for _, t := range tensors {
|
||||
defer t.MustDrop()
|
||||
}
|
||||
}
|
||||
retVal, err := AlignTensors(tensors)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||
func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
ctensorsPtr := lib.AtgBroadcastTensors(ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
for _, t := range tensors {
|
||||
defer t.MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
retVal, err := BroadcastTensors(tensors)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
||||
func (ts Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) {
|
||||
ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.Chunk(chunks, dim)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||
func (ts Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
ctensorsPtr := lib.AtgMeshgrid(ctensors, len(ctensors))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.Meshgrid(tensors)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_nonzero_numpy(tensor self);
|
||||
func (ts Tensor) NonzeroNumpy() (retVal []Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.NonzeroNumpy()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Split splits tensor into chunks
|
||||
//
|
||||
// Parameters:
|
||||
// - splitSize – size of a single chunk or list of sizes for each chunk
|
||||
// - splitSize – size of a single chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||
func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
|
||||
|
@ -246,3 +442,121 @@ func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
|
|||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// SplitWithSizes splits tensor into chunks
|
||||
//
|
||||
// Parameters:
|
||||
// - splitSizes – slice of sizes for each chunk
|
||||
// - dim – dimension along which to split the tensor.
|
||||
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||
func (ts Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
|
||||
// C tensor is the `ctensorsPtr` value. The next pointer will be
|
||||
// calculated from there. The vector of tensors will end if the calculated
|
||||
// pointer value is `null`.
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
// calculate the next pointer value
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.SplitWithSizes(splitSizes, dim)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||
func (ts Tensor) Unbind(dim int64) (retVal []Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer ts.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := ts.Unbind(dim)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// tensor *atg_where(tensor condition);
|
||||
func Where(condition Tensor) (retVal []Tensor, err error) {
|
||||
|
||||
ctensorsPtr := lib.AtgWhere(condition.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
currentPtr := ctensorsPtr
|
||||
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||
for {
|
||||
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||
if *nextPtr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||
currentPtr = nextPtr
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustWhere(condition Tensor, del bool) (retVal []Tensor) {
|
||||
if del {
|
||||
defer condition.MustDrop()
|
||||
}
|
||||
|
||||
retVal, err := Where(condition)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user