From e5735c77dc11c44fe7efeffb23d35e0da5f37796 Mon Sep 17 00:00:00 2001 From: sugarme Date: Thu, 24 Sep 2020 09:52:31 +1000 Subject: [PATCH] WIP(patch): AtgSplit func --- libtch/README.md | 65 ++++++++++++++++++++++++++++++++++- libtch/patch.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++++ tensor/patch.go | 63 ++++++++++++++++++++++++++++++++++ 3 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 libtch/patch.go diff --git a/libtch/README.md b/libtch/README.md index 76bf3d1..b8194ab 100644 --- a/libtch/README.md +++ b/libtch/README.md @@ -60,7 +60,7 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt ## Function Return -### `void *` +### `void *CFUNC(...)` ```c void *at_data_ptr(tensor); @@ -84,6 +84,69 @@ then in the return of function body return &C_tensor{private: unsafe.Pointer(t)} ``` +### `tensor *CFUNC(...)` + +The pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`. +The returning tensor pointer actually is the FIRST element of a vector of C tensor pointers. +Next pointer will be calculated from the first. In C land, verifying a valid pointer is +to check whether it points to **NULL**. + +```c + +tensor *atg_split(tensor self, int64_t split_size, int64_t dim); + +``` + +```go + +// Wrapper +func AtgSplit(self Ctensor, splitSize int64, dim int64) *Ctensor { + + csplitSize := *(*C.int64_t)(unsafe.Pointer(&splitSize)) + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + + return C.atg_split(self, csplitSize, cdim) +} + +// API + +// Split splits tensor into chunks +// +// Parameters: +// - splitSize – size of a single chunk or list of sizes for each chunk +// - dim – dimension along which to split the tensor. +// Ref. https://pytorch.org/docs/stable/generated/torch.split.html +func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim) + if err = TorchErr(); err != nil { + return retVal, err + } + + // NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first + // C tensor is the `ctensorsPtr` value. The next pointer will be + // calculated from there. The vector of tensors will end if the calculated + // pointer value is `null`. + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + // calculate the next pointer value + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + + +``` + + ### C types e.g. `C_ulong` -> Go equivalent types `uint64` then in the return of function body diff --git a/libtch/patch.go b/libtch/patch.go new file mode 100644 index 0000000..83b2f54 --- /dev/null +++ b/libtch/patch.go @@ -0,0 +1,88 @@ +package libtch + +// NOTE. This file is a patch of missing auto-generated APIs in `c-generated.go` + +//#include "stdbool.h" +//#include "torch_api.h" +import "C" + +import "unsafe" + +// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`: +// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len); +// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len); +// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim); +// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); +// tensor *atg_nonzero_numpy(tensor self); +// tensor *atg_split(tensor self, int64_t split_size, int64_t dim); +// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim); +// tensor *atg_unbind(tensor self, int64_t dim); +// tensor *atg_where(tensor condition); + +// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len); +func AtgAlignTensors(tensorsData []Ctensor, tensorsLen int) *Ctensor { + + ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0])) + ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen)) + return C.atg_align_tensors(ctensorsDataPtr, ctensorsLen) +} + +// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len); +func AtgBroadcastTensors(tensorsData []Ctensor, tensorsLen int) *Ctensor { + + ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0])) + ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen)) + return C.atg_broadcast_tensors(ctensorsDataPtr, ctensorsLen) +} + +// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim); +func AtgChunk(self Ctensor, chunks int64, dim int64) *Ctensor { + + cchunks := *(*C.int64_t)(unsafe.Pointer(&chunks)) + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + return C.atg_chunk(self, cchunks, cdim) +} + +// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); +func AtgMeshgrid(tensorsData []Ctensor, tensorsLen int) *Ctensor { + + ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0])) + ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen)) + return C.atg_meshgrid(ctensorsDataPtr, ctensorsLen) +} + +// tensor *atg_nonzero_numpy(tensor self); +func AtgNonzeroNumpy(self Ctensor) *Ctensor { + return C.atg_nonzero_numpy(self) +} + +// tensor *atg_split(tensor self, int64_t split_size, int64_t dim); +func AtgSplit(self Ctensor, splitSize int64, dim int64) *Ctensor { + + csplitSize := *(*C.int64_t)(unsafe.Pointer(&splitSize)) + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + + return C.atg_split(self, csplitSize, cdim) +} + +// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim); +func AtgSplitWithSizes(self Ctensor, splitSizesData []int64, splitSizesLen int, dim int64) *Ctensor { + + csplitSizesDataPtr := (*C.int64_t)(unsafe.Pointer(&splitSizesData[0])) + csplitSizesLen := *(*C.int)(unsafe.Pointer(&splitSizesLen)) + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + + return C.atg_split_with_sizes(self, csplitSizesDataPtr, csplitSizesLen, cdim) +} + +// tensor *atg_unbind(tensor self, int64_t dim); +func AtgUnbind(self Ctensor, dim int64) *Ctensor { + + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + return C.atg_unbind(self, cdim) +} + +// tensor *atg_where(tensor condition); +func AtgWhere(condition Ctensor) *Ctensor { + return C.atg_where(condition) +} diff --git a/tensor/patch.go b/tensor/patch.go index 0dcac81..ef2a6c2 100644 --- a/tensor/patch.go +++ b/tensor/patch.go @@ -4,6 +4,7 @@ package tensor import "C" import ( + "fmt" "log" "unsafe" @@ -183,3 +184,65 @@ func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) { return retVal } + +// NOTE: the following 9 APIs are missing from `tensor-generated.go` with +// pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`. +// The returning tensor pointer actually is the FIRST element of a vector +// of C tensor pointers. Next pointer will be calculated from the first. +// In C land, verifying a valid pointer is to check whether it points to **NULL**. +// +// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len); +// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len); +// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim); +// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len); +// tensor *atg_nonzero_numpy(tensor self); +// tensor *atg_split(tensor self, int64_t split_size, int64_t dim); +// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim); +// tensor *atg_unbind(tensor self, int64_t dim); +// tensor *atg_where(tensor condition); + +// Split splits tensor into chunks +// +// Parameters: +// - splitSize – size of a single chunk or list of sizes for each chunk +// - dim – dimension along which to split the tensor. +// Ref. https://pytorch.org/docs/stable/generated/torch.split.html +func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) { + + ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim) + if err = TorchErr(); err != nil { + return retVal, err + } + + // NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first + // C tensor is the `ctensorsPtr` value. The next pointer will be + // calculated from there. The vector of tensors will end if the calculated + // pointer value is `null`. + currentPtr := ctensorsPtr + retVal = append(retVal, Tensor{ctensor: *currentPtr}) + for { + // calculate the next pointer value + nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr))) + if *nextPtr == nil { + break + } + + retVal = append(retVal, Tensor{ctensor: *nextPtr}) + currentPtr = nextPtr + } + + return retVal, nil +} + +func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) { + if del { + defer ts.MustDrop() + } + + retVal, err := ts.Split(splitSize, dim) + if err != nil { + log.Fatal(err) + } + + return retVal +}