WIP(wrapper/error): handle error message from C; WIP(libtch/tensor_generated.go): at_eq1
This commit is contained in:
parent
f6c22b4df9
commit
ae5f26d567
|
@ -12,7 +12,8 @@ func main() {
|
|||
// Try to compare 2 tensor with incompatible dimensions
|
||||
// and check this returns an error
|
||||
dx := []int32{1, 2, 3}
|
||||
dy := []int32{1, 2, 3, 4}
|
||||
// dy := []int32{1, 2, 3, 4}
|
||||
dy := []int32{1, 2, 5}
|
||||
|
||||
xs, err := wrapper.OfSlice(dx)
|
||||
if err != nil {
|
||||
|
@ -26,7 +27,11 @@ func main() {
|
|||
xs.Print()
|
||||
ys.Print()
|
||||
|
||||
fmt.Printf("xs dim: %v\n", xs.Dim())
|
||||
fmt.Printf("ys dim: %v\n", ys.Dim())
|
||||
fmt.Printf("xs num of dimensions: %v\n", xs.Dim())
|
||||
fmt.Printf("ys num of dimensions: %v\n", ys.Dim())
|
||||
|
||||
fmt.Printf("xs shape: %v\n", xs.Size())
|
||||
fmt.Printf("ys shape: %v\n", ys.Size())
|
||||
|
||||
xs.Eq1(*ys)
|
||||
}
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
// NOTE: this file would be automatically generated by executing `gen` OCaml
|
||||
// folder.
|
||||
package libtch
|
|
@ -17,6 +17,11 @@ func AtNewTensor() *C_tensor {
|
|||
return &C_tensor{private: unsafe.Pointer(t)}
|
||||
}
|
||||
|
||||
func NewTensor() unsafe.Pointer {
|
||||
t := C.at_new_tensor()
|
||||
return unsafe.Pointer(t)
|
||||
}
|
||||
|
||||
func AtTensorOfData(vs unsafe.Pointer, dims []int64, ndims uint, elt_size_in_bytes uint, kind int) *C_tensor {
|
||||
|
||||
// just get pointer of the first element of shape
|
||||
|
|
27
libtch/tensor_generated_sample.go
Normal file
27
libtch/tensor_generated_sample.go
Normal file
|
@ -0,0 +1,27 @@
|
|||
// NOTE: this is a sample for OCaml generated code for `tensor_generated.go`
|
||||
package libtch
|
||||
|
||||
//#include "stdbool.h"
|
||||
//#include "torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// void atg_eq1(tensor *, tensor self, tensor other);
|
||||
func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
|
||||
// // func Atg_eq1(ptr unsafe.Pointer, self *C_tensor, other *C_tensor) {
|
||||
//
|
||||
// // t := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(C.tensor{}))))
|
||||
// var ctensor C.tensor
|
||||
// t := C.malloc(C.size_t(3) * C.size_t(unsafe.Sizeof(uintptr(ctensor))))
|
||||
// // t := C.malloc(1000)
|
||||
// // t := C.at_new_tensor()
|
||||
c_self := (C.tensor)((*self).private)
|
||||
c_other := (C.tensor)((*other).private)
|
||||
|
||||
C.atg_eq1((*C.tensor)(ptr), c_self, c_other)
|
||||
// cptr := (*C.tensor)(ptr)
|
||||
// C.atg_eq1(cptr, c_self, c_other)
|
||||
}
|
|
@ -21,8 +21,10 @@ func ptrToString(cptr *C.char) string {
|
|||
var str string = ""
|
||||
|
||||
if cptr != nil {
|
||||
str = *(*string)(unsafe.Pointer(&cptr))
|
||||
fmt.Printf("Err Msg from C: %v\n", str)
|
||||
// strPtr := (*string)(unsafe.Pointer(cptr))
|
||||
// fmt.Printf("Error: string at err pointer: %v\n", cptr)
|
||||
// TODO: get error string from pointer
|
||||
fmt.Printf("Err Msg from C (TODO: Will show err message here)...\n")
|
||||
C.free(unsafe.Pointer(cptr))
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
package wrapper
|
||||
|
||||
// #include <stdlib.h>
|
||||
//#include "stdbool.h"
|
||||
// #include "../libtch/torch_api.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
@ -53,9 +55,7 @@ func (ts Tensor) Size() []int64 {
|
|||
|
||||
retVal := decodeSize(szPtr, dim)
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
} // Size1 returns the tensor size for 1D tensors.
|
||||
func (ts Tensor) Size1() (retVal int64, err error) {
|
||||
shape := ts.Size()
|
||||
if len(shape) != 1 {
|
||||
|
@ -232,3 +232,66 @@ func (ts Tensor) DType() gotch.DType {
|
|||
|
||||
return dtype
|
||||
}
|
||||
|
||||
func (ts Tensor) Eq1(other Tensor) {
|
||||
|
||||
// var ptr unsafe.Pointer
|
||||
// NOTE:
|
||||
// This will cause panic: runtime error: cgo argument has Go pointer to Go pointer
|
||||
// ptr = NewTensor()
|
||||
// lib.Atg_eq1(unsafe.Pointer(&ptr), ts.ctensor, other.ctensor)
|
||||
|
||||
// C pointer to [1]uintptr (Go pointer)
|
||||
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(uintptr(0))))
|
||||
|
||||
// TODO: create C pointer to a slice of tensors [1]C.tensor using C.malloc
|
||||
// Slice with 1 element type C.tensor
|
||||
// nbytes := C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor))
|
||||
// ctensorsPtr := C.malloc(nbytes)
|
||||
// ctensorsPtr := C.malloc(C.size_t(1) * C.size_t(unsafe.Sizeof(C.tensor)))
|
||||
|
||||
// C null pointer C.tensor * = null
|
||||
ctensorPtr := lib.NewTensor()
|
||||
fmt.Printf("Out tensor BEFORE: %v\n", &ctensorPtr)
|
||||
fmt.Printf("Out tensor address: %v\n", *(*int)(unsafe.Pointer(&ctensorPtr)))
|
||||
|
||||
ctensorAddr := *(*int64)(unsafe.Pointer(&ctensorPtr))
|
||||
var data []int64
|
||||
data = append(data, ctensorAddr)
|
||||
|
||||
// lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(ctensorPtr)))
|
||||
|
||||
// nullPtr := (*C.tensor)(unsafe.Pointer(uintptr(0)))
|
||||
// fmt.Printf("Null pointer: %v\n", &nullPtr)
|
||||
//
|
||||
// data := []*C.tensor{nullPtr}
|
||||
// fmt.Printf("data: %v\n", data)
|
||||
// // Calculate number of bytes for a slice of one element of C null pointer
|
||||
// nbytes := 1 * unsafe.Sizeof(uintptr(0))
|
||||
// fmt.Printf("Nbytes: %v\n", nbytes)
|
||||
//
|
||||
// cptr := C.malloc(C.size_t(nbytes))
|
||||
// ctensorsPtr := (*[1 << 30]byte)(cptr)[:nbytes:nbytes]
|
||||
// buf := bytes.NewBuffer(ctensorsPtr[:0:nbytes])
|
||||
// // ctensorsPtr := (*[1 << 30]C.tensor)(unsafe.Pointer(uintptr(0)))[:nbytes:nbytes]
|
||||
// fmt.Printf("ctensorsPtr 1: %v\n", &ctensorsPtr[0])
|
||||
// fmt.Printf("Type of ctensorsPtr: %v\n", reflect.TypeOf(ctensorsPtr))
|
||||
// // buff := bytes.NewBuffer(dataSlice[:0:nbytes])
|
||||
// // Write to memory
|
||||
// err := binary.Write(buf, nativeEndian, data)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
|
||||
// lib.Atg_eq1(unsafe.Pointer(cptr), ts.ctensor, other.ctensor)
|
||||
lib.Atg_eq1(unsafe.Pointer(&ctensorPtr), ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Out tensor AFTER: %v\n", &ctensorPtr)
|
||||
|
||||
lib.AtPrint((*lib.C_tensor)(unsafe.Pointer(&ctensorPtr)))
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user