WIP(wrapper/error): handle error message from C; WIP(libtch/tensor_generated.go): at_eq1

This commit is contained in:
sugarme 2020-06-04 13:36:20 +10:00
parent f6c22b4df9
commit ae5f26d567
6 changed files with 110 additions and 11 deletions

View File

@ -12,7 +12,8 @@ func main() {
// Try to compare 2 tensor with incompatible dimensions
// and check this returns an error
dx := []int32{1, 2, 3}
dy := []int32{1, 2, 3, 4}
// dy := []int32{1, 2, 3, 4}
dy := []int32{1, 2, 5}
xs, err := wrapper.OfSlice(dx)
if err != nil {
@ -26,7 +27,11 @@ func main() {
xs.Print()
ys.Print()
fmt.Printf("xs dim: %v\n", xs.Dim())
fmt.Printf("ys dim: %v\n", ys.Dim())
fmt.Printf("xs num of dimensions: %v\n", xs.Dim())
fmt.Printf("ys num of dimensions: %v\n", ys.Dim())
fmt.Printf("xs shape: %v\n", xs.Size())
fmt.Printf("ys shape: %v\n", ys.Size())
xs.Eq1(*ys)
}

View File

@ -1,3 +0,0 @@
// NOTE: this file would be automatically generated by executing `gen` OCaml
// folder.
package libtch

View File

@ -17,6 +17,11 @@ func AtNewTensor() *C_tensor {
return &C_tensor{private: unsafe.Pointer(t)}
}
func NewTensor() unsafe.Pointer {
t := C.at_new_tensor()
return unsafe.Pointer(t)
}
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
// just get pointer of the first element of shape

View File

@ -0,0 +1,27 @@
// NOTE: this is a sample for OCaml generated code for `tensor_generated.go`
package libtch
//#include "stdbool.h"
//#include "torch_api.h"
import "C"
import (
"unsafe"
)
// void atg_eq1(tensor *, tensor self, tensor other);
func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
// // func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
//
// // t := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(C.tensor{}))))
// var ctensor C.tensor
// t := C.malloc(C.size_t(3) * C.size_t(unsafe.Sizeof(uintptr(ctensor))))
// // t := C.malloc(1000)
// // t := C.at_new_tensor()
c_self := (C.tensor)((*self).private)
c_other := (C.tensor)((*other).private)
C.atg_eq1((*C.tensor)(ptr), c_self, c_other)
// cptr := (*C.tensor)(ptr)
// C.atg_eq1(cptr, c_self, c_other)
}

View File

@ -21,8 +21,10 @@ func ptrToString(cptr *C.char) string {
var str string = ""
if cptr != nil {
str = *(*string)(unsafe.Pointer(&cptr))
fmt.Printf("Err Msg from C: %v\n", str)
// strPtr := (*string)(unsafe.Pointer(cptr))
// fmt.Printf("Error: string at err pointer: %v\n", cptr)
// TODO: get error string from pointer
fmt.Printf("Err Msg from C (TODO: Will show err message here)...\n")
C.free(unsafe.Pointer(cptr))
}

View File

@ -1,6 +1,8 @@
package wrapper
// #include <stdlib.h>
//#include "stdbool.h"
// #include "../libtch/torch_api.h"
import "C"
import (
@ -53,9 +55,7 @@ func (ts Tensor) Size() []int64 {
retVal := decodeSize(szPtr, dim)
return retVal
}
// Size1 returns the tensor size for 1D tensors.
} // Size1 returns the tensor size for 1D tensors.
func (ts Tensor) Size1() (retVal int64, err error) {
shape := ts.Size()
if len(shape) != 1 {
@ -232,3 +232,66 @@ func (ts Tensor) DType() gotch.DType {
return dtype
}
func (ts Tensor) Eq1(other Tensor) {
// var ptr unsafe.Pointer
// NOTE:
// This will cause panic: runtime error: cgo argument has Go pointer to Go pointer
// ptr = NewTensor()
// lib.Atg_eq1(unsafe.Pointer(&ptr), ts.ctensor, other.ctensor)
// C pointer to [1]uintptr (Go pointer)
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(0))))
// TODO: create C pointer to a slice of tensors [1]C.tensor using C.malloc
// Slice with 1 element type C.tensor
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
// ctensorsPtr := C.malloc(nbytes)
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor)))
// C null pointer C.tensor * = null
ctensorPtr := lib.NewTensor()
fmt.Printf("Out tensor BEFORE: %v\n", &ctensorPtr)
fmt.Printf("Out tensor address: %v\n", *(*int)(unsafe.Pointer(&ctensorPtr)))
ctensorAddr := *(*int64)(unsafe.Pointer(&ctensorPtr))
var data []int64
data = append(data, ctensorAddr)
// lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(ctensorPtr)))
// nullPtr := (*C.tensor)(unsafe.Pointer(uintptr(0)))
// fmt.Printf("Null pointer: %v\n", &nullPtr)
//
// data := []*C.tensor{nullPtr}
// fmt.Printf("data: %v\n", data)
// // Calculate number of bytes for a slice of one element of C null pointer
// nbytes := 1 * unsafe.Sizeof(uintptr(0))
// fmt.Printf("Nbytes: %v\n", nbytes)
//
// cptr := C.malloc(C.size_t(nbytes))
// ctensorsPtr := (*[1 << 30]byte)(cptr)[:nbytes:nbytes]
// buf := bytes.NewBuffer(ctensorsPtr[:0:nbytes])
// // ctensorsPtr := (*[1 << 30]C.tensor)(unsafe.Pointer(uintptr(0)))[:nbytes:nbytes]
// fmt.Printf("ctensorsPtr 1: %v\n", &ctensorsPtr[0])
// fmt.Printf("Type of ctensorsPtr: %v\n", reflect.TypeOf(ctensorsPtr))
// // buff := bytes.NewBuffer(dataSlice[:0:nbytes])
// // Write to memory
// err := binary.Write(buf, nativeEndian, data)
// if err != nil {
// log.Fatal(err)
// }
// lib.Atg_eq1(unsafe.Pointer(cptr), ts.ctensor, other.ctensor)
lib.Atg_eq1(unsafe.Pointer(&ctensorPtr), ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
fmt.Printf("Out tensor AFTER: %v\n", &ctensorPtr)
lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(&ctensorPtr)))
}