feat(tensor): unit test

This commit is contained in:
sugarme 2020-07-11 16:42:27 +10:00
parent 8b05753eb4
commit 600adf506a
4 changed files with 145 additions and 11 deletions

View File

@ -270,6 +270,16 @@ func AtgDiv(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_div(ptr, self, other)
}
// void atg_div_(tensor *, tensor self, tensor other);
func AtgDiv_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_div_(ptr, self, other)
}
// void atg_div_1(tensor *, tensor self, scalar other);
func AtgDiv1_(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_div_1(ptr, self, other)
}
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
cn := *(*C.int64_t)(unsafe.Pointer(&n))
@ -358,6 +368,11 @@ func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_sub_(ptr, self, other)
}
// void atg_sub_1(tensor *, tensor self, scalar other);
func AtgSub1_(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_sub_1(ptr, self, other)
}
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))

View File

@ -3,31 +3,37 @@ package tensor
import (
"fmt"
"log"
"reflect"
"github.com/sugarme/gotch"
)
type Iterator interface {
Next() interface{}
Next() (item interface{}, ok bool)
}
type Iterable struct {
Index int64
Len int64
Content Tensor
ItemKind reflect.Kind
ItemKind gotch.DType
}
// Next implements Iterator interface
func (it *Iterable) Next() (retVal interface{}) {
func (it *Iterable) Next() (retVal interface{}, ok bool) {
if it.Index == it.Len {
return retVal, false
}
var err error
switch it.ItemKind {
case reflect.Int64:
switch it.ItemKind.Kind().String() {
case "int64":
retVal, err = it.Content.Int64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
}
it.Index += 1
case reflect.Float64:
case "float64":
retVal, err = it.Content.Float64Value([]int64{it.Index})
if err != nil {
log.Fatal(err)
@ -38,24 +44,25 @@ func (it *Iterable) Next() (retVal interface{}) {
log.Fatal(err)
}
return retVal
return retVal, true
}
// Iter creates an iterable object with specified item type.
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) {
num, err := ts.Size1() // size for 1D tensor
if err != nil {
return retVal, err
}
content, err := ts.ShallowClone()
tmp, err := ts.ShallowClone()
if err != nil {
return retVal, err
}
content := tmp.MustTotype(dtype, true)
return Iterable{
Index: 0,
Len: num,
Content: content,
ItemKind: kind,
ItemKind: dtype,
}, nil
}

View File

@ -883,6 +883,22 @@ func (ts Tensor) MustDiv(other Tensor, del bool) (retVal Tensor) {
return retVal
}
func (ts Tensor) Div_(other Tensor) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgDiv_(ptr, ts.ctensor, other.ctensor)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
func (ts Tensor) Div1_(other Scalar) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgDiv1_(ptr, ts.ctensor, other.cscalar)
if err := TorchErr(); err != nil {
log.Fatal(err)
}
}
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
@ -1221,6 +1237,15 @@ func (ts Tensor) Sub_(other Tensor) {
}
}
func (ts Tensor) Sub1_(other Scalar) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
lib.AtgSub1_(ptr, ts.ctensor, other.cscalar)
err := TorchErr()
if err != nil {
log.Fatal(err)
}
}
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))

87
tensor/tensor_test.go Normal file
View File

@ -0,0 +1,87 @@
package tensor_test
import (
"reflect"
"testing"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
func TestInplaceAssign(t *testing.T) {
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
tensor.Add1_(ts.IntScalar(1))
tensor.Mul1_(ts.IntScalar(2))
tensor.Sub1_(ts.IntScalar(1))
want := []int64{7, 3, 9, 3, 11}
got := tensor.Vals()
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected tensor values: %v\n", want)
t.Errorf("Got tensor values: %v\n", got)
}
}
func TestConstantOp(t *testing.T) {
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
resTs1 := tensor.MustMul1(ts.IntScalar(-1), true)
want1 := []int64{-3, -9, -3, -11}
got1 := resTs1.Vals()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
// TODO: more ops
}
func TestIter(t *testing.T) {
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
iter, err := tensor.Iter(gotch.Int64)
if err != nil {
panic(err)
}
want := []int64{3, 9, 3, 11}
var got []int64
for {
item, ok := iter.Next()
if !ok {
break
}
got = append(got, item.(int64))
}
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected tensor values: %v\n", want)
t.Errorf("Got tensor values: %v\n", got)
}
tensor1 := ts.MustOfSlice([]float64{3.14, 15.926, 5.3589, 79.0})
iter1, err := tensor1.Iter(gotch.Double)
want1 := []float64{3.14, 15.926, 5.3589, 79.0}
var got1 []float64
for {
item, ok := iter1.Next()
if !ok {
break
}
got1 = append(got1, item.(float64))
}
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
}
// TODO: more tests