feat(wrapper/tensor): tensor.DType
This commit is contained in:
parent
45bb5a5907
commit
549e5d1313
1
dtype.go
1
dtype.go
|
@ -106,6 +106,7 @@ func CInt2DType(v CInt) (dtype DType, err error) {
|
||||||
for key, val := range dtypeCInt {
|
for key, val := range dtypeCInt {
|
||||||
if val == v {
|
if val == v {
|
||||||
dtype = key
|
dtype = key
|
||||||
|
found = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,4 +53,6 @@ func main() {
|
||||||
// fmt.Printf("typ: %v\n", typ)
|
// fmt.Printf("typ: %v\n", typ)
|
||||||
// fmt.Printf("Count: %v\n", count)
|
// fmt.Printf("Count: %v\n", count)
|
||||||
|
|
||||||
|
fmt.Printf("DType: %v\n", ts.DType())
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,3 +52,9 @@ func AtShape(t *C_tensor, ptr unsafe.Pointer) {
|
||||||
c_ptr := (*C.long)(ptr)
|
c_ptr := (*C.long)(ptr)
|
||||||
C.at_shape(cTensor, c_ptr)
|
C.at_shape(cTensor, c_ptr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AtScalarType(t *C_tensor) int32 {
|
||||||
|
c_tensor := (C.tensor)((*t).private)
|
||||||
|
c_result := C.at_scalar_type(c_tensor)
|
||||||
|
return *(*int32)(unsafe.Pointer(&c_result))
|
||||||
|
}
|
||||||
|
|
39
wrapper/error.go
Normal file
39
wrapper/error.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package wrapper
|
||||||
|
|
||||||
|
/*
|
||||||
|
* import "C"
|
||||||
|
*
|
||||||
|
* import (
|
||||||
|
* "fmt"
|
||||||
|
* )
|
||||||
|
*
|
||||||
|
* // ptrToString returns nil on the null pointer. If not null,
|
||||||
|
* // the pointer gets freed.
|
||||||
|
* // NOTE: C does not have exception design. C++ throws exception
|
||||||
|
* // to stderr. This code to check stderr for any err message,
|
||||||
|
* // if it exists, takes it and frees up C pointer.
|
||||||
|
* func ptrToString(ptr *C.c_char) string {
|
||||||
|
* var str string
|
||||||
|
* if !ptr.is_null() {
|
||||||
|
* // TODO: implement this
|
||||||
|
* // str := GET_ERROR_FROM C std::err
|
||||||
|
* C.free(ptr)
|
||||||
|
* return str
|
||||||
|
* } else {
|
||||||
|
* return ""
|
||||||
|
* }
|
||||||
|
* }
|
||||||
|
*
|
||||||
|
* // readAndCleanError wraps error handling and C memory free up
|
||||||
|
* func UnsafeTorch(f func()) (retF func(), err error) {
|
||||||
|
*
|
||||||
|
* var str string
|
||||||
|
* // TODO: implement this
|
||||||
|
* // str := ptrToString(torch_sys.get_and_reset_last_err())
|
||||||
|
* if str != "" {
|
||||||
|
* err = fmt.Errorf("Unsafe error: %v\n", err.Error())
|
||||||
|
* return nil, err
|
||||||
|
* } else {
|
||||||
|
* return f, nil
|
||||||
|
* }
|
||||||
|
* } */
|
|
@ -213,3 +213,14 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal *Tensor, err err
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) DType() gotch.DType {
|
||||||
|
cint := lib.AtScalarType(ts.ctensor)
|
||||||
|
|
||||||
|
dtype, err := gotch.CInt2DType(cint)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Tensor DType error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dtype
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user