fix(wrapper/error): fixed checking C pointer for null. WIP(example/error): testing TorchErr
This commit is contained in:
parent
3963bea16d
commit
f6c22b4df9
36
dtype.go
36
dtype.go
|
@ -292,23 +292,25 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// TypeCheck checks whether data Go type matching DType
|
||||
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||
dataValue := reflect.ValueOf(data)
|
||||
var dataType reflect.Type
|
||||
var err error
|
||||
dataType, err = elementType(dataValue)
|
||||
if err != nil {
|
||||
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
msg += err.Error()
|
||||
return false, msg
|
||||
}
|
||||
|
||||
matched = dataType == dtype.Type
|
||||
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
|
||||
return matched, msg
|
||||
}
|
||||
/*
|
||||
* // TypeCheck checks whether data Go type matching DType
|
||||
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||
* dataValue := reflect.ValueOf(data)
|
||||
* var dataType reflect.Type
|
||||
* var err error
|
||||
* dataType, err = elementType(dataValue)
|
||||
* if err != nil {
|
||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
* msg += err.Error()
|
||||
* return false, msg
|
||||
* }
|
||||
*
|
||||
* matched = dataType == dtype.Type
|
||||
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||
*
|
||||
* return matched, msg
|
||||
* }
|
||||
* */
|
||||
|
||||
var supportedTypes = map[reflect.Kind]bool{
|
||||
reflect.Uint8: true,
|
||||
|
|
32
example/error/main.go
Normal file
32
example/error/main.go
Normal file
|
@ -0,0 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
// Try to compare 2 tensor with incompatible dimensions
|
||||
// and check this returns an error
|
||||
dx := []int32{1, 2, 3}
|
||||
dy := []int32{1, 2, 3, 4}
|
||||
|
||||
xs, err := wrapper.OfSlice(dx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
ys, err := wrapper.OfSlice(dy)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
xs.Print()
|
||||
ys.Print()
|
||||
|
||||
fmt.Printf("xs dim: %v\n", xs.Dim())
|
||||
fmt.Printf("ys dim: %v\n", ys.Dim())
|
||||
|
||||
}
|
|
@ -10,23 +10,23 @@ import (
|
|||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
// ptrToString returns nil on the null pointer. If not null,
|
||||
// the pointer gets freed.
|
||||
// ptrToString check C pointer for null. If not null, get value
|
||||
// the pointer points to and frees up C memory. It is used for
|
||||
// getting error message C pointer points to and clean up C memory.
|
||||
//
|
||||
// NOTE: C does not have exception design. C++ throws exception
|
||||
// to stderr. This code to check stderr for any err message,
|
||||
// if it exists, takes it and frees up C pointer.
|
||||
// if it exists, takes it and frees up C memory.
|
||||
func ptrToString(cptr *C.char) string {
|
||||
var str string
|
||||
var str string = ""
|
||||
|
||||
str = *(*string)(unsafe.Pointer(&cptr))
|
||||
fmt.Printf("Err Msg from C: %v\n", str)
|
||||
if str != "" {
|
||||
// Free up C memory
|
||||
if cptr != nil {
|
||||
str = *(*string)(unsafe.Pointer(&cptr))
|
||||
fmt.Printf("Err Msg from C: %v\n", str)
|
||||
C.free(unsafe.Pointer(cptr))
|
||||
return str
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
|
||||
return str
|
||||
}
|
||||
|
||||
// TorchErr checks and retrieves last error message from
|
||||
|
|
|
@ -128,15 +128,19 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
|||
//
|
||||
// }
|
||||
|
||||
// FOfSlice creates tensor from a slice data
|
||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||
// OfSlice creates tensor from a slice data
|
||||
func OfSlice(data interface{}) (retVal *Tensor, err error) {
|
||||
|
||||
if ok, msg := gotch.TypeCheck(data, dtype); !ok {
|
||||
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg)
|
||||
typ, dataLen, err := DataCheck(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dtype, err := gotch.ToDType(typ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataLen := reflect.ValueOf(data).Len()
|
||||
shape := []int64{int64(dataLen)}
|
||||
elementNum := ElementCount(shape)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user