feat(API redesign): remove C_tensor and introduce Ctensor which is alias of C.tensor
This commit is contained in:
parent
2ef7f06e4a
commit
df1c0b34ff
|
@ -12,8 +12,8 @@ func main() {
|
|||
// Try to compare 2 tensor with incompatible dimensions
|
||||
// and check this returns an error
|
||||
dx := []int32{1, 2, 3}
|
||||
dy := []int32{1, 2, 3, 4}
|
||||
// dy := []int32{1, 2, 5}
|
||||
// dy := []int32{1, 2, 3, 4}
|
||||
dy := []int32{1, 2, 5}
|
||||
|
||||
xs, err := wrapper.OfSlice(dx)
|
||||
if err != nil {
|
||||
|
|
|
@ -8,21 +8,20 @@ import (
|
|||
"unsafe"
|
||||
)
|
||||
|
||||
type C_tensor struct {
|
||||
private unsafe.Pointer
|
||||
// NOTE: C.tensor is a C pointer to torch::Tensor
|
||||
type Ctensor = C.tensor
|
||||
|
||||
func AtNewTensor() Ctensor {
|
||||
return C.at_new_tensor()
|
||||
}
|
||||
|
||||
func AtNewTensor() *C_tensor {
|
||||
t := C.at_new_tensor()
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
// tensor at_new_tensor();
|
||||
func NewTensor() Ctensor {
|
||||
return C.at_new_tensor()
|
||||
}
|
||||
|
||||
func NewTensor() unsafe.Pointer {
|
||||
t := C.at_new_tensor()
|
||||
return unsafe.Pointer(t)
|
||||
}
|
||||
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
|
||||
// tensor at_tensor_of_data(void *vs, int64_t *dims, size_t ndims, size_t element_size_in_bytes, int type);
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) Ctensor {
|
||||
|
||||
// just get pointer of the first element of shape
|
||||
c_dims := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
||||
|
@ -30,38 +29,36 @@ func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_byt
|
|||
c_elt_size_in_bytes := *(*C.size_t)(unsafe.Pointer(&elt_size_in_bytes))
|
||||
c_kind := *(*C.int)(unsafe.Pointer(&kind))
|
||||
|
||||
// t is of type `unsafe.Pointer` in Go and `*void` in C
|
||||
t := C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
return C.at_tensor_of_data(vs, c_dims, c_ndims, c_elt_size_in_bytes, c_kind)
|
||||
|
||||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
}
|
||||
|
||||
func AtPrint(t *C_tensor) {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
C.at_print(c_tensor)
|
||||
// void at_print(tensor);
|
||||
func AtPrint(t Ctensor) {
|
||||
C.at_print(t)
|
||||
}
|
||||
|
||||
func AtDataPtr(t *C_tensor) unsafe.Pointer {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
return C.at_data_ptr(c_tensor)
|
||||
// void *at_data_ptr(tensor);
|
||||
func AtDataPtr(t Ctensor) unsafe.Pointer {
|
||||
return C.at_data_ptr(t)
|
||||
}
|
||||
|
||||
func AtDim(t *C_tensor) uint64 {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
c_result := C.at_dim(c_tensor)
|
||||
return *(*uint64)(unsafe.Pointer(&c_result))
|
||||
// size_t at_dim(tensor);
|
||||
func AtDim(t Ctensor) uint64 {
|
||||
result := C.at_dim(t)
|
||||
return *(*uint64)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
func AtShape(t *C_tensor, ptr unsafe.Pointer) {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
// void at_shape(tensor, int64_t *);
|
||||
func AtShape(t Ctensor, ptr unsafe.Pointer) {
|
||||
c_ptr := (*C.long)(ptr)
|
||||
C.at_shape(cTensor, c_ptr)
|
||||
C.at_shape(t, c_ptr)
|
||||
}
|
||||
|
||||
func AtScalarType(t *C_tensor) int32 {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
c_result := C.at_scalar_type(c_tensor)
|
||||
return *(*int32)(unsafe.Pointer(&c_result))
|
||||
// int at_scalar_type(tensor);
|
||||
func AtScalarType(t Ctensor) int32 {
|
||||
result := C.at_scalar_type(t)
|
||||
return *(*int32)(unsafe.Pointer(&result))
|
||||
}
|
||||
|
||||
func GetAndResetLastErr() *C.char {
|
||||
|
|
|
@ -5,23 +5,7 @@ package libtch
|
|||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// void atg_eq1(tensor *, tensor self, tensor other);
|
||||
func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
|
||||
// // func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
|
||||
//
|
||||
// // t := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(C.tensor{}))))
|
||||
// var ctensor C.tensor
|
||||
// t := C.malloc(C.size_t(3) * C.size_t(unsafe.Sizeof(uintptr(ctensor))))
|
||||
// // t := C.malloc(1000)
|
||||
// // t := C.at_new_tensor()
|
||||
c_self := (C.tensor)((*self).private)
|
||||
c_other := (C.tensor)((*other).private)
|
||||
|
||||
C.atg_eq1((*C.tensor)(ptr), c_self, c_other)
|
||||
// cptr := (*C.tensor)(ptr)
|
||||
// C.atg_eq1(cptr, c_self, c_other)
|
||||
func AtgEq1(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_eq1(ptr, self, other)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
type Tensor struct {
|
||||
ctensor *lib.C_tensor
|
||||
ctensor lib.Ctensor
|
||||
}
|
||||
|
||||
// NewTensor creates a new tensor
|
||||
|
@ -235,37 +235,17 @@ func (ts Tensor) DType() gotch.DType {
|
|||
|
||||
func (ts Tensor) Eq1(other Tensor) {
|
||||
|
||||
// var ptr unsafe.Pointer
|
||||
// NOTE:
|
||||
// This will cause panic: runtime error: cgo argument has Go pointer to Go pointer
|
||||
// ptr = NewTensor()
|
||||
// lib.Atg_eq1(unsafe.Pointer(&ptr), ts.ctensor, other.ctensor)
|
||||
// Get a C null pointer
|
||||
// https://stackoverflow.com/a/2022369
|
||||
// ctensorPtr := C.malloc(0)
|
||||
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
// C pointer to [1]uintptr (Go pointer)
|
||||
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(0))))
|
||||
|
||||
// TODO: create C pointer to a slice of tensors [1]C.tensor using C.malloc
|
||||
// Slice with 1 element type C.tensor
|
||||
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
|
||||
// ctensorsPtr := C.malloc(nbytes)
|
||||
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor)))
|
||||
|
||||
// C null pointer C.tensor * = null
|
||||
// ctensorPtr := lib.NewTensor()
|
||||
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
|
||||
|
||||
// Get a pointer in C memory
|
||||
ctensorPtr := C.malloc(0)
|
||||
fmt.Printf("ctensorPtr: %v\n", ctensorPtr)
|
||||
|
||||
lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
|
||||
// lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
|
||||
lib.AtgEq1(ctensorPtr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(ctensorPtr)))
|
||||
|
||||
// fmt.Printf("Out tensor AFTER: %v\n", &ctensorPtr)
|
||||
|
||||
lib.AtPrint(*ctensorPtr)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user