feat(tensor/index_test.go): added

This commit is contained in:
sugarme 2020-07-10 15:28:37 +10:00
parent 1843dd5b90
commit 853cd73d3a
4 changed files with 179 additions and 12 deletions

View File

@ -5,6 +5,7 @@
- [Quick start guide](how-quickstart.md)
- [Installation](how-installation.md)
- [Cloud training](how-cloud.md)
- [Use pre-trained model](how-pretrained.md)
## General
@ -26,6 +27,7 @@
- [Linear](nn-linear.md)
- [Convolutional Neural Network](nn-cnn.md)
- [Recurrent Neural Network](nn-rnn.md)
- [JIT](nn-jit.md)
## Vision

View File

@ -254,7 +254,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
switch reflect.TypeOf(spec).Name() {
case "InsertNewAxis":
fmt.Println(currIdx)
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
if err != nil {
return retVal, err

View File

@ -8,19 +8,76 @@ import (
ts "github.com/sugarme/gotch/tensor"
)
func TestNewInsertAxis(t *testing.T) {
func TestIntegerIndex(t *testing.T) {
// [ 0 1 2
// 3 4 5 ]
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
// tensor, err := ts.NewTensorFromData([]bool{true, false, false, false, false, false}, []int64{2, 3})
// if err != nil {
// panic(err)
// }
idx1 := []ts.TensorIndexer{
ts.NewSelect(1),
}
result1 := tensor.Idx(idx1)
want1 := []int64{3, 4, 5}
want1Shape := []int64{3}
got1 := result1.Vals()
got1Shape := result1.MustSize()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
if !reflect.DeepEqual(want1Shape, got1Shape) {
t.Errorf("Expected tensor values: %v\n", want1Shape)
t.Errorf("Got tensor values: %v\n", got1Shape)
}
idx2 := []ts.TensorIndexer{
ts.NewNarrow(0, tensor.MustSize()[0]),
ts.NewSelect(2),
}
result2 := tensor.Idx(idx2)
want2 := []int64{2, 5}
want2Shape := []int64{2}
got2 := result2.Vals()
got2Shape := result2.MustSize()
if !reflect.DeepEqual(want2, got2) {
t.Errorf("Expected tensor values: %v\n", want2)
t.Errorf("Got tensor values: %v\n", got2)
}
if !reflect.DeepEqual(want2Shape, got2Shape) {
t.Errorf("Expected tensor values: %v\n", want2Shape)
t.Errorf("Got tensor values: %v\n", got2Shape)
}
idx3 := []ts.TensorIndexer{
ts.NewNarrow(0, tensor.MustSize()[0]),
ts.NewSelect(-2),
}
result3 := tensor.Idx(idx3)
want3 := []int64{1, 4}
want3Shape := []int64{2}
got3 := result3.Vals()
got3Shape := result3.MustSize()
if !reflect.DeepEqual(want3, got3) {
t.Errorf("Expected tensor values: %v\n", want3)
t.Errorf("Got tensor values: %v\n", got3)
}
if !reflect.DeepEqual(want3Shape, got3Shape) {
t.Errorf("Expected tensor values: %v\n", want3Shape)
t.Errorf("Got tensor values: %v\n", got3Shape)
}
}
func TestNewInsertAxis(t *testing.T) {
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
var idxs1 []ts.TensorIndexer = []ts.TensorIndexer{
ts.NewInsertNewAxis(),
}
result1 := tensor.Idx(idxs1)
want1 := []int64{1, 2, 3}
got1 := result1.MustSize()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected a tensor shape: %v\n", want1)
t.Errorf("Got a tensor shape: %v\n", got1)
@ -30,12 +87,9 @@ func TestNewInsertAxis(t *testing.T) {
ts.NewNarrow(0, tensor.MustSize()[0]),
ts.NewInsertNewAxis(),
}
result2 := tensor.Idx(idxs2)
want2 := []int64{2, 1, 3}
got2 := result2.MustSize()
if !reflect.DeepEqual(want2, got2) {
t.Errorf("Expected a tensor shape: %v\n", want2)
t.Errorf("Got a tensor shape: %v\n", got2)
@ -46,14 +100,94 @@ func TestNewInsertAxis(t *testing.T) {
ts.NewNarrow(0, tensor.MustSize()[1]),
ts.NewInsertNewAxis(),
}
result3 := tensor.Idx(idxs3)
want3 := []int64{2, 3, 1}
got3 := result3.MustSize()
if !reflect.DeepEqual(want3, got3) {
t.Errorf("Expected a tensor shape: %v\n", want3)
t.Errorf("Got a tensor shape: %v\n", got3)
}
}
func TestRangeIndex(t *testing.T) {
// Range
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
idx1 := []ts.TensorIndexer{
ts.NewNarrow(1, 3),
}
result1 := tensor1.Idx(idx1)
want1 := []int64{3, 4, 5, 6, 7, 8}
want1Shape := []int64{2, 3}
got1 := result1.Vals()
got1Shape := result1.MustSize()
if !reflect.DeepEqual(want1, got1) {
t.Errorf("Expected tensor values: %v\n", want1)
t.Errorf("Got tensor values: %v\n", got1)
}
if !reflect.DeepEqual(want1Shape, got1Shape) {
t.Errorf("Expected tensor values: %v\n", want1Shape)
t.Errorf("Got tensor values: %v\n", got1Shape)
}
// Full range
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
idx2 := []ts.TensorIndexer{
ts.NewNarrow(0, tensor2.MustSize()[0]),
}
result2 := tensor2.Idx(idx2)
want2 := []int64{0, 1, 2, 3, 4, 5}
want2Shape := []int64{2, 3}
got2 := result2.Vals()
got2Shape := result2.MustSize()
if !reflect.DeepEqual(want2, got2) {
t.Errorf("Expected tensor values: %v\n", want2)
t.Errorf("Got tensor values: %v\n", got2)
}
if !reflect.DeepEqual(want2Shape, got2Shape) {
t.Errorf("Expected tensor values: %v\n", want2Shape)
t.Errorf("Got tensor values: %v\n", got2Shape)
}
// Range from
tensor3 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
idx3 := []ts.TensorIndexer{
ts.NewNarrow(2, tensor3.MustSize()[0]),
}
result3 := tensor3.Idx(idx3)
want3 := []int64{6, 7, 8, 9, 10, 11}
want3Shape := []int64{2, 3}
got3 := result3.Vals()
got3Shape := result3.MustSize()
if !reflect.DeepEqual(want3, got3) {
t.Errorf("Expected tensor values: %v\n", want3)
t.Errorf("Got tensor values: %v\n", got3)
}
if !reflect.DeepEqual(want3Shape, got3Shape) {
t.Errorf("Expected tensor values: %v\n", want3Shape)
t.Errorf("Got tensor values: %v\n", got3Shape)
}
// Range to
tensor4 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
idx4 := []ts.TensorIndexer{
ts.NewNarrow(0, 2),
}
result4 := tensor4.Idx(idx4)
want4 := []int64{0, 1, 2, 3, 4, 5}
want4Shape := []int64{2, 3}
got4 := result4.Vals()
got4Shape := result4.MustSize()
if !reflect.DeepEqual(want4, got4) {
t.Errorf("Expected tensor values: %v\n", want4)
t.Errorf("Got tensor values: %v\n", got4)
}
if !reflect.DeepEqual(want4Shape, got4Shape) {
t.Errorf("Expected tensor values: %v\n", want4Shape)
t.Errorf("Got tensor values: %v\n", got4Shape)
}
}
func TestSliceIndex(t *testing.T) {
// TODO
}

View File

@ -1029,6 +1029,38 @@ func (ts Tensor) Values() []float64 {
return values
}
// Vals returns tensor values in a slice
// NOTE: need a type insersion to get runtime type
// E.g. res := xs.Vals().([]int64)
func (ts Tensor) Vals() (retVal interface{}) {
dtype := ts.DType()
numel := ts.Numel()
switch dtype.Name() {
case "uint8":
retVal = make([]uint8, numel)
case "int8":
retVal = make([]int8, numel)
case "int16":
retVal = make([]int16, numel)
case "int32":
retVal = make([]int32, numel)
case "int64":
retVal = make([]int64, numel)
case "float32":
retVal = make([]float32, numel)
case "float64":
retVal = make([]float64, numel)
case "bool":
retVal = make([]bool, numel)
default:
log.Fatalf("Unsupported dtype (%v)", dtype)
}
ts.CopyData(retVal, numel)
return retVal
}
// FlatView flattens a tensor.
//
// This returns a flattened version of the given tensor. The first dimension