feat(error result): added Go return pattern of (result, error)

This commit is contained in:
sugarme 2020-06-06 13:20:00 +10:00
parent ff9ae65229
commit 52643de1da
4 changed files with 110 additions and 73 deletions

View File

@ -12,8 +12,8 @@ func main() {
// Try to compare 2 tensor with incompatible dimensions
// and check this returns an error
dx := []int32{1, 2, 3}
// dy := []int32{1, 2, 3, 4}
dy := []int32{1, 2, 5}
dy := []int32{1, 2, 3, 4}
// dy := []int32{1, 2, 5}
xs, err := wrapper.OfSlice(dx)
if err != nil {
@ -30,10 +30,24 @@ func main() {
fmt.Printf("xs num of dimensions: %v\n", xs.Dim())
fmt.Printf("ys num of dimensions: %v\n", ys.Dim())
fmt.Printf("xs shape: %v\n", xs.Size())
fmt.Printf("ys shape: %v\n", ys.Size())
xsize, err := xs.Size()
if err != nil {
log.Fatal(err)
}
xs.Eq1(*ys)
ysize, err := ys.Size()
if err != nil {
log.Fatal(err)
}
fmt.Printf("xs shape: %v\n", xsize)
fmt.Printf("ys shape: %v\n", ysize)
res, err := xs.Eq1(ys)
if err != nil {
log.Fatal(err)
}
res.Print()
// xs.Matmul(*ys)
}

View File

@ -74,8 +74,14 @@ func main() {
device := gotch.NewCuda()
startGPUTime := time.Now()
for i := 1; i < 100000; i++ {
cx := xs.To(device)
cy := ys.To(device)
cx, err := xs.To(device)
if err != nil {
log.Fatal(err)
}
cy, err := ys.To(device)
if err != nil {
log.Fatal(err)
}
cx.Matmul(cy)
}

View File

@ -5,15 +5,13 @@ package wrapper
import "C"
import (
// "fmt"
"log"
"unsafe"
gt "github.com/sugarme/gotch"
lib "github.com/sugarme/gotch/libtch"
)
func (ts Tensor) To(device gt.Device) Tensor {
func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
// TODO: how to get pointer to CUDA memory???
// C.cuMemAlloc((*C.ulonglong)(cudaPtr), 1) // 0 byte is invalid
@ -22,21 +20,33 @@ func (ts Tensor) To(device gt.Device) Tensor {
lib.AtgTo((*lib.Ctensor)(ptr), ts.ctensor, int(device.CInt()))
if err := TorchErr(); err != nil {
log.Fatal(err)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) Device() gt.Device {
func (ts Tensor) Device() (retVal gt.Device, err error) {
cInt := lib.AtDevice(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
if err = TorchErr(); err != nil {
return retVal, err
}
var device gt.Device
return device.OfCInt(int32(cInt))
return device.OfCInt(int32(cInt)), nil
}
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMatmul(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}

View File

@ -1,8 +1,6 @@
package wrapper
// #include <stdlib.h>
//#include "stdbool.h"
// #include "../libtch/torch_api.h"
// #include "stdlib.h"
import "C"
import (
@ -40,27 +38,34 @@ func (ts Tensor) Dim() uint64 {
// NOTE: C++ libtorch calls at_shape() -> t.sizes()
// And returns a slice of sizes or shape using given pointer
// to that slice.
func (ts Tensor) Size() []int64 {
func (ts Tensor) Size() (retVal []int64, err error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
log.Fatal(err)
return retVal, err
}
// TODO: should we free C memory here or at `DataAsPtr` func
defer C.free(unsafe.Pointer(szPtr))
lib.AtShape(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal := decodeSize(szPtr, dim)
return retVal
} // Size1 returns the tensor size for 1D tensors.
retVal = decodeSize(szPtr, dim)
return retVal, nil
}
// Size1 returns the tensor size for 1D tensors.
func (ts Tensor) Size1() (retVal int64, err error) {
shape := ts.Size()
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 1 {
err = fmt.Errorf("Expected one dim, got %v\n", len(shape))
return 0, err
return retVal, err
}
return shape[0], nil
@ -68,10 +73,14 @@ func (ts Tensor) Size1() (retVal int64, err error) {
// Size2 returns the tensor size for 2D tensors.
func (ts Tensor) Size2() (retVal []int64, err error) {
shape := ts.Size()
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 2 {
err = fmt.Errorf("Expected two dims, got %v\n", len(shape))
return nil, err
return retVal, err
}
return shape, nil
@ -79,10 +88,14 @@ func (ts Tensor) Size2() (retVal []int64, err error) {
// Size3 returns the tensor size for 3D tensors.
func (ts Tensor) Size3() (retVal []int64, err error) {
shape := ts.Size()
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 3 {
err = fmt.Errorf("Expected three dims, got %v\n", len(shape))
return nil, err
return retVal, err
}
return shape, nil
@ -90,10 +103,14 @@ func (ts Tensor) Size3() (retVal []int64, err error) {
// Size4 returns the tensor size for 4D tensors.
func (ts Tensor) Size4() (retVal []int64, err error) {
shape := ts.Size()
shape, err := ts.Size()
if err != nil {
return retVal, err
}
if len(shape) != 4 {
err = fmt.Errorf("Expected four dims, got %v\n", len(shape))
return nil, err
return retVal, err
}
return shape, nil
@ -119,26 +136,17 @@ func decodeSize(ptr unsafe.Pointer, nsize uint64) []int64 {
return dataIn
}
// Size1 returns the tensor size for single dimension tensor
// func (ts Tensor) Size1() {
//
// shape := ts.Size()
//
// fmt.Printf("shape: %v\n", shape)
//
// }
// OfSlice creates tensor from a slice data
func OfSlice(data interface{}) (retVal *Tensor, err error) {
func OfSlice(data interface{}) (retVal Tensor, err error) {
typ, dataLen, err := DataCheck(data)
if err != nil {
return nil, err
return retVal, err
}
dtype, err := gotch.ToDType(typ)
if err != nil {
return nil, err
return retVal, err
}
shape := []int64{int64(dataLen)}
@ -146,25 +154,29 @@ func OfSlice(data interface{}) (retVal *Tensor, err error) {
eltSizeInBytes, err := gotch.DTypeSize(dtype)
if err != nil {
return nil, err
return retVal, err
}
nbytes := int(eltSizeInBytes) * int(elementNum)
dataPtr, buff := CMalloc(nbytes)
defer C.free(unsafe.Pointer(dataPtr))
if err = EncodeTensor(buff, reflect.ValueOf(data), shape); err != nil {
return nil, err
return retVal, err
}
cint, err := gotch.DType2CInt(dtype)
if err != nil {
return nil, err
return retVal, err
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = &Tensor{ctensor}
retVal = Tensor{ctensor}
return retVal, nil
}
@ -175,6 +187,9 @@ func OfSlice(data interface{}) (retVal *Tensor, err error) {
// with no truncation at all.
func (ts Tensor) Print() {
lib.AtPrint(ts.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
// NewTensorFromData creates tensor from given data and shape
@ -194,6 +209,7 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err erro
// 2. Write raw data to C memory and get C pointer
dataPtr, err := DataAsPtr(data)
defer C.free(unsafe.Pointer(dataPtr))
if err != nil {
return retVal, err
}
@ -215,6 +231,10 @@ func NewTensorFromData(data interface{}, shape []int64) (retVal Tensor, err erro
}
ctensor := lib.AtTensorOfData(dataPtr, shape, uint(len(shape)), uint(eltSizeInBytes), int(cint))
// defer C.free(unsafe.Pointer(ctensor))
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor}
@ -233,31 +253,18 @@ func (ts Tensor) DType() gotch.DType {
return dtype
}
func (ts Tensor) Eq1(other Tensor) {
func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
// Get a C null pointer
// https://stackoverflow.com/a/2022369
// ctensorPtr := C.malloc(0)
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
// lib.Atg_eq1(unsafe.Pointer(ctensorPtr), ts.ctensor, other.ctensor)
lib.AtgEq1(ctensorPtr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
lib.AtgEq1(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
lib.AtPrint(*ctensorPtr)
}
func (ts Tensor) Matmul(other Tensor) Tensor {
ctensorPtr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ctensorPtr))
lib.AtgMatmul(ctensorPtr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
return Tensor{ctensor: *ctensorPtr}
return Tensor{ctensor: *ptr}, nil
}