feat(wrapper/util): TorchErr

This commit is contained in:
sugarme 2020-06-03 11:03:38 +10:00
parent 549e5d1313
commit 3963bea16d
4 changed files with 69 additions and 38 deletions

View File

@ -95,5 +95,18 @@ then in the return of function body
``` ```
### C type pointers e.g. `char *FUNCTION()` --> `*C.char`
then just return the C function call.
```c
char *get_and_reset_last_err(); // thread-local
```
```go
func GetAndResetLastErr() *C.char{
return C.get_and_reset_last_err()
}
```

View File

@ -58,3 +58,7 @@ func AtScalarType(t *C_tensor) int32 {
c_result := C.at_scalar_type(c_tensor) c_result := C.at_scalar_type(c_tensor)
return *(*int32)(unsafe.Pointer(&c_result)) return *(*int32)(unsafe.Pointer(&c_result))
} }
func GetAndResetLastErr() *C.char {
return C.get_and_reset_last_err()
}

View File

@ -1,39 +1,49 @@
package wrapper package wrapper
/* // #include <stdlib.h>
* import "C" import "C"
*
* import ( import (
* "fmt" "fmt"
* ) "unsafe"
*
* // ptrToString returns nil on the null pointer. If not null, lib "github.com/sugarme/gotch/libtch"
* // the pointer gets freed. )
* // NOTE: C does not have exception design. C++ throws exception
* // to stderr. This code to check stderr for any err message, // ptrToString returns nil on the null pointer. If not null,
* // if it exists, takes it and frees up C pointer. // the pointer gets freed.
* func ptrToString(ptr *C.c_char) string { // NOTE: C does not have exception design. C++ throws exception
* var str string // to stderr. This code to check stderr for any err message,
* if !ptr.is_null() { // if it exists, takes it and frees up C pointer.
* // TODO: implement this func ptrToString(cptr *C.char) string {
* // str := GET_ERROR_FROM C std::err var str string
* C.free(ptr)
* return str str = *(*string)(unsafe.Pointer(&cptr))
* } else { fmt.Printf("Err Msg from C: %v\n", str)
* return "" if str != "" {
* } // Free up C memory
* } C.free(unsafe.Pointer(cptr))
* return str
* // readAndCleanError wraps error handling and C memory free up } else {
* func UnsafeTorch(f func()) (retF func(), err error) { return ""
* }
* var str string }
* // TODO: implement this
* // str := ptrToString(torch_sys.get_and_reset_last_err()) // TorchErr checks and retrieves last error message from
* if str != "" { // C `thread_local` if existing and frees up C memory the C pointer
* err = fmt.Errorf("Unsafe error: %v\n", err.Error()) // points to.
* return nil, err //
* } else { // NOTE: Go language atm does not have generic function something
* return f, nil // similar to `macro` in Rust language, does it? So we have to
* } // wrap this function to any Libtorch C function call to check error
* } */ // instead of doing the other way around.
// See Go2 proposal: https://github.com/golang/go/issues/32620
func TorchErr() error {
cptr := (*C.char)(lib.GetAndResetLastErr())
errStr := ptrToString(cptr)
if errStr != "" {
return fmt.Errorf("Libtorch API Err: %v\n", errStr)
}
return nil
}

View File

@ -26,7 +26,11 @@ func NewTensor() Tensor {
} }
func (ts Tensor) Dim() uint64 { func (ts Tensor) Dim() uint64 {
return lib.AtDim(ts.ctensor) retVal := lib.AtDim(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return retVal
} }
// Size return shape of the tensor // Size return shape of the tensor