feat(tensor/index): completed indexing unit tests
This commit is contained in:
parent
853cd73d3a
commit
8704adc867
|
@ -101,6 +101,12 @@ func NewInsertNewAxis() InsertNewAxis {
|
|||
return InsertNewAxis{}
|
||||
}
|
||||
|
||||
func NewSliceIndex(sl []int64) IndexSelect {
|
||||
ts := MustOfSlice(sl)
|
||||
|
||||
return IndexSelect{Index: ts}
|
||||
}
|
||||
|
||||
// type SelectFn func(int64)
|
||||
// type NarrowFn func(from int64, to int64)
|
||||
// type IndexSelectFn func(ts Tensor)
|
||||
|
|
|
@ -189,5 +189,111 @@ func TestRangeIndex(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
// TODO
|
||||
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(6*2), gotch.Int64, gotch.CPU).MustView([]int64{6, 2}, true)
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewSliceIndex([]int64{1, 3, 5}),
|
||||
}
|
||||
result1 := tensor1.Idx(idx1)
|
||||
want1 := []int64{2, 3, 6, 7, 10, 11}
|
||||
want1Shape := []int64{3, 2}
|
||||
got1 := result1.Vals()
|
||||
got1Shape := result1.MustSize()
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
if !reflect.DeepEqual(want1Shape, got1Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got1Shape)
|
||||
}
|
||||
|
||||
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(3*4), gotch.Int64, gotch.CPU).MustView([]int64{3, 4}, true)
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor2.MustSize()[0]),
|
||||
ts.NewSliceIndex([]int64{3, 0}),
|
||||
}
|
||||
result2 := tensor2.Idx(idx2)
|
||||
want2 := []int64{3, 0, 7, 4, 11, 8}
|
||||
want2Shape := []int64{3, 2}
|
||||
got2 := result2.Vals()
|
||||
got2Shape := result2.MustSize()
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2)
|
||||
t.Errorf("Got tensor values: %v\n", got2)
|
||||
}
|
||||
if !reflect.DeepEqual(want2Shape, got2Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got2Shape)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestComplexIndex(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3*5*7), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 5, 7}, true)
|
||||
idx := []ts.TensorIndexer{
|
||||
ts.NewSelect(1),
|
||||
ts.NewNarrow(1, 2),
|
||||
ts.NewSliceIndex([]int64{2, 3, 0}),
|
||||
ts.NewInsertNewAxis(),
|
||||
ts.NewNarrow(3, tensor.MustSize()[3]),
|
||||
}
|
||||
result := tensor.Idx(idx)
|
||||
want := []int64{157, 158, 159, 160, 164, 165, 166, 167, 143, 144, 145, 146}
|
||||
wantShape := []int64{1, 3, 1, 4}
|
||||
got := result.Vals()
|
||||
gotShape := result.MustSize()
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("Expected tensor values: %v\n", want)
|
||||
t.Errorf("Got tensor values: %v\n", got)
|
||||
}
|
||||
if !reflect.DeepEqual(wantShape, gotShape) {
|
||||
t.Errorf("Expected tensor values: %v\n", wantShape)
|
||||
t.Errorf("Got tensor values: %v\n", gotShape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIndex3D(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(24), gotch.Int64, gotch.CPU).MustView([]int64{2, 3, 4}, true)
|
||||
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewSelect(0),
|
||||
ts.NewSelect(0),
|
||||
ts.NewSelect(0),
|
||||
}
|
||||
result1 := tensor.Idx(idx1).MustView([]int64{1}, true)
|
||||
|
||||
want1 := []int64{0}
|
||||
got1 := result1.Vals()
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewSelect(1),
|
||||
ts.NewSelect(0),
|
||||
ts.NewSelect(0),
|
||||
}
|
||||
result2 := tensor.Idx(idx2).MustView([]int64{1}, true)
|
||||
|
||||
want2 := []int64{12}
|
||||
got2 := result2.Vals()
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2)
|
||||
t.Errorf("Got tensor values: %v\n", got2)
|
||||
}
|
||||
|
||||
idx3 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, 2),
|
||||
ts.NewSelect(0),
|
||||
ts.NewSelect(0),
|
||||
}
|
||||
result3 := tensor.Idx(idx3)
|
||||
|
||||
want3 := []int64{0, 12}
|
||||
got3 := result3.Vals()
|
||||
if !reflect.DeepEqual(want3, got3) {
|
||||
t.Errorf("Expected tensor values: %v\n", want3)
|
||||
t.Errorf("Got tensor values: %v\n", got3)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user