feat(tensor/index): completed indexing unit tests

This commit is contained in:
sugarme 2020-07-10 16:50:57 +10:00
parent 853cd73d3a
commit 8704adc867
2 changed files with 113 additions and 1 deletions

View File

@ -101,6 +101,12 @@ func NewInsertNewAxis() InsertNewAxis {
return InsertNewAxis{}
}
func NewSliceIndex(sl []int64) IndexSelect {
ts := MustOfSlice(sl)
return IndexSelect{Index: ts}
}
// type SelectFn func(int64)
// type NarrowFn func(from int64, to int64)
// type IndexSelectFn func(ts Tensor)

View File

@ -189,5 +189,111 @@ func TestRangeIndex(t *testing.T) {
}
func TestSliceIndex(t *testing.T) {
// TODO
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(6*2), gotch.Int64, gotch.CPU).MustView([]int64{6, 2}, true)
idx1 := []ts.TensorIndexer{
ts.NewSliceIndex([]int64{1, 3, 5}),
}
result1 := tensor1.Idx(idx1)
want1 := []int64{2, 3, 6, 7, 10, 11}
want1Shape := []int64{3, 2}
got1 := result1.Vals()
got1Shape := result1.MustSize()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
if !reflect.DeepEqual(want1Shape, got1Shape) {
t.Errorf("Expected tensor values: %v\n", want1Shape)
t.Errorf("Got tensor values: %v\n", got1Shape)
}
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(3*4), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
idx2 := []ts.TensorIndexer{
ts.NewNarrow(0, tensor2.MustSize()[0]),
ts.NewSliceIndex([]int64{3, 0}),
}
result2 := tensor2.Idx(idx2)
want2 := []int64{3, 0, 7, 4, 11, 8}
want2Shape := []int64{3, 2}
got2 := result2.Vals()
got2Shape := result2.MustSize()
if !reflect.DeepEqual(want2, got2) {
t.Errorf("Expected tensor values: %v\n", want2)
t.Errorf("Got tensor values: %v\n", got2)
}
if !reflect.DeepEqual(want2Shape, got2Shape) {
t.Errorf("Expected tensor values: %v\n", want2Shape)
t.Errorf("Got tensor values: %v\n", got2Shape)
}
}
func TestComplexIndex(t *testing.T) {
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3*5*7), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 5, 7}, true)
idx := []ts.TensorIndexer{
ts.NewSelect(1),
ts.NewNarrow(1, 2),
ts.NewSliceIndex([]int64{2, 3, 0}),
ts.NewInsertNewAxis(),
ts.NewNarrow(3, tensor.MustSize()[3]),
}
result := tensor.Idx(idx)
want := []int64{157, 158, 159, 160, 164, 165, 166, 167, 143, 144, 145, 146}
wantShape := []int64{1, 3, 1, 4}
got := result.Vals()
gotShape := result.MustSize()
if !reflect.DeepEqual(want, got) {
t.Errorf("Expected tensor values: %v\n", want)
t.Errorf("Got tensor values: %v\n", got)
}
if !reflect.DeepEqual(wantShape, gotShape) {
t.Errorf("Expected tensor values: %v\n", wantShape)
t.Errorf("Got tensor values: %v\n", gotShape)
}
}
func TestIndex3D(t *testing.T) {
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(24), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
idx1 := []ts.TensorIndexer{
ts.NewSelect(0),
ts.NewSelect(0),
ts.NewSelect(0),
}
result1 := tensor.Idx(idx1).MustView([]int64{1}, true)
want1 := []int64{0}
got1 := result1.Vals()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
idx2 := []ts.TensorIndexer{
ts.NewSelect(1),
ts.NewSelect(0),
ts.NewSelect(0),
}
result2 := tensor.Idx(idx2).MustView([]int64{1}, true)
want2 := []int64{12}
got2 := result2.Vals()
if !reflect.DeepEqual(want2, got2) {
t.Errorf("Expected tensor values: %v\n", want2)
t.Errorf("Got tensor values: %v\n", got2)
}
idx3 := []ts.TensorIndexer{
ts.NewNarrow(0, 2),
ts.NewSelect(0),
ts.NewSelect(0),
}
result3 := tensor.Idx(idx3)
want3 := []int64{0, 12}
got3 := result3.Vals()
if !reflect.DeepEqual(want3, got3) {
t.Errorf("Expected tensor values: %v\n", want3)
t.Errorf("Got tensor values: %v\n", got3)
}
}