fix(wrapper/error): fixed checking C pointer for null. WIP(example/error): testing TorchErr

This commit is contained in:
sugarme 2020-06-03 12:07:08 +10:00
parent 3963bea16d
commit f6c22b4df9
4 changed files with 71 additions and 33 deletions

View File

@ -292,23 +292,25 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
}
}
// TypeCheck checks whether data Go type matching DType
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
dataValue := reflect.ValueOf(data)
var dataType reflect.Type
var err error
dataType, err = elementType(dataValue)
if err != nil {
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
msg += err.Error()
return false, msg
}
matched = dataType == dtype.Type
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
return matched, msg
}
/*
* // TypeCheck checks whether data Go type matching DType
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
* dataValue := reflect.ValueOf(data)
* var dataType reflect.Type
* var err error
* dataType, err = elementType(dataValue)
* if err != nil {
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
* msg += err.Error()
* return false, msg
* }
*
* matched = dataType == dtype.Type
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
*
* return matched, msg
* }
* */
var supportedTypes = map[reflect.Kind]bool{
reflect.Uint8: true,

32
example/error/main.go Normal file
View File

@ -0,0 +1,32 @@
package main
import (
"fmt"
"log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
// Try to compare 2 tensor with incompatible dimensions
// and check this returns an error
dx := []int32{1, 2, 3}
dy := []int32{1, 2, 3, 4}
xs, err := wrapper.OfSlice(dx)
if err != nil {
log.Fatal(err)
}
ys, err := wrapper.OfSlice(dy)
if err != nil {
log.Fatal(err)
}
xs.Print()
ys.Print()
fmt.Printf("xs dim: %v\n", xs.Dim())
fmt.Printf("ys dim: %v\n", ys.Dim())
}

View File

@ -10,23 +10,23 @@ import (
lib "github.com/sugarme/gotch/libtch"
)
// ptrToString returns nil on the null pointer. If not null,
// the pointer gets freed.
// ptrToString check C pointer for null. If not null, get value
// the pointer points to and frees up C memory. It is used for
// getting error message C pointer points to and clean up C memory.
//
// NOTE: C does not have exception design. C++ throws exception
// to stderr. This code to check stderr for any err message,
// if it exists, takes it and frees up C pointer.
// if it exists, takes it and frees up C memory.
func ptrToString(cptr *C.char) string {
var str string
var str string = ""
str = *(*string)(unsafe.Pointer(&cptr))
fmt.Printf("Err Msg from C: %v\n", str)
if str != "" {
// Free up C memory
if cptr != nil {
str = *(*string)(unsafe.Pointer(&cptr))
fmt.Printf("Err Msg from C: %v\n", str)
C.free(unsafe.Pointer(cptr))
return str
} else {
return ""
}
return str
}
// TorchErr checks and retrieves last error message from

View File

@ -128,15 +128,19 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
//
// }
// FOfSlice creates tensor from a slice data
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (retVal *Tensor, err error) {
if ok, msg := gotch.TypeCheck(data, dtype); !ok {
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg)
typ, dataLen, err := DataCheck(data)
if err != nil {
return nil, err
}
dtype, err := gotch.ToDType(typ)
if err != nil {
return nil, err
}
dataLen := reflect.ValueOf(data).Len()
shape := []int64{int64(dataLen)}
elementNum := ElementCount(shape)