feat(tensor): unit test
This commit is contained in:
parent
8b05753eb4
commit
600adf506a
|
@ -270,6 +270,16 @@ func AtgDiv(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|||
C.atg_div(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_div_(tensor *, tensor self, tensor other);
|
||||
func AtgDiv_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_div_(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_div_1(tensor *, tensor self, scalar other);
|
||||
func AtgDiv1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_div_1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
||||
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
||||
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
||||
|
@ -358,6 +368,11 @@ func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|||
C.atg_sub_(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_sub_1(tensor *, tensor self, scalar other);
|
||||
func AtgSub1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_sub_1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
||||
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||
|
|
|
@ -3,31 +3,37 @@ package tensor
|
|||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"reflect"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
)
|
||||
|
||||
type Iterator interface {
|
||||
Next() interface{}
|
||||
Next() (item interface{}, ok bool)
|
||||
}
|
||||
|
||||
type Iterable struct {
|
||||
Index int64
|
||||
Len int64
|
||||
Content Tensor
|
||||
ItemKind reflect.Kind
|
||||
ItemKind gotch.DType
|
||||
}
|
||||
|
||||
// Next implements Iterator interface
|
||||
func (it *Iterable) Next() (retVal interface{}) {
|
||||
func (it *Iterable) Next() (retVal interface{}, ok bool) {
|
||||
|
||||
if it.Index == it.Len {
|
||||
return retVal, false
|
||||
}
|
||||
|
||||
var err error
|
||||
switch it.ItemKind {
|
||||
case reflect.Int64:
|
||||
switch it.ItemKind.Kind().String() {
|
||||
case "int64":
|
||||
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
it.Index += 1
|
||||
case reflect.Float64:
|
||||
case "float64":
|
||||
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
|
@ -38,24 +44,25 @@ func (it *Iterable) Next() (retVal interface{}) {
|
|||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
return retVal, true
|
||||
}
|
||||
|
||||
// Iter creates an iterable object with specified item type.
|
||||
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
|
||||
func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) {
|
||||
num, err := ts.Size1() // size for 1D tensor
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
content, err := ts.ShallowClone()
|
||||
tmp, err := ts.ShallowClone()
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
content := tmp.MustTotype(dtype, true)
|
||||
|
||||
return Iterable{
|
||||
Index: 0,
|
||||
Len: num,
|
||||
Content: content,
|
||||
ItemKind: kind,
|
||||
ItemKind: dtype,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -883,6 +883,22 @@ func (ts Tensor) MustDiv(other Tensor, del bool) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Div_(other Tensor) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgDiv_(ptr, ts.ctensor, other.ctensor)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Div1_(other Scalar) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgDiv1_(ptr, ts.ctensor, other.cscalar)
|
||||
if err := TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
|
@ -1221,6 +1237,15 @@ func (ts Tensor) Sub_(other Tensor) {
|
|||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Sub1_(other Scalar) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
lib.AtgSub1_(ptr, ts.ctensor, other.cscalar)
|
||||
err := TorchErr()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
|
|
87
tensor/tensor_test.go
Normal file
87
tensor/tensor_test.go
Normal file
|
@ -0,0 +1,87 @@
|
|||
package tensor_test
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestInplaceAssign(t *testing.T) {
|
||||
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
|
||||
|
||||
tensor.Add1_(ts.IntScalar(1))
|
||||
tensor.Mul1_(ts.IntScalar(2))
|
||||
tensor.Sub1_(ts.IntScalar(1))
|
||||
|
||||
want := []int64{7, 3, 9, 3, 11}
|
||||
got := tensor.Vals()
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Expected tensor values: %v\n", want)
|
||||
t.Errorf("Got tensor values: %v\n", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantOp(t *testing.T) {
|
||||
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
|
||||
resTs1 := tensor.MustMul1(ts.IntScalar(-1), true)
|
||||
|
||||
want1 := []int64{-3, -9, -3, -11}
|
||||
got1 := resTs1.Vals()
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
|
||||
// TODO: more ops
|
||||
|
||||
}
|
||||
|
||||
func TestIter(t *testing.T) {
|
||||
|
||||
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
|
||||
|
||||
iter, err := tensor.Iter(gotch.Int64)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
want := []int64{3, 9, 3, 11}
|
||||
var got []int64
|
||||
|
||||
for {
|
||||
item, ok := iter.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
got = append(got, item.(int64))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Expected tensor values: %v\n", want)
|
||||
t.Errorf("Got tensor values: %v\n", got)
|
||||
}
|
||||
|
||||
tensor1 := ts.MustOfSlice([]float64{3.14, 15.926, 5.3589, 79.0})
|
||||
iter1, err := tensor1.Iter(gotch.Double)
|
||||
|
||||
want1 := []float64{3.14, 15.926, 5.3589, 79.0}
|
||||
var got1 []float64
|
||||
for {
|
||||
item, ok := iter1.Next()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
got1 = append(got1, item.(float64))
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: more tests
|
Loading…
Reference in New Issue
Block a user