From 3963bea16d6c74acdba9366ced3fbcfb37911f0b Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 3 Jun 2020 11:03:38 +1000 Subject: [PATCH] feat(wrapper/util): TorchErr --- libtch/README.md | 13 ++++++++ libtch/tensor.go | 4 +++ wrapper/error.go | 84 ++++++++++++++++++++++++++--------------------- wrapper/tensor.go | 6 +++- 4 files changed, 69 insertions(+), 38 deletions(-) diff --git a/libtch/README.md b/libtch/README.md index 345372e..cf16a09 100644 --- a/libtch/README.md +++ b/libtch/README.md @@ -95,5 +95,18 @@ then in the return of function body ``` +### C type pointers e.g. `char *FUNCTION()` --> `*C.char` + +then just return the C function call. + +```c +char *get_and_reset_last_err(); // thread-local +``` + +```go +func GetAndResetLastErr() *C.char{ + return C.get_and_reset_last_err() +} +``` diff --git a/libtch/tensor.go b/libtch/tensor.go index c8a66de..f227977 100644 --- a/libtch/tensor.go +++ b/libtch/tensor.go @@ -58,3 +58,7 @@ func AtScalarType(t *C_tensor) int32 { c_result := C.at_scalar_type(c_tensor) return *(*int32)(unsafe.Pointer(&c_result)) } + +func GetAndResetLastErr() *C.char { + return C.get_and_reset_last_err() +} diff --git a/wrapper/error.go b/wrapper/error.go index 01c1459..2a59523 100644 --- a/wrapper/error.go +++ b/wrapper/error.go @@ -1,39 +1,49 @@ package wrapper -/* - * import "C" - * - * import ( - * "fmt" - * ) - * - * // ptrToString returns nil on the null pointer. If not null, - * // the pointer gets freed. - * // NOTE: C does not have exception design. C++ throws exception - * // to stderr. This code to check stderr for any err message, - * // if it exists, takes it and frees up C pointer. - * func ptrToString(ptr *C.c_char) string { - * var str string - * if !ptr.is_null() { - * // TODO: implement this - * // str := GET_ERROR_FROM C std::err - * C.free(ptr) - * return str - * } else { - * return "" - * } - * } - * - * // readAndCleanError wraps error handling and C memory free up - * func UnsafeTorch(f func()) (retF func(), err error) { - * - * var str string - * // TODO: implement this - * // str := ptrToString(torch_sys.get_and_reset_last_err()) - * if str != "" { - * err = fmt.Errorf("Unsafe error: %v\n", err.Error()) - * return nil, err - * } else { - * return f, nil - * } - * } */ +// #include +import "C" + +import ( + "fmt" + "unsafe" + + lib "github.com/sugarme/gotch/libtch" +) + +// ptrToString returns nil on the null pointer. If not null, +// the pointer gets freed. +// NOTE: C does not have exception design. C++ throws exception +// to stderr. This code to check stderr for any err message, +// if it exists, takes it and frees up C pointer. +func ptrToString(cptr *C.char) string { + var str string + + str = *(*string)(unsafe.Pointer(&cptr)) + fmt.Printf("Err Msg from C: %v\n", str) + if str != "" { + // Free up C memory + C.free(unsafe.Pointer(cptr)) + return str + } else { + return "" + } +} + +// TorchErr checks and retrieves last error message from +// C `thread_local` if existing and frees up C memory the C pointer +// points to. +// +// NOTE: Go language atm does not have generic function something +// similar to `macro` in Rust language, does it? So we have to +// wrap this function to any Libtorch C function call to check error +// instead of doing the other way around. +// See Go2 proposal: https://github.com/golang/go/issues/32620 +func TorchErr() error { + cptr := (*C.char)(lib.GetAndResetLastErr()) + errStr := ptrToString(cptr) + if errStr != "" { + return fmt.Errorf("Libtorch API Err: %v\n", errStr) + } + + return nil +} diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 2a2817c..163cc6d 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -26,7 +26,11 @@ func NewTensor() Tensor { } func (ts Tensor) Dim() uint64 { - return lib.AtDim(ts.ctensor) + retVal := lib.AtDim(ts.ctensor) + if err := TorchErr(); err != nil { + log.Fatal(err) + } + return retVal } // Size return shape of the tensor