feat(wrapper/tensor): tensor.DType
This commit is contained in:
parent
45bb5a5907
commit
549e5d1313
1
dtype.go
1
dtype.go
|
@ -106,6 +106,7 @@ func CInt2DType(v CInt) (dtype DType, err error) {
|
|||
for key, val := range dtypeCInt {
|
||||
if val == v {
|
||||
dtype = key
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
@ -53,4 +53,6 @@ func main() {
|
|||
// fmt.Printf("typ: %v\n", typ)
|
||||
// fmt.Printf("Count: %v\n", count)
|
||||
|
||||
fmt.Printf("DType: %v\n", ts.DType())
|
||||
|
||||
}
|
||||
|
|
|
@ -52,3 +52,9 @@ func AtShape(t *C_tensor, ptr unsafe.Pointer) {
|
|||
c_ptr := (*C.long)(ptr)
|
||||
C.at_shape(cTensor, c_ptr)
|
||||
}
|
||||
|
||||
func AtScalarType(t *C_tensor) int32 {
|
||||
c_tensor := (C.tensor)((*t).private)
|
||||
c_result := C.at_scalar_type(c_tensor)
|
||||
return *(*int32)(unsafe.Pointer(&c_result))
|
||||
}
|
||||
|
|
39
wrapper/error.go
Normal file
39
wrapper/error.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package wrapper
|
||||
|
||||
/*
|
||||
* import "C"
|
||||
*
|
||||
* import (
|
||||
* "fmt"
|
||||
* )
|
||||
*
|
||||
* // ptrToString returns nil on the null pointer. If not null,
|
||||
* // the pointer gets freed.
|
||||
* // NOTE: C does not have exception design. C++ throws exception
|
||||
* // to stderr. This code to check stderr for any err message,
|
||||
* // if it exists, takes it and frees up C pointer.
|
||||
* func ptrToString(ptr *C.c_char) string {
|
||||
* var str string
|
||||
* if !ptr.is_null() {
|
||||
* // TODO: implement this
|
||||
* // str := GET_ERROR_FROM C std::err
|
||||
* C.free(ptr)
|
||||
* return str
|
||||
* } else {
|
||||
* return ""
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* // readAndCleanError wraps error handling and C memory free up
|
||||
* func UnsafeTorch(f func()) (retF func(), err error) {
|
||||
*
|
||||
* var str string
|
||||
* // TODO: implement this
|
||||
* // str := ptrToString(torch_sys.get_and_reset_last_err())
|
||||
* if str != "" {
|
||||
* err = fmt.Errorf("Unsafe error: %v\n", err.Error())
|
||||
* return nil, err
|
||||
* } else {
|
||||
* return f, nil
|
||||
* }
|
||||
* } */
|
|
@ -213,3 +213,14 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err
|
|||
return retVal, nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) DType() gotch.DType {
|
||||
cint := lib.AtScalarType(ts.ctensor)
|
||||
|
||||
dtype, err := gotch.CInt2DType(cint)
|
||||
if err != nil {
|
||||
log.Fatalf("Tensor DType error: %v\n", err)
|
||||
}
|
||||
|
||||
return dtype
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user