WIP(wrapper/tensor): added more Tensor methods

This commit is contained in:
sugarme 2020-06-06 18:12:42 +10:00
parent 52643de1da
commit 1c1122c4ea
4 changed files with 139 additions and 12 deletions

43
example/tensor1/main.go Normal file
View File

@ -0,0 +1,43 @@
package main
import (
"fmt"
"log"
wrapper "github.com/sugarme/gotch/wrapper"
)
func main() {
ts, err := wrapper.OfSlice([]float64{1.3, 29.7})
if err != nil {
log.Fatal(err)
}
res, err := ts.Float64Value([]int64{1})
if err != nil {
log.Fatal(err)
}
fmt.Println(res)
resInt64, err := ts.Int64Value([]int64{1})
if err != nil {
log.Fatal(err)
}
fmt.Println(resInt64)
grad, err := ts.RequiresGrad()
if err != nil {
log.Fatal(err)
}
fmt.Printf("Requires Grad: %v\n", grad)
ele1, err := ts.DataPtr()
if err != nil {
log.Fatal(err)
}
fmt.Printf("First element address: %v\n", ele1)
}

View File

@ -89,3 +89,27 @@ func AtcSetBenchmarkCudnn(b int) {
cb := *(*C.int)(unsafe.Pointer(&b))
C.atc_set_benchmark_cudnn(cb)
}
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_double_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*float64)(unsafe.Pointer(&retVal))
}
// int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) int64 {
ctensor := (C.tensor)(ts)
cindexes := (*C.long)(indexes)
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
retVal := C.at_int64_value_at_indexes(ctensor, cindexes, cindexesLen)
return *(*int64)(unsafe.Pointer(&retVal))
}
// int at_requires_grad(tensor);
func AtRequiresGrad(ts Ctensor) bool {
retVal := C.at_requires_grad((C.tensor)(ts))
return *(*bool)(unsafe.Pointer(&retVal))
}

View File

@ -27,18 +27,6 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) Device() (retVal gt.Device, err error) {
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gt.Device
return device.OfCInt(int32(cInt)), nil
}
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))

View File

@ -253,6 +253,18 @@ func (ts Tensor) DType() gotch.DType {
return dtype
}
func (ts Tensor) Device() (retVal gotch.Device, err error) {
cInt := lib.AtDevice(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gotch.Device
return device.OfCInt(int32(cInt)), nil
}
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
// Get a C null pointer
@ -268,3 +280,63 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
// DoubleValue returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
retVal = lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, err
}
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
idxPtr, err := DataAsPtr(idx)
if err != nil {
return retVal, err
}
defer C.free(unsafe.Pointer(idxPtr))
retVal = lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, err
}
// RequiresGrad returns true if gradient are currently tracked for this tensor.
func (ts Tensor) RequiresGrad() (retVal bool, err error) {
retVal = lib.AtRequiresGrad(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}
// DataPtr returns the address of the first element of this tensor.
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
retVal = lib.AtDataPtr(ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return retVal, nil
}