feat(libtorch/tensor): added AtDim(), AtSize(); feat(libtorch/README): notes on generating FFI; cleanup
This commit is contained in:
parent
3b219ec1e0
commit
45bb5a5907
|
@ -14,11 +14,11 @@ func main() {
|
|||
// For. if data is []int and DType is Bool
|
||||
// It is still running but get wrong result.
|
||||
data := [][]int64{
|
||||
{1, 1, 1, 2, 2, 2, 1},
|
||||
{1, 1, 1, 2, 2, 2, 1},
|
||||
{1, 1, 1, 2, 2, 2, 3, 3},
|
||||
{1, 1, 1, 2, 2, 2, 4, 4},
|
||||
}
|
||||
// shape := []int64{2, 7}
|
||||
shape := []int64{2, 7}
|
||||
shape := []int64{2, 8}
|
||||
// shape := []int64{2, 2, 4}
|
||||
|
||||
// dtype := gotch.Int
|
||||
// ts := wrapper.NewTensor()
|
||||
|
@ -34,9 +34,16 @@ func main() {
|
|||
|
||||
ts.Print()
|
||||
|
||||
fmt.Println(ts.Dim())
|
||||
// fmt.Printf("Dim: %v\n", ts.Dim())
|
||||
|
||||
ts.Size()
|
||||
// ts.Size()
|
||||
// fmt.Println(ts.Size())
|
||||
|
||||
sz, err := ts.Size2()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("Shape: %v\n", sz)
|
||||
|
||||
// typ, count, err := wrapper.DataCheck(data)
|
||||
// if err != nil {
|
||||
|
|
99
libtch/README.md
Normal file
99
libtch/README.md
Normal file
|
@ -0,0 +1,99 @@
|
|||
# NOTES ON WRITING WRAPPER FUNCTIONS
|
||||
|
||||
|
||||
## Function Input Arguments
|
||||
|
||||
### `tensor` -> `t *C_tensor`
|
||||
|
||||
```c
|
||||
void at_print(tensor);
|
||||
```
|
||||
|
||||
```go
|
||||
func AtPrint(t *C_tensor) {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
C.at_print(c_tensor)
|
||||
}
|
||||
```
|
||||
|
||||
### C pointer e.g `int64_t *` -> `ptr unsafe.Pointer`
|
||||
|
||||
In function body, `cPtr := (*C.long)(ptr)`
|
||||
|
||||
```c
|
||||
void at_shape(tensor, int64_t *);
|
||||
```
|
||||
|
||||
```go
|
||||
func AtShape(t *C_tensor, ptr unsafe.Pointer) {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
c_ptr := (*C.long)(ptr)
|
||||
C.at_shape(c_tensor, c_ptr)
|
||||
}
|
||||
```
|
||||
|
||||
### C types e.g `size_t ndims` -> equivalent Go types `ndims uint`
|
||||
|
||||
In function body, `c_ndims := *(*C.size_t)(unsafe.Pointer(&ndims))`
|
||||
|
||||
```c
|
||||
tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
|
||||
```
|
||||
|
||||
```go
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
|
||||
|
||||
// 1. Unsafe pointer
|
||||
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
||||
c_ndims := *(*C.size_t)(unsafe.Pointer(&ndims))
|
||||
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
|
||||
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
||||
|
||||
// 2. Call C function
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
|
||||
// 3. Form return value
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Function Return
|
||||
|
||||
### `void *`
|
||||
|
||||
```c
|
||||
void *at_data_ptr(tensor);
|
||||
```
|
||||
|
||||
```go
|
||||
func AtDataPtr(t *C_tensor) unsafe.Pointer {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
return C.at_data_ptr(c_tensor)
|
||||
}
|
||||
```
|
||||
|
||||
### `tensor` -> `*C_tensor`
|
||||
|
||||
then in the return of function body
|
||||
|
||||
```go
|
||||
// Call C function
|
||||
t := C.FUNCTION_TO_CALL(...)
|
||||
// Return
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
```
|
||||
|
||||
### C types e.g. `C_ulong` -> Go equivalent types `uint64`
|
||||
|
||||
then in the return of function body
|
||||
|
||||
```go
|
||||
|
||||
c_result := C.FUNCTION_CALL(...)
|
||||
return *(*uint64)(unsafe.Pointer(&c_result))
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
|
@ -5,15 +5,9 @@ package libtch
|
|||
import "C"
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
// "reflect"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// type c_void unsafe.Pointer
|
||||
// type size_t uint
|
||||
// type c_int int32
|
||||
|
||||
type C_tensor struct {
|
||||
private unsafe.Pointer
|
||||
}
|
||||
|
@ -33,45 +27,28 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
|||
|
||||
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
// fmt.Printf("t type: %v\n", reflect.TypeOf(t).Kind())
|
||||
// fmt.Printf("1. C.tensor AtTensorOfData returned from C call: %v\n", t)
|
||||
// Keep C pointer value tin Go struct
|
||||
cTensorPtrVal := unsafe.Pointer(t)
|
||||
// fmt.Printf("2. cTensorPtrVal: %v\n", cTensorPtrVal)
|
||||
|
||||
var retVal *C_tensor
|
||||
retVal = &C_tensor{private: cTensorPtrVal}
|
||||
// fmt.Printf("3. C_tensor.private: %v\n", (*retVal).private)
|
||||
|
||||
// test call C.at_print to print out tensor
|
||||
// C.at_print(*(*C.tensor)(unsafe.Pointer(&t)))
|
||||
// AtPrint(retVal)
|
||||
|
||||
return retVal
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
}
|
||||
|
||||
func AtPrint(t *C_tensor) {
|
||||
// fmt.Printf("4. C_tensor.private AtPrint: %v\n", (*t).private)
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
// fmt.Printf("5. C.tensor AtPrint: %v\n", cTensor)
|
||||
|
||||
C.at_print(cTensor)
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
C.at_print(c_tensor)
|
||||
}
|
||||
|
||||
func AtDataPtr(t *C_tensor) unsafe.Pointer {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
return C.at_data_ptr(cTensor)
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
return C.at_data_ptr(c_tensor)
|
||||
}
|
||||
|
||||
func AtDim(t *C_tensor) uint64 {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
cdim := C.at_dim(cTensor)
|
||||
return *(*uint64)(unsafe.Pointer(&cdim))
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
c_result := C.at_dim(c_tensor)
|
||||
return *(*uint64)(unsafe.Pointer(&c_result))
|
||||
}
|
||||
|
||||
func AtShape(t *C_tensor, sz []int64) {
|
||||
func AtShape(t *C_tensor, ptr unsafe.Pointer) {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
// just get pointer of the first element
|
||||
csz := (*C.int64_t)(unsafe.Pointer(&sz[0]))
|
||||
C.at_shape(cTensor, csz)
|
||||
c_ptr := (*C.long)(ptr)
|
||||
C.at_shape(cTensor, c_ptr)
|
||||
}
|
||||
|
|
|
@ -4,8 +4,12 @@ package wrapper
|
|||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
"unsafe"
|
||||
|
||||
gotch "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
|
@ -25,14 +29,101 @@ func (ts Tensor) Dim() uint64 {
|
|||
return lib.AtDim(ts.ctensor)
|
||||
}
|
||||
|
||||
func (ts Tensor) Size() {
|
||||
// Size return shape of the tensor
|
||||
//
|
||||
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
|
||||
// And returns a slice of sizes or shape using given pointer
|
||||
// to that slice.
|
||||
func (ts Tensor) Size() []int64 {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
sz := []int64{int64(dim)}
|
||||
lib.AtShape(ts.ctensor, sz)
|
||||
fmt.Printf("sz val:%v", sz)
|
||||
// return lib.AtShape(ts.ctensor, sz)
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// TODO: should we free C memory here or at `DataAsPtr` func
|
||||
defer C.free(unsafe.Pointer(szPtr))
|
||||
|
||||
lib.AtShape(ts.ctensor, szPtr)
|
||||
|
||||
retVal := decodeSize(szPtr, dim)
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
func (ts Tensor) Size1() (retVal int64, err error) {
|
||||
shape := ts.Size()
|
||||
if len(shape) != 1 {
|
||||
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return shape[0], nil
|
||||
}
|
||||
|
||||
// Size2 returns the tensor size for 2D tensors.
|
||||
func (ts Tensor) Size2() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
if len(shape) != 2 {
|
||||
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
}
|
||||
|
||||
// Size3 returns the tensor size for 3D tensors.
|
||||
func (ts Tensor) Size3() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
if len(shape) != 3 {
|
||||
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
}
|
||||
|
||||
// Size4 returns the tensor size for 4D tensors.
|
||||
func (ts Tensor) Size4() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
if len(shape) != 4 {
|
||||
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
}
|
||||
|
||||
func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||
// Decode sz
|
||||
// 1. Count number of elements in data
|
||||
elementNum := nsize
|
||||
// 2. Element size in bytes
|
||||
eltSizeInBytes, err := gotch.DTypeSize(gotch.Int64)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
dataSlice := (*[1 << 30]byte)(ptr)[:nbytes:nbytes]
|
||||
r := bytes.NewReader(dataSlice)
|
||||
dataIn := make([]int64, nsize)
|
||||
if err := binary.Read(r, nativeEndian, dataIn); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return dataIn
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for single dimension tensor
|
||||
// func (ts Tensor) Size1() {
|
||||
//
|
||||
// shape := ts.Size()
|
||||
//
|
||||
// fmt.Printf("shape: %v\n", shape)
|
||||
//
|
||||
// }
|
||||
|
||||
// FOfSlice creates tensor from a slice data
|
||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user