From f6c22b4df92df1abc190d27427ba509cb92c1698 Mon Sep 17 00:00:00 2001 From: sugarme Date: Wed, 3 Jun 2020 12:07:08 +1000 Subject: [PATCH] fix(wrapper/error): fixed checking C pointer for null. WIP(example/error): testing TorchErr --- dtype.go | 36 +++++++++++++++++++----------------- example/error/main.go | 32 ++++++++++++++++++++++++++++++++ wrapper/error.go | 22 +++++++++++----------- wrapper/tensor.go | 14 +++++++++----- 4 files changed, 71 insertions(+), 33 deletions(-) create mode 100644 example/error/main.go diff --git a/dtype.go b/dtype.go index 2761936..ef124b4 100644 --- a/dtype.go +++ b/dtype.go @@ -292,23 +292,25 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) { } } -// TypeCheck checks whether data Go type matching DType -func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { - dataValue := reflect.ValueOf(data) - var dataType reflect.Type - var err error - dataType, err = elementType(dataValue) - if err != nil { - msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) - msg += err.Error() - return false, msg - } - - matched = dataType == dtype.Type - msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) - - return matched, msg -} +/* + * // TypeCheck checks whether data Go type matching DType + * func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { + * dataValue := reflect.ValueOf(data) + * var dataType reflect.Type + * var err error + * dataType, err = elementType(dataValue) + * if err != nil { + * msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) + * msg += err.Error() + * return false, msg + * } + * + * matched = dataType == dtype.Type + * msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) + * + * return matched, msg + * } + * */ var supportedTypes = map[reflect.Kind]bool{ reflect.Uint8: true, diff --git a/example/error/main.go b/example/error/main.go new file mode 100644 index 0000000..34f3839 --- /dev/null +++ b/example/error/main.go @@ -0,0 +1,32 @@ +package main + +import ( + "fmt" + "log" + + wrapper "github.com/sugarme/gotch/wrapper" +) + +func main() { + + // Try to compare 2 tensor with incompatible dimensions + // and check this returns an error + dx := []int32{1, 2, 3} + dy := []int32{1, 2, 3, 4} + + xs, err := wrapper.OfSlice(dx) + if err != nil { + log.Fatal(err) + } + ys, err := wrapper.OfSlice(dy) + if err != nil { + log.Fatal(err) + } + + xs.Print() + ys.Print() + + fmt.Printf("xs dim: %v\n", xs.Dim()) + fmt.Printf("ys dim: %v\n", ys.Dim()) + +} diff --git a/wrapper/error.go b/wrapper/error.go index 2a59523..99104e6 100644 --- a/wrapper/error.go +++ b/wrapper/error.go @@ -10,23 +10,23 @@ import ( lib "github.com/sugarme/gotch/libtch" ) -// ptrToString returns nil on the null pointer. If not null, -// the pointer gets freed. +// ptrToString check C pointer for null. If not null, get value +// the pointer points to and frees up C memory. It is used for +// getting error message C pointer points to and clean up C memory. +// // NOTE: C does not have exception design. C++ throws exception // to stderr. This code to check stderr for any err message, -// if it exists, takes it and frees up C pointer. +// if it exists, takes it and frees up C memory. func ptrToString(cptr *C.char) string { - var str string + var str string = "" - str = *(*string)(unsafe.Pointer(&cptr)) - fmt.Printf("Err Msg from C: %v\n", str) - if str != "" { - // Free up C memory + if cptr != nil { + str = *(*string)(unsafe.Pointer(&cptr)) + fmt.Printf("Err Msg from C: %v\n", str) C.free(unsafe.Pointer(cptr)) - return str - } else { - return "" } + + return str } // TorchErr checks and retrieves last error message from diff --git a/wrapper/tensor.go b/wrapper/tensor.go index 163cc6d..5520a9a 100644 --- a/wrapper/tensor.go +++ b/wrapper/tensor.go @@ -128,15 +128,19 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 { // // } -// FOfSlice creates tensor from a slice data -func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) { +// OfSlice creates tensor from a slice data +func OfSlice(data interface{}) (retVal *Tensor, err error) { - if ok, msg := gotch.TypeCheck(data, dtype); !ok { - err = fmt.Errorf("data type and DType are mismatched: %v\n", msg) + typ, dataLen, err := DataCheck(data) + if err != nil { + return nil, err + } + + dtype, err := gotch.ToDType(typ) + if err != nil { return nil, err } - dataLen := reflect.ValueOf(data).Len() shape := []int64{int64(dataLen)} elementNum := ElementCount(shape)