fix(wrapper/error): fixed checking C pointer for null. WIP(example/error): testing TorchErr
This commit is contained in:
parent
3963bea16d
commit
f6c22b4df9
36
dtype.go
36
dtype.go
|
@ -292,23 +292,25 @@ func TypeOf(dt DType, shape []int64) (retVal reflect.Type, err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TypeCheck checks whether data Go type matching DType
|
/*
|
||||||
func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
* // TypeCheck checks whether data Go type matching DType
|
||||||
dataValue := reflect.ValueOf(data)
|
* func TypeCheck(data interface{}, dtype DType) (matched bool, msg string) {
|
||||||
var dataType reflect.Type
|
* dataValue := reflect.ValueOf(data)
|
||||||
var err error
|
* var dataType reflect.Type
|
||||||
dataType, err = elementType(dataValue)
|
* var err error
|
||||||
if err != nil {
|
* dataType, err = elementType(dataValue)
|
||||||
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
* if err != nil {
|
||||||
msg += err.Error()
|
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||||
return false, msg
|
* msg += err.Error()
|
||||||
}
|
* return false, msg
|
||||||
|
* }
|
||||||
matched = dataType == dtype.Type
|
*
|
||||||
msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
* matched = dataType == dtype.Type
|
||||||
|
* msg = fmt.Sprintf("data type: %v, DType: %v\n", dataType, dtype.Kind())
|
||||||
return matched, msg
|
*
|
||||||
}
|
* return matched, msg
|
||||||
|
* }
|
||||||
|
* */
|
||||||
|
|
||||||
var supportedTypes = map[reflect.Kind]bool{
|
var supportedTypes = map[reflect.Kind]bool{
|
||||||
reflect.Uint8: true,
|
reflect.Uint8: true,
|
||||||
|
|
32
example/error/main.go
Normal file
32
example/error/main.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
wrapper "github.com/sugarme/gotch/wrapper"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
// Try to compare 2 tensor with incompatible dimensions
|
||||||
|
// and check this returns an error
|
||||||
|
dx := []int32{1, 2, 3}
|
||||||
|
dy := []int32{1, 2, 3, 4}
|
||||||
|
|
||||||
|
xs, err := wrapper.OfSlice(dx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
ys, err := wrapper.OfSlice(dy)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
xs.Print()
|
||||||
|
ys.Print()
|
||||||
|
|
||||||
|
fmt.Printf("xs dim: %v\n", xs.Dim())
|
||||||
|
fmt.Printf("ys dim: %v\n", ys.Dim())
|
||||||
|
|
||||||
|
}
|
|
@ -10,23 +10,23 @@ import (
|
||||||
lib "github.com/sugarme/gotch/libtch"
|
lib "github.com/sugarme/gotch/libtch"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ptrToString returns nil on the null pointer. If not null,
|
// ptrToString check C pointer for null. If not null, get value
|
||||||
// the pointer gets freed.
|
// the pointer points to and frees up C memory. It is used for
|
||||||
|
// getting error message C pointer points to and clean up C memory.
|
||||||
|
//
|
||||||
// NOTE: C does not have exception design. C++ throws exception
|
// NOTE: C does not have exception design. C++ throws exception
|
||||||
// to stderr. This code to check stderr for any err message,
|
// to stderr. This code to check stderr for any err message,
|
||||||
// if it exists, takes it and frees up C pointer.
|
// if it exists, takes it and frees up C memory.
|
||||||
func ptrToString(cptr *C.char) string {
|
func ptrToString(cptr *C.char) string {
|
||||||
var str string
|
var str string = ""
|
||||||
|
|
||||||
str = *(*string)(unsafe.Pointer(&cptr))
|
if cptr != nil {
|
||||||
fmt.Printf("Err Msg from C: %v\n", str)
|
str = *(*string)(unsafe.Pointer(&cptr))
|
||||||
if str != "" {
|
fmt.Printf("Err Msg from C: %v\n", str)
|
||||||
// Free up C memory
|
|
||||||
C.free(unsafe.Pointer(cptr))
|
C.free(unsafe.Pointer(cptr))
|
||||||
return str
|
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
// TorchErr checks and retrieves last error message from
|
// TorchErr checks and retrieves last error message from
|
||||||
|
|
|
@ -128,15 +128,19 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
||||||
//
|
//
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// FOfSlice creates tensor from a slice data
|
// OfSlice creates tensor from a slice data
|
||||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
func OfSlice(data interface{}) (retVal *Tensor, err error) {
|
||||||
|
|
||||||
if ok, msg := gotch.TypeCheck(data, dtype); !ok {
|
typ, dataLen, err := DataCheck(data)
|
||||||
err = fmt.Errorf("data type and DType are mismatched: %v\n", msg)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dtype, err := gotch.ToDType(typ)
|
||||||
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataLen := reflect.ValueOf(data).Len()
|
|
||||||
shape := []int64{int64(dataLen)}
|
shape := []int64{int64(dataLen)}
|
||||||
elementNum := ElementCount(shape)
|
elementNum := ElementCount(shape)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user