WIP(patch): AtgSplit func
This commit is contained in:
parent
5edded6ca1
commit
e5735c77dc
|
@ -60,7 +60,7 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
||||||
|
|
||||||
## Function Return
|
## Function Return
|
||||||
|
|
||||||
### `void *`
|
### `void *CFUNC(...)`
|
||||||
|
|
||||||
```c
|
```c
|
||||||
void *at_data_ptr(tensor);
|
void *at_data_ptr(tensor);
|
||||||
|
@ -84,6 +84,69 @@ then in the return of function body
|
||||||
return &C_tensor{private: unsafe.Pointer(t)}
|
return &C_tensor{private: unsafe.Pointer(t)}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `tensor *CFUNC(...)`
|
||||||
|
|
||||||
|
The pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`.
|
||||||
|
The returning tensor pointer actually is the FIRST element of a vector of C tensor pointers.
|
||||||
|
Next pointer will be calculated from the first. In C land, verifying a valid pointer is
|
||||||
|
to check whether it points to **NULL**.
|
||||||
|
|
||||||
|
```c
|
||||||
|
|
||||||
|
tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
```go
|
||||||
|
|
||||||
|
// Wrapper
|
||||||
|
func AtgSplit(self Ctensor, splitSize int64, dim int64) *Ctensor {
|
||||||
|
|
||||||
|
csplitSize := *(*C.int64_t)(unsafe.Pointer(&splitSize))
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
|
||||||
|
return C.atg_split(self, csplitSize, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// API
|
||||||
|
|
||||||
|
// Split splits tensor into chunks
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - splitSize – size of a single chunk or list of sizes for each chunk
|
||||||
|
// - dim – dimension along which to split the tensor.
|
||||||
|
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||||
|
func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
|
||||||
|
|
||||||
|
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
|
||||||
|
// C tensor is the `ctensorsPtr` value. The next pointer will be
|
||||||
|
// calculated from there. The vector of tensors will end if the calculated
|
||||||
|
// pointer value is `null`.
|
||||||
|
currentPtr := ctensorsPtr
|
||||||
|
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||||
|
for {
|
||||||
|
// calculate the next pointer value
|
||||||
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||||
|
if *nextPtr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||||
|
currentPtr = nextPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
### C types e.g. `C_ulong` -> Go equivalent types `uint64`
|
### C types e.g. `C_ulong` -> Go equivalent types `uint64`
|
||||||
|
|
||||||
then in the return of function body
|
then in the return of function body
|
||||||
|
|
88
libtch/patch.go
Normal file
88
libtch/patch.go
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
package libtch
|
||||||
|
|
||||||
|
// NOTE. This file is a patch of missing auto-generated APIs in `c-generated.go`
|
||||||
|
|
||||||
|
//#include "stdbool.h"
|
||||||
|
//#include "torch_api.h"
|
||||||
|
import "C"
|
||||||
|
|
||||||
|
import "unsafe"
|
||||||
|
|
||||||
|
// NOTE: 9 patches for pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`:
|
||||||
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
||||||
|
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_nonzero_numpy(tensor self);
|
||||||
|
// tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
|
||||||
|
// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim);
|
||||||
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||||
|
// tensor *atg_where(tensor condition);
|
||||||
|
|
||||||
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
func AtgAlignTensors(tensorsData []Ctensor, tensorsLen int) *Ctensor {
|
||||||
|
|
||||||
|
ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
||||||
|
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
||||||
|
return C.atg_align_tensors(ctensorsDataPtr, ctensorsLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
func AtgBroadcastTensors(tensorsData []Ctensor, tensorsLen int) *Ctensor {
|
||||||
|
|
||||||
|
ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
||||||
|
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
||||||
|
return C.atg_broadcast_tensors(ctensorsDataPtr, ctensorsLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
||||||
|
func AtgChunk(self Ctensor, chunks int64, dim int64) *Ctensor {
|
||||||
|
|
||||||
|
cchunks := *(*C.int64_t)(unsafe.Pointer(&chunks))
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
return C.atg_chunk(self, cchunks, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||||
|
func AtgMeshgrid(tensorsData []Ctensor, tensorsLen int) *Ctensor {
|
||||||
|
|
||||||
|
ctensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
||||||
|
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
||||||
|
return C.atg_meshgrid(ctensorsDataPtr, ctensorsLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_nonzero_numpy(tensor self);
|
||||||
|
func AtgNonzeroNumpy(self Ctensor) *Ctensor {
|
||||||
|
return C.atg_nonzero_numpy(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
|
||||||
|
func AtgSplit(self Ctensor, splitSize int64, dim int64) *Ctensor {
|
||||||
|
|
||||||
|
csplitSize := *(*C.int64_t)(unsafe.Pointer(&splitSize))
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
|
||||||
|
return C.atg_split(self, csplitSize, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim);
|
||||||
|
func AtgSplitWithSizes(self Ctensor, splitSizesData []int64, splitSizesLen int, dim int64) *Ctensor {
|
||||||
|
|
||||||
|
csplitSizesDataPtr := (*C.int64_t)(unsafe.Pointer(&splitSizesData[0]))
|
||||||
|
csplitSizesLen := *(*C.int)(unsafe.Pointer(&splitSizesLen))
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
|
||||||
|
return C.atg_split_with_sizes(self, csplitSizesDataPtr, csplitSizesLen, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||||
|
func AtgUnbind(self Ctensor, dim int64) *Ctensor {
|
||||||
|
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
return C.atg_unbind(self, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tensor *atg_where(tensor condition);
|
||||||
|
func AtgWhere(condition Ctensor) *Ctensor {
|
||||||
|
return C.atg_where(condition)
|
||||||
|
}
|
|
@ -4,6 +4,7 @@ package tensor
|
||||||
import "C"
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
@ -183,3 +184,65 @@ func (ts Tensor) MustNLLLoss(target Tensor, del bool) (retVal Tensor) {
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE: the following 9 APIs are missing from `tensor-generated.go` with
|
||||||
|
// pattern of **return tensor pointer**: `tensor *atg_FUNCTION_NAME()`.
|
||||||
|
// The returning tensor pointer actually is the FIRST element of a vector
|
||||||
|
// of C tensor pointers. Next pointer will be calculated from the first.
|
||||||
|
// In C land, verifying a valid pointer is to check whether it points to **NULL**.
|
||||||
|
//
|
||||||
|
// tensor *atg_align_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_broadcast_tensors(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_chunk(tensor self, int64_t chunks, int64_t dim);
|
||||||
|
// tensor *atg_meshgrid(tensor *tensors_data, int tensors_len);
|
||||||
|
// tensor *atg_nonzero_numpy(tensor self);
|
||||||
|
// tensor *atg_split(tensor self, int64_t split_size, int64_t dim);
|
||||||
|
// tensor *atg_split_with_sizes(tensor self, int64_t *split_sizes_data, int split_sizes_len, int64_t dim);
|
||||||
|
// tensor *atg_unbind(tensor self, int64_t dim);
|
||||||
|
// tensor *atg_where(tensor condition);
|
||||||
|
|
||||||
|
// Split splits tensor into chunks
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - splitSize – size of a single chunk or list of sizes for each chunk
|
||||||
|
// - dim – dimension along which to split the tensor.
|
||||||
|
// Ref. https://pytorch.org/docs/stable/generated/torch.split.html
|
||||||
|
func (ts Tensor) Split(splitSize, dim int64) (retVal []Tensor, err error) {
|
||||||
|
|
||||||
|
ctensorsPtr := lib.AtgSplit(ts.ctensor, splitSize, dim)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOTE: ctensorsPtr is a c-pointer to a vector of tensors. The first
|
||||||
|
// C tensor is the `ctensorsPtr` value. The next pointer will be
|
||||||
|
// calculated from there. The vector of tensors will end if the calculated
|
||||||
|
// pointer value is `null`.
|
||||||
|
currentPtr := ctensorsPtr
|
||||||
|
retVal = append(retVal, Tensor{ctensor: *currentPtr})
|
||||||
|
for {
|
||||||
|
// calculate the next pointer value
|
||||||
|
nextPtr := (*lib.Ctensor)(unsafe.Pointer(uintptr(unsafe.Pointer(currentPtr)) + unsafe.Sizeof(currentPtr)))
|
||||||
|
if *nextPtr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = append(retVal, Tensor{ctensor: *nextPtr})
|
||||||
|
currentPtr = nextPtr
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustSplit(splitSize, dim int64, del bool) (retVal []Tensor) {
|
||||||
|
if del {
|
||||||
|
defer ts.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal, err := ts.Split(splitSize, dim)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user