From 549e5d13137804b7ac018e93d26b24eae0f96dcf Mon Sep 17 00:00:00 2001 From: sugarme Date: Tue, 2 Jun 2020 19:29:24 +1000 Subject: [PATCH] feat(wrapper/tensor): tensor.DType --- dtype.go | 1 + example/tensor/main.go | 2 ++ libtch/tensor.go | 6 ++++++ wrapper/error.go | 39 +++++++++++++++++++++++++++++++++++++++ wrapper/tensor.go | 11 +++++++++++ 5 files changed, 59 insertions(+) create mode 100644 wrapper/error.go diff --git a/dtype.go b/dtype.go index 7e877f7..2761936 100644 --- a/dtype.go +++ b/dtype.go @@ -106,6 +106,7 @@ func CInt2DType(v CInt) (dtype DType, err error) { for key, val := range dtypeCInt { if val == v { dtype = key + found = true break } } diff --git a/example/tensor/main.go b/example/tensor/main.go index ae2e191..f6a51cc 100644 --- a/example/tensor/main.go +++ b/example/tensor/main.go @@ -53,4 +53,6 @@ func main() { // fmt.Printf("typ: %v\n", typ) // fmt.Printf("Count: %v\n", count) + fmt.Printf("DType: %v\n", ts.DType()) + } diff --git a/libtch/tensor.go b/libtch/tensor.go index 69e7222..c8a66de 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -52,3 +52,9 @@ func AtShape(t *C_tensor, ptr unsafe.Pointer) { c_ptr := (*C.long)(ptr) C.at_shape(cTensor, c_ptr) } + +func AtScalarType(t *C_tensor) int32 { + c_tensor := (C.tensor)((*t).private) + c_result := C.at_scalar_type(c_tensor) + return *(*int32)(unsafe.Pointer(&c_result)) +} diff --git a/wrapper/error.go b/wrapper/error.go new file mode 100644 index 0000000..01c1459 --- /dev/null +++ b/wrapper/error.go @@ -0,0 +1,39 @@ +package wrapper + +/* + * import "C" + * + * import ( + * "fmt" + * ) + * + * // ptrToString returns nil on the null pointer. If not null, + * // the pointer gets freed. + * // NOTE: C does not have exception design. C++ throws exception + * // to stderr. This code to check stderr for any err message, + * // if it exists, takes it and frees up C pointer. + * func ptrToString(ptr *C.c_char) string { + * var str string + * if !ptr.is_null() { + * // TODO: implement this + * // str := GET_ERROR_FROM C std::err + * C.free(ptr) + * return str + * } else { + * return "" + * } + * } + * + * // readAndCleanError wraps error handling and C memory free up + * func UnsafeTorch(f func()) (retF func(), err error) { + * + * var str string + * // TODO: implement this + * // str := ptrToString(torch_sys.get_and_reset_last_err()) + * if str != "" { + * err = fmt.Errorf("Unsafe error: %v\n", err.Error()) + * return nil, err + * } else { + * return f, nil + * } + * } */ diff --git a/wrapper/tensor.go b/wrapper/tensor.go index b57d64d..2a2817c 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -213,3 +213,14 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err return retVal, nil } + +func (ts Tensor) DType() gotch.DType { + cint := lib.AtScalarType(ts.ctensor) + + dtype, err := gotch.CInt2DType(cint) + if err != nil { + log.Fatalf("Tensor DType error: %v\n", err) + } + + return dtype +}