tensor/dim
This commit is contained in:
parent
d4beb985e0
commit
3b219ec1e0
|
@ -1,7 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
// gotch "github.com/sugarme/gotch"
|
||||
|
@ -34,6 +34,10 @@ func main() {
|
|||
|
||||
ts.Print()
|
||||
|
||||
fmt.Println(ts.Dim())
|
||||
|
||||
ts.Size()
|
||||
|
||||
// typ, count, err := wrapper.DataCheck(data)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
|
|
|
@ -62,3 +62,16 @@ func AtDataPtr(t *C_tensor) unsafe.Pointer {
|
|||
cTensor := (C.tensor)((*t).private)
|
||||
return C.at_data_ptr(cTensor)
|
||||
}
|
||||
|
||||
func AtDim(t *C_tensor) uint64 {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
cdim := C.at_dim(cTensor)
|
||||
return *(*uint64)(unsafe.Pointer(&cdim))
|
||||
}
|
||||
|
||||
func AtShape(t *C_tensor, sz []int64) {
|
||||
cTensor := (C.tensor)((*t).private)
|
||||
// just get pointer of the first element
|
||||
csz := (*C.int64_t)(unsafe.Pointer(&sz[0]))
|
||||
C.at_shape(cTensor, csz)
|
||||
}
|
||||
|
|
|
@ -21,6 +21,18 @@ func NewTensor() Tensor {
|
|||
return Tensor{ctensor}
|
||||
}
|
||||
|
||||
func (ts Tensor) Dim() uint64 {
|
||||
return lib.AtDim(ts.ctensor)
|
||||
}
|
||||
|
||||
func (ts Tensor) Size() {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
sz := []int64{int64(dim)}
|
||||
lib.AtShape(ts.ctensor, sz)
|
||||
fmt.Printf("sz val:%v", sz)
|
||||
// return lib.AtShape(ts.ctensor, sz)
|
||||
}
|
||||
|
||||
// FOfSlice creates tensor from a slice data
|
||||
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user