feat(error result): added Go return pattern of (result, error)
This commit is contained in:
parent
ff9ae65229
commit
52643de1da
|
@ -12,8 +12,8 @@ func main() {
|
|||
// Try to compare 2 tensor with incompatible dimensions
|
||||
// and check this returns an error
|
||||
dx := []int32{1, 2, 3}
|
||||
// dy := []int32{1, 2, 3, 4}
|
||||
dy := []int32{1, 2, 5}
|
||||
dy := []int32{1, 2, 3, 4}
|
||||
// dy := []int32{1, 2, 5}
|
||||
|
||||
xs, err := wrapper.OfSlice(dx)
|
||||
if err != nil {
|
||||
|
@ -30,10 +30,24 @@ func main() {
|
|||
fmt.Printf("xs num of dimensions: %v\n", xs.Dim())
|
||||
fmt.Printf("ys num of dimensions: %v\n", ys.Dim())
|
||||
|
||||
fmt.Printf("xs shape: %v\n", xs.Size())
|
||||
fmt.Printf("ys shape: %v\n", ys.Size())
|
||||
xsize, err := xs.Size()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
xs.Eq1(*ys)
|
||||
ysize, err := ys.Size()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("xs shape: %v\n", xsize)
|
||||
fmt.Printf("ys shape: %v\n", ysize)
|
||||
|
||||
res, err := xs.Eq1(ys)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
res.Print()
|
||||
|
||||
// xs.Matmul(*ys)
|
||||
}
|
||||
|
|
|
@ -74,8 +74,14 @@ func main() {
|
|||
device := gotch.NewCuda()
|
||||
startGPUTime := time.Now()
|
||||
for i := 1; i < 100000; i++ {
|
||||
cx := xs.To(device)
|
||||
cy := ys.To(device)
|
||||
cx, err := xs.To(device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cy, err := ys.To(device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cx.Matmul(cy)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,15 +5,13 @@ package wrapper
|
|||
import "C"
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
"log"
|
||||
"unsafe"
|
||||
|
||||
gt "github.com/sugarme/gotch"
|
||||
lib "github.com/sugarme/gotch/libtch"
|
||||
)
|
||||
|
||||
func (ts Tensor) To(device gt.Device) Tensor {
|
||||
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
||||
|
||||
// TODO: how to get pointer to CUDA memory???
|
||||
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
|
||||
|
@ -22,21 +20,33 @@ func (ts Tensor) To(device gt.Device) Tensor {
|
|||
|
||||
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Device() gt.Device {
|
||||
func (ts Tensor) Device() (retVal gt.Device, err error) {
|
||||
cInt := lib.AtDevice(ts.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
var device gt.Device
|
||||
|
||||
return device.OfCInt(int32(cInt))
|
||||
return device.OfCInt(int32(cInt)), nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgMatmul(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
package wrapper
|
||||
|
||||
// #include <stdlib.h>
|
||||
//#include "stdbool.h"
|
||||
// #include "../libtch/torch_api.h"
|
||||
// #include "stdlib.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
|
@ -40,27 +38,34 @@ func (ts Tensor) Dim() uint64 {
|
|||
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
|
||||
// And returns a slice of sizes or shape using given pointer
|
||||
// to that slice.
|
||||
func (ts Tensor) Size() []int64 {
|
||||
func (ts Tensor) Size() (retVal []int64, err error) {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
// TODO: should we free C memory here or at `DataAsPtr` func
|
||||
defer C.free(unsafe.Pointer(szPtr))
|
||||
|
||||
lib.AtShape(ts.ctensor, szPtr)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal := decodeSize(szPtr, dim)
|
||||
return retVal
|
||||
} // Size1 returns the tensor size for 1D tensors.
|
||||
retVal = decodeSize(szPtr, dim)
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
func (ts Tensor) Size1() (retVal int64, err error) {
|
||||
shape := ts.Size()
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
if len(shape) != 1 {
|
||||
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
|
||||
return 0, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return shape[0], nil
|
||||
|
@ -68,10 +73,14 @@ func (ts Tensor) Size1() (retVal int64, err error) {
|
|||
|
||||
// Size2 returns the tensor size for 2D tensors.
|
||||
func (ts Tensor) Size2() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
if len(shape) != 2 {
|
||||
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
|
@ -79,10 +88,14 @@ func (ts Tensor) Size2() (retVal []int64, err error) {
|
|||
|
||||
// Size3 returns the tensor size for 3D tensors.
|
||||
func (ts Tensor) Size3() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
if len(shape) != 3 {
|
||||
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
|
@ -90,10 +103,14 @@ func (ts Tensor) Size3() (retVal []int64, err error) {
|
|||
|
||||
// Size4 returns the tensor size for 4D tensors.
|
||||
func (ts Tensor) Size4() (retVal []int64, err error) {
|
||||
shape := ts.Size()
|
||||
shape, err := ts.Size()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
if len(shape) != 4 {
|
||||
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return shape, nil
|
||||
|
@ -119,26 +136,17 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
|
|||
return dataIn
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for single dimension tensor
|
||||
// func (ts Tensor) Size1() {
|
||||
//
|
||||
// shape := ts.Size()
|
||||
//
|
||||
// fmt.Printf("shape: %v\n", shape)
|
||||
//
|
||||
// }
|
||||
|
||||
// OfSlice creates tensor from a slice data
|
||||
func OfSlice(data interface{}) (retVal *Tensor, err error) {
|
||||
func OfSlice(data interface{}) (retVal Tensor, err error) {
|
||||
|
||||
typ, dataLen, err := DataCheck(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
dtype, err := gotch.ToDType(typ)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
shape := []int64{int64(dataLen)}
|
||||
|
@ -146,25 +154,29 @@ func OfSlice(data interface{}) (retVal *Tensor, err error) {
|
|||
|
||||
eltSizeInBytes, err := gotch.DTypeSize(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
nbytes := int(eltSizeInBytes) * int(elementNum)
|
||||
|
||||
dataPtr, buff := CMalloc(nbytes)
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
|
||||
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
cint, err := gotch.DType2CInt(dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = &Tensor{ctensor}
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
@ -175,6 +187,9 @@ func OfSlice(data interface{}) (retVal *Tensor, err error) {
|
|||
// with no truncation at all.
|
||||
func (ts Tensor) Print() {
|
||||
lib.AtPrint(ts.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewTensorFromData creates tensor from given data and shape
|
||||
|
@ -194,6 +209,7 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err erro
|
|||
|
||||
// 2. Write raw data to C memory and get C pointer
|
||||
dataPtr, err := DataAsPtr(data)
|
||||
defer C.free(unsafe.Pointer(dataPtr))
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
@ -215,6 +231,10 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err erro
|
|||
}
|
||||
|
||||
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
|
||||
// defer C.free(unsafe.Pointer(ctensor))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor}
|
||||
|
||||
|
@ -233,31 +253,18 @@ func (ts Tensor) DType() gotch.DType {
|
|||
return dtype
|
||||
}
|
||||
|
||||
func (ts Tensor) Eq1(other Tensor) {
|
||||
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
||||
|
||||
// Get a C null pointer
|
||||
// https://stackoverflow.com/a/2022369
|
||||
// ctensorPtr := C.malloc(0)
|
||||
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
// lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
|
||||
lib.AtgEq1(ctensorPtr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
lib.AtPrint(*ctensorPtr)
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) Tensor {
|
||||
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ctensorPtr))
|
||||
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ctensorPtr}
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user