tensor/dim

This commit is contained in:
sugarme 2020-06-01 17:37:05 +10:00
parent d4beb985e0
commit 3b219ec1e0
3 changed files with 30 additions and 1 deletions

View File

@ -1,7 +1,7 @@
package main
import (
// "fmt"
"fmt"
"log"
// gotch "github.com/sugarme/gotch"
@ -34,6 +34,10 @@ func main() {
ts.Print()
fmt.Println(ts.Dim())
ts.Size()
// typ, count, err := wrapper.DataCheck(data)
// if err != nil {
// log.Fatal(err)

View File

@ -62,3 +62,16 @@ func AtDataPtr(t *C_tensor) unsafe.Pointer {
cTensor := (C.tensor)((*t).private)
return C.at_data_ptr(cTensor)
}
func AtDim(t *C_tensor) uint64 {
cTensor := (C.tensor)((*t).private)
cdim := C.at_dim(cTensor)
return *(*uint64)(unsafe.Pointer(&cdim))
}
func AtShape(t *C_tensor, sz []int64) {
cTensor := (C.tensor)((*t).private)
// just get pointer of the first element
csz := (*C.int64_t)(unsafe.Pointer(&sz[0]))
C.at_shape(cTensor, csz)
}

View File

@ -21,6 +21,18 @@ func NewTensor() Tensor {
return Tensor{ctensor}
}
func (ts Tensor) Dim() uint64 {
return lib.AtDim(ts.ctensor)
}
func (ts Tensor) Size() {
dim := lib.AtDim(ts.ctensor)
sz := []int64{int64(dim)}
lib.AtShape(ts.ctensor, sz)
fmt.Printf("sz val:%v", sz)
// return lib.AtShape(ts.ctensor, sz)
}
// FOfSlice creates tensor from a slice data
func (ts Tensor) FOfSlice(data interface{}, dtype gotch.DType) (retVal *Tensor, err error) {