fix(wrapper/error): fixed checking C pointer for null. WIP(example/error): testing TorchErr

This commit is contained in:
sugarme 2020-06-03 12:07:08 +10:00
parent 3963bea16d
commit f6c22b4df9
4 changed files with 71 additions and 33 deletions

View File

@ -292,23 +292,25 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
} }
} }
// TypeCheck checks whether data Go type matching DType /*
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) { * // TypeCheck checks whether data Go type matching DType
dataValue := reflect.ValueOf(data) * func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
var dataType reflect.Type * dataValue := reflect.ValueOf(data)
var err error * var dataType reflect.Type
dataType, err = elementType(dataValue) * var err error
if err != nil { * dataType, err = elementType(dataValue)
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) * if err != nil {
msg += err.Error() * msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
return false, msg * msg += err.Error()
} * return false, msg
* }
matched = dataType == dtype.Type *
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind()) * matched = dataType == dtype.Type
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
return matched, msg *
} * return matched, msg
* }
* */
var supportedTypes = map[reflect.Kind]bool{ var supportedTypes = map[reflect.Kind]bool{
reflect.Uint8: true, reflect.Uint8: true,

32
example/error/main.go Normal file
View File

@ -0,0 +1,32 @@
package main
import (
"fmt"
"log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
// Try to compare 2 tensor with incompatible dimensions
// and check this returns an error
dx := []int32{1, 2, 3}
dy := []int32{1, 2, 3, 4}
xs, err := wrapper.OfSlice(dx)
if err != nil {
log.Fatal(err)
}
ys, err := wrapper.OfSlice(dy)
if err != nil {
log.Fatal(err)
}
xs.Print()
ys.Print()
fmt.Printf("xs dim: %v\n", xs.Dim())
fmt.Printf("ys dim: %v\n", ys.Dim())
}

View File

@ -10,23 +10,23 @@ import (
lib "github.com/sugarme/gotch/libtch" lib "github.com/sugarme/gotch/libtch"
) )
// ptrToString returns nil on the null pointer. If not null, // ptrToString check C pointer for null. If not null, get value
// the pointer gets freed. // the pointer points to and frees up C memory. It is used for
// getting error message C pointer points to and clean up C memory.
//
// NOTE: C does not have exception design. C++ throws exception // NOTE: C does not have exception design. C++ throws exception
// to stderr. This code to check stderr for any err message, // to stderr. This code to check stderr for any err message,
// if it exists, takes it and frees up C pointer. // if it exists, takes it and frees up C memory.
func ptrToString(cptr *C.char) string { func ptrToString(cptr *C.char) string {
var str string var str string = ""
str = *(*string)(unsafe.Pointer(&cptr)) if cptr != nil {
fmt.Printf("Err Msg from C: %v\n", str) str = *(*string)(unsafe.Pointer(&cptr))
if str != "" { fmt.Printf("Err Msg from C: %v\n", str)
// Free up C memory
C.free(unsafe.Pointer(cptr)) C.free(unsafe.Pointer(cptr))
return str
} else {
return ""
} }
return str
} }
// TorchErr checks and retrieves last error message from // TorchErr checks and retrieves last error message from

View File

@ -128,15 +128,19 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
// //
// } // }
// FOfSlice creates tensor from a slice data // OfSlice creates tensor from a slice data
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) { func OfSlice(data interface{}) (retVal *Tensor, err error) {
if ok, msg := gotch.TypeCheck(data, dtype); !ok { typ, dataLen, err := DataCheck(data)
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg) if err != nil {
return nil, err
}
dtype, err := gotch.ToDType(typ)
if err != nil {
return nil, err return nil, err
} }
dataLen := reflect.ValueOf(data).Len()
shape := []int64{int64(dataLen)} shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape) elementNum := ElementCount(shape)