WIP(wrapper/tensor): added more Tensor methods
This commit is contained in:
parent
52643de1da
commit
1c1122c4ea
43
example/tensor1/main.go
Normal file
43
example/tensor1/main.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
wrapper "github.com/sugarme/gotch/wrapper"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
||||
ts, err := wrapper.OfSlice([]float64{1.3, 29.7})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
res, err := ts.Float64Value([]int64{1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(res)
|
||||
|
||||
resInt64, err := ts.Int64Value([]int64{1})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(resInt64)
|
||||
|
||||
grad, err := ts.RequiresGrad()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Requires Grad: %v\n", grad)
|
||||
|
||||
ele1, err := ts.DataPtr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Printf("First element address: %v\n", ele1)
|
||||
}
|
|
@ -89,3 +89,27 @@ func AtcSetBenchmarkCudnn(b int) {
|
|||
cb := *(*C.int)(unsafe.Pointer(&b))
|
||||
C.atc_set_benchmark_cudnn(cb)
|
||||
}
|
||||
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func AtDoubleValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) float64 {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cindexes := (*C.long)(indexes)
|
||||
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
|
||||
retVal := C.at_double_value_at_indexes(ctensor, cindexes, cindexesLen)
|
||||
return *(*float64)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int64_t at_int64_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func AtInt64ValueAtIndexes(ts Ctensor, indexes unsafe.Pointer, indexesLen int) int64 {
|
||||
ctensor := (C.tensor)(ts)
|
||||
cindexes := (*C.long)(indexes)
|
||||
cindexesLen := *(*C.int)(unsafe.Pointer(&indexesLen))
|
||||
retVal := C.at_int64_value_at_indexes(ctensor, cindexes, cindexesLen)
|
||||
return *(*int64)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
||||
// int at_requires_grad(tensor);
|
||||
func AtRequiresGrad(ts Ctensor) bool {
|
||||
retVal := C.at_requires_grad((C.tensor)(ts))
|
||||
return *(*bool)(unsafe.Pointer(&retVal))
|
||||
}
|
||||
|
|
|
@ -27,18 +27,6 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Device() (retVal gt.Device, err error) {
|
||||
cInt := lib.AtDevice(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
var device gt.Device
|
||||
|
||||
return device.OfCInt(int32(cInt)), nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
|
|
@ -253,6 +253,18 @@ func (ts Tensor) DType() gotch.DType {
|
|||
return dtype
|
||||
}
|
||||
|
||||
func (ts Tensor) Device() (retVal gotch.Device, err error) {
|
||||
cInt := lib.AtDevice(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
var device gotch.Device
|
||||
|
||||
return device.OfCInt(int32(cInt)), nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
||||
|
||||
// Get a C null pointer
|
||||
|
@ -268,3 +280,63 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
|
||||
}
|
||||
|
||||
// DoubleValue returns a float value on tensors holding a single element.
|
||||
// An error is returned otherwise.
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
|
||||
|
||||
idxPtr, err := DataAsPtr(idx)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
defer C.free(unsafe.Pointer(idxPtr))
|
||||
|
||||
retVal = lib.AtDoubleValueAtIndexes(ts.ctensor, idxPtr, len(idx))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// Int64Value returns an int value on tensors holding a single element. An error is
|
||||
// returned otherwise.
|
||||
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
|
||||
|
||||
idxPtr, err := DataAsPtr(idx)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
defer C.free(unsafe.Pointer(idxPtr))
|
||||
|
||||
retVal = lib.AtInt64ValueAtIndexes(ts.ctensor, idxPtr, len(idx))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// RequiresGrad returns true if gradient are currently tracked for this tensor.
|
||||
func (ts Tensor) RequiresGrad() (retVal bool, err error) {
|
||||
retVal = lib.AtRequiresGrad(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// DataPtr returns the address of the first element of this tensor.
|
||||
func (ts Tensor) DataPtr() (retVal unsafe.Pointer, err error) {
|
||||
|
||||
retVal = lib.AtDataPtr(ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user