diff --git a/tensor/patch-example_test.go b/tensor/patch-example_test.go new file mode 100644 index 0000000..7165245 --- /dev/null +++ b/tensor/patch-example_test.go @@ -0,0 +1,45 @@ +package tensor_test + +import ( + "testing" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +func ExampleTensor_Split(t *testing.T) { + tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true) + splitTensors := tensor.MustSplit(2, 0, false) + + for _, t := range splitTensors { + t.Print() + } + + //Output: + // 0 1 + // 2 3 + // [ CPUFloatType{2,2} ] + // 4 5 + // 6 7 + // [ CPUFloatType{2,2} ] + // 8 9 + // [ CPUFloatType{1,2} ] +} + +func ExampleTensorSplitWithSizes(t *testing.T) { + tensor := ts.MustArange1(ts.FloatScalar(0), ts.FloatScalar(10), gotch.Float, gotch.CPU).MustView([]int64{5, 2}, true) + splitTensors := tensor.MustSplitWithSizes([]int64{1, 4}, 0, false) + + for _, t := range splitTensors { + t.Print() + } + + //Output: + // 0 1 + // [ CPUFloatType{1,2} ] + // 2 3 + // 4 5 + // 6 7 + // 8 9 + // [ CPUFloatType{4,2} ] +} diff --git a/tensor/patch.go b/tensor/patch.go index ef2a6c2..e2f8ead 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -4,11 +4,9 @@ package tensor import "C" import ( - "fmt" "log" "unsafe" - // "github.com/sugarme/gotch" lib "github.com/sugarme/gotch/libtch" ) @@ -201,10 +199,208 @@ func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) { // tensor *atg_unbind(tensor self, int64_t dim); // tensor *atg_where(tensor condition); +// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len); +func AlignTensors(tensors []Tensor) (retVal []Tensor, err error) { + + var ctensors []lib.Ctensor + for _, t := range tensors { + ctensors = append(ctensors, t.ctensor) + } + + ctensorsPtr := lib.AtgAlignTensors(ctensors, len(ctensors)) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func MustAlignTensors(tensors []Tensor, del bool) (retVal []Tensor) { + if del { + for _, t := range tensors { + defer t.MustDrop() + } + } + retVal, err := AlignTensors(tensors) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len); +func BroadcastTensors(tensors []Tensor) (retVal []Tensor, err error) { + + var ctensors []lib.Ctensor + for _, t := range tensors { + ctensors = append(ctensors, t.ctensor) + } + + ctensorsPtr := lib.AtgBroadcastTensors(ctensors, len(ctensors)) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func MustBroadcastTensors(tensors []Tensor, del bool) (retVal []Tensor) { + if del { + for _, t := range tensors { + defer t.MustDrop() + } + } + + retVal, err := BroadcastTensors(tensors) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim); +func (ts Tensor) Chunk(chunks int64, dim int64) (retVal []Tensor, err error) { + ctensorsPtr := lib.AtgChunk(ts.ctensor, chunks, dim) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + // calculate the next pointer value + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustChunk(chunks int64, dim int64, del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.Chunk(chunks, dim) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); +func (ts Tensor) Meshgrid(tensors []Tensor) (retVal []Tensor, err error) { + + var ctensors []lib.Ctensor + for _, t := range tensors { + ctensors = append(ctensors, t.ctensor) + } + + ctensorsPtr := lib.AtgMeshgrid(ctensors, len(ctensors)) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustMeshgrid(tensors []Tensor, del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.Meshgrid(tensors) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_nonzero_numpy(tensor self); +func (ts Tensor) NonzeroNumpy() (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgNonzeroNumpy(ts.ctensor) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustNonzeroNumpy(del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.NonzeroNumpy() + if err != nil { + log.Fatal(err) + } + + return retVal +} + // Split splits tensor into chunks // // Parameters: -// - splitSize – size of a single chunk or list of sizes for each chunk +// - splitSize – size of a single chunk // - dim – dimension along which to split the tensor. // Ref. https://pytorch.org/docs/stable/generated/torch.split.html func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) { @@ -246,3 +442,121 @@ func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) { return retVal } + +// SplitWithSizes splits tensor into chunks +// +// Parameters: +// - splitSizes – slice of sizes for each chunk +// - dim – dimension along which to split the tensor. +// Ref. https://pytorch.org/docs/stable/generated/torch.split.html +func (ts Tensor) SplitWithSizes(splitSizes []int64, dim int64) (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgSplitWithSizes(ts.ctensor, splitSizes, len(splitSizes), dim) + if err = TorchErr(); err != nil { + return retVal, err + } + + // NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first + // C tensor is the `ctensorsPtr` value. The next pointer will be + // calculated from there. The vector of tensors will end if the calculated + // pointer value is `null`. + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + // calculate the next pointer value + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustSplitWithSizes(splitSizes []int64, dim int64, del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.SplitWithSizes(splitSizes, dim) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_unbind(tensor self, int64_t dim); +func (ts Tensor) Unbind(dim int64) (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgUnbind(ts.ctensor, dim) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustUnbind(dim int64, del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.Unbind(dim) + if err != nil { + log.Fatal(err) + } + + return retVal +} + +// tensor *atg_where(tensor condition); +func Where(condition Tensor) (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgWhere(condition.ctensor) + if err = TorchErr(); err != nil { + return retVal, err + } + + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func MustWhere(condition Tensor, del bool) (retVal []Tensor) { + if del { + defer condition.MustDrop() + } + + retVal, err := Where(condition) + if err != nil { + log.Fatal(err) + } + + return retVal +}