feat(API redesign): remove C_tensor and introduce Ctensor which is alias of C.tensor

This commit is contained in:
sugarme 2020-06-04 16:23:53 +10:00
parent 2ef7f06e4a
commit df1c0b34ff
4 changed files with 40 additions and 79 deletions

View File

@ -12,8 +12,8 @@ func main() {
// Try to compare 2 tensor with incompatible dimensions
// and check this returns an error
dx := []int32{1, 2, 3}
dy := []int32{1, 2, 3, 4}
// dy := []int32{1, 2, 5}
// dy := []int32{1, 2, 3, 4}
dy := []int32{1, 2, 5}
xs, err := wrapper.OfSlice(dx)
if err != nil {

View File

@ -8,21 +8,20 @@ import (
"unsafe"
)
type C_tensor struct {
private unsafe.Pointer
// NOTE: C.tensor is a C pointer to torch::Tensor
type Ctensor = C.tensor
func AtNewTensor() Ctensor {
return C.at_new_tensor()
}
func AtNewTensor() *C_tensor {
t := C.at_new_tensor()
return &C_tensor{private: unsafe.Pointer(t)}
// tensor at_new_tensor();
func NewTensor() Ctensor {
return C.at_new_tensor()
}
func NewTensor() unsafe.Pointer {
t := C.at_new_tensor()
return unsafe.Pointer(t)
}
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
// just get pointer of the first element of shape
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
@ -30,38 +29,36 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
c_kind := *(*C.int)(unsafe.Pointer(&kind))
// t is of type `unsafe.Pointer` in Go and `*void` in C
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
return C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
return &C_tensor{private: unsafe.Pointer(t)}
}
func AtPrint(t *C_tensor) {
c_tensor := (C.tensor)((*t).private)
C.at_print(c_tensor)
// void at_print(tensor);
func AtPrint(t Ctensor) {
C.at_print(t)
}
func AtDataPtr(t *C_tensor) unsafe.Pointer {
c_tensor := (C.tensor)((*t).private)
return C.at_data_ptr(c_tensor)
// void *at_data_ptr(tensor);
func AtDataPtr(t Ctensor) unsafe.Pointer {
return C.at_data_ptr(t)
}
func AtDim(t *C_tensor) uint64 {
c_tensor := (C.tensor)((*t).private)
c_result := C.at_dim(c_tensor)
return *(*uint64)(unsafe.Pointer(&c_result))
// size_t at_dim(tensor);
func AtDim(t Ctensor) uint64 {
result := C.at_dim(t)
return *(*uint64)(unsafe.Pointer(&result))
}
func AtShape(t *C_tensor, ptr unsafe.Pointer) {
cTensor := (C.tensor)((*t).private)
// void at_shape(tensor, int64_t *);
func AtShape(t Ctensor, ptr unsafe.Pointer) {
c_ptr := (*C.long)(ptr)
C.at_shape(cTensor, c_ptr)
C.at_shape(t, c_ptr)
}
func AtScalarType(t *C_tensor) int32 {
c_tensor := (C.tensor)((*t).private)
c_result := C.at_scalar_type(c_tensor)
return *(*int32)(unsafe.Pointer(&c_result))
// int at_scalar_type(tensor);
func AtScalarType(t Ctensor) int32 {
result := C.at_scalar_type(t)
return *(*int32)(unsafe.Pointer(&result))
}
func GetAndResetLastErr() *C.char {

View File

@ -5,23 +5,7 @@ package libtch
//#include "torch_api.h"
import "C"
import (
"unsafe"
)
// void atg_eq1(tensor *, tensor self, tensor other);
func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
// // func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
//
// // t := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(C.tensor{}))))
// var ctensor C.tensor
// t := C.malloc(C.size_t(3) * C.size_t(unsafe.Sizeof(uintptr(ctensor))))
// // t := C.malloc(1000)
// // t := C.at_new_tensor()
c_self := (C.tensor)((*self).private)
c_other := (C.tensor)((*other).private)
C.atg_eq1((*C.tensor)(ptr), c_self, c_other)
// cptr := (*C.tensor)(ptr)
// C.atg_eq1(cptr, c_self, c_other)
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_eq1(ptr, self, other)
}

View File

@ -18,7 +18,7 @@ import (
)
type Tensor struct {
ctensor *lib.C_tensor
ctensor lib.Ctensor
}
// NewTensor creates a new tensor
@ -235,37 +235,17 @@ func (ts Tensor) DType() gotch.DType {
func (ts Tensor) Eq1(other Tensor) {
// var ptr unsafe.Pointer
// NOTE:
// This will cause panic: runtime error: cgo argument has Go pointer to Go pointer
// ptr = NewTensor()
// lib.Atg_eq1(unsafe.Pointer(&ptr), ts.ctensor, other.ctensor)
// Get a C null pointer
// https://stackoverflow.com/a/2022369
// ctensorPtr := C.malloc(0)
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
// C pointer to [1]uintptr (Go pointer)
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(0))))
// TODO: create C pointer to a slice of tensors [1]C.tensor using C.malloc
// Slice with 1 element type C.tensor
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
// ctensorsPtr := C.malloc(nbytes)
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor)))
// C null pointer C.tensor * = null
// ctensorPtr := lib.NewTensor()
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
// Get a pointer in C memory
ctensorPtr := C.malloc(0)
fmt.Printf("ctensorPtr: %v\n", ctensorPtr)
lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
// lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
lib.AtgEq1(ctensorPtr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(ctensorPtr)))
// fmt.Printf("Out tensor AFTER: %v\n", &ctensorPtr)
lib.AtPrint(*ctensorPtr)
}