feat(wrapper/tensor): tensor.DType

This commit is contained in:
sugarme 2020-06-02 19:29:24 +10:00
parent 45bb5a5907
commit 549e5d1313
5 changed files with 59 additions and 0 deletions

View File

@ -106,6 +106,7 @@ func CInt2DType(v CInt) (dtype DType, err error) {
for key, val := range dtypeCInt {
if val == v {
dtype = key
found = true
break
}
}

View File

@ -53,4 +53,6 @@ func main() {
// fmt.Printf("typ: %v\n", typ)
// fmt.Printf("Count: %v\n", count)
fmt.Printf("DType: %v\n", ts.DType())
}

View File

@ -52,3 +52,9 @@ func AtShape(t *C_tensor, ptr unsafe.Pointer) {
c_ptr := (*C.long)(ptr)
C.at_shape(cTensor, c_ptr)
}
func AtScalarType(t *C_tensor) int32 {
c_tensor := (C.tensor)((*t).private)
c_result := C.at_scalar_type(c_tensor)
return *(*int32)(unsafe.Pointer(&c_result))
}

39
wrapper/error.go Normal file
View File

@ -0,0 +1,39 @@
package wrapper
/*
* import "C"
*
* import (
* "fmt"
* )
*
* // ptrToString returns nil on the null pointer. If not null,
* // the pointer gets freed.
* // NOTE: C does not have exception design. C++ throws exception
* // to stderr. This code to check stderr for any err message,
* // if it exists, takes it and frees up C pointer.
* func ptrToString(ptr *C.c_char) string {
* var str string
* if !ptr.is_null() {
* // TODO: implement this
* // str := GET_ERROR_FROM C std::err
* C.free(ptr)
* return str
* } else {
* return ""
* }
* }
*
* // readAndCleanError wraps error handling and C memory free up
* func UnsafeTorch(f func()) (retF func(), err error) {
*
* var str string
* // TODO: implement this
* // str := ptrToString(torch_sys.get_and_reset_last_err())
* if str != "" {
* err = fmt.Errorf("Unsafe error: %v\n", err.Error())
* return nil, err
* } else {
* return f, nil
* }
* } */

View File

@ -213,3 +213,14 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err
return retVal, nil
}
func (ts Tensor) DType() gotch.DType {
cint := lib.AtScalarType(ts.ctensor)
dtype, err := gotch.CInt2DType(cint)
if err != nil {
log.Fatalf("Tensor DType error: %v\n", err)
}
return dtype
}