feat(tensor): unit test
This commit is contained in:
parent
8b05753eb4
commit
600adf506a
|
@ -270,6 +270,16 @@ func AtgDiv(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_div(ptr, self, other)
|
C.atg_div(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_div_(tensor *, tensor self, tensor other);
|
||||||
|
func AtgDiv_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
|
C.atg_div_(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_div_1(tensor *, tensor self, scalar other);
|
||||||
|
func AtgDiv1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||||
|
C.atg_div_1(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
// void atg_randperm(tensor *, int64_t n, int options_kind, int options_device);
|
||||||
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
func AtgRandperm(ptr *Ctensor, n int64, optionKind int32, optionDevice int32) {
|
||||||
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
cn := *(*C.int64_t)(unsafe.Pointer(&n))
|
||||||
|
@ -358,6 +368,11 @@ func AtgSub_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_sub_(ptr, self, other)
|
C.atg_sub_(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_sub_1(tensor *, tensor self, scalar other);
|
||||||
|
func AtgSub1_(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||||
|
C.atg_sub_1(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
// void atg_conv1d(tensor *, tensor input, tensor weight, tensor bias, int64_t *stride_data, int stride_len, int64_t *padding_data, int padding_len, int64_t *dilation_data, int dilation_len, int64_t groups);
|
||||||
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
func AtgConv1d(ptr *Ctensor, input Ctensor, weight Ctensor, bias Ctensor, strideData []int64, strideLen int, paddingData []int64, paddingLen int, dilationData []int64, dilationLen int, groups int64) {
|
||||||
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
cstrideDataPtr := (*C.int64_t)(unsafe.Pointer(&strideData[0]))
|
||||||
|
|
|
@ -3,31 +3,37 @@ package tensor
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"reflect"
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Iterator interface {
|
type Iterator interface {
|
||||||
Next() interface{}
|
Next() (item interface{}, ok bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Iterable struct {
|
type Iterable struct {
|
||||||
Index int64
|
Index int64
|
||||||
Len int64
|
Len int64
|
||||||
Content Tensor
|
Content Tensor
|
||||||
ItemKind reflect.Kind
|
ItemKind gotch.DType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Next implements Iterator interface
|
// Next implements Iterator interface
|
||||||
func (it *Iterable) Next() (retVal interface{}) {
|
func (it *Iterable) Next() (retVal interface{}, ok bool) {
|
||||||
|
|
||||||
|
if it.Index == it.Len {
|
||||||
|
return retVal, false
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
switch it.ItemKind {
|
switch it.ItemKind.Kind().String() {
|
||||||
case reflect.Int64:
|
case "int64":
|
||||||
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
retVal, err = it.Content.Int64Value([]int64{it.Index})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
it.Index += 1
|
it.Index += 1
|
||||||
case reflect.Float64:
|
case "float64":
|
||||||
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
retVal, err = it.Content.Float64Value([]int64{it.Index})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
|
@ -38,24 +44,25 @@ func (it *Iterable) Next() (retVal interface{}) {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return retVal
|
return retVal, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iter creates an iterable object with specified item type.
|
// Iter creates an iterable object with specified item type.
|
||||||
func (ts Tensor) Iter(kind reflect.Kind) (retVal Iterable, err error) {
|
func (ts Tensor) Iter(dtype gotch.DType) (retVal Iterable, err error) {
|
||||||
num, err := ts.Size1() // size for 1D tensor
|
num, err := ts.Size1() // size for 1D tensor
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
content, err := ts.ShallowClone()
|
tmp, err := ts.ShallowClone()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
content := tmp.MustTotype(dtype, true)
|
||||||
|
|
||||||
return Iterable{
|
return Iterable{
|
||||||
Index: 0,
|
Index: 0,
|
||||||
Len: num,
|
Len: num,
|
||||||
Content: content,
|
Content: content,
|
||||||
ItemKind: kind,
|
ItemKind: dtype,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -883,6 +883,22 @@ func (ts Tensor) MustDiv(other Tensor, del bool) (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Div_(other Tensor) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
lib.AtgDiv_(ptr, ts.ctensor, other.ctensor)
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Div1_(other Scalar) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
lib.AtgDiv1_(ptr, ts.ctensor, other.cscalar)
|
||||||
|
if err := TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
func Randperm(n int64, optionKind gotch.DType, optionDevice gotch.Device) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
@ -1221,6 +1237,15 @@ func (ts Tensor) Sub_(other Tensor) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Sub1_(other Scalar) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
lib.AtgSub1_(ptr, ts.ctensor, other.cscalar)
|
||||||
|
err := TorchErr()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
func Conv1D(input, weight, bias Tensor, stride, padding, dilation []int64, groups int64) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
|
87
tensor/tensor_test.go
Normal file
87
tensor/tensor_test.go
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
package tensor_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInplaceAssign(t *testing.T) {
|
||||||
|
tensor := ts.MustOfSlice([]int64{3, 1, 4, 1, 5})
|
||||||
|
|
||||||
|
tensor.Add1_(ts.IntScalar(1))
|
||||||
|
tensor.Mul1_(ts.IntScalar(2))
|
||||||
|
tensor.Sub1_(ts.IntScalar(1))
|
||||||
|
|
||||||
|
want := []int64{7, 3, 9, 3, 11}
|
||||||
|
got := tensor.Vals()
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("Expected tensor values: %v\n", want)
|
||||||
|
t.Errorf("Got tensor values: %v\n", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConstantOp(t *testing.T) {
|
||||||
|
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
|
||||||
|
resTs1 := tensor.MustMul1(ts.IntScalar(-1), true)
|
||||||
|
|
||||||
|
want1 := []int64{-3, -9, -3, -11}
|
||||||
|
got1 := resTs1.Vals()
|
||||||
|
if !reflect.DeepEqual(want1, got1) {
|
||||||
|
t.Errorf("Expected tensor values: %v\n", want1)
|
||||||
|
t.Errorf("Got tensor values: %v\n", got1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: more ops
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIter(t *testing.T) {
|
||||||
|
|
||||||
|
tensor := ts.MustOfSlice([]int64{3, 9, 3, 11})
|
||||||
|
|
||||||
|
iter, err := tensor.Iter(gotch.Int64)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []int64{3, 9, 3, 11}
|
||||||
|
var got []int64
|
||||||
|
|
||||||
|
for {
|
||||||
|
item, ok := iter.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
got = append(got, item.(int64))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("Expected tensor values: %v\n", want)
|
||||||
|
t.Errorf("Got tensor values: %v\n", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor1 := ts.MustOfSlice([]float64{3.14, 15.926, 5.3589, 79.0})
|
||||||
|
iter1, err := tensor1.Iter(gotch.Double)
|
||||||
|
|
||||||
|
want1 := []float64{3.14, 15.926, 5.3589, 79.0}
|
||||||
|
var got1 []float64
|
||||||
|
for {
|
||||||
|
item, ok := iter1.Next()
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
got1 = append(got1, item.(float64))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(want1, got1) {
|
||||||
|
t.Errorf("Expected tensor values: %v\n", want1)
|
||||||
|
t.Errorf("Got tensor values: %v\n", got1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: more tests
|
Loading…
Reference in New Issue
Block a user