feat(tensor/index_test.go): added
This commit is contained in:
parent
1843dd5b90
commit
853cd73d3a
|
@ -5,6 +5,7 @@
|
|||
- [Quick start guide](how-quickstart.md)
|
||||
- [Installation](how-installation.md)
|
||||
- [Cloud training](how-cloud.md)
|
||||
- [Use pre-trained model](how-pretrained.md)
|
||||
|
||||
## General
|
||||
|
||||
|
@ -26,6 +27,7 @@
|
|||
- [Linear](nn-linear.md)
|
||||
- [Convolutional Neural Network](nn-cnn.md)
|
||||
- [Recurrent Neural Network](nn-rnn.md)
|
||||
- [JIT](nn-jit.md)
|
||||
|
||||
## Vision
|
||||
|
||||
|
|
|
@ -254,7 +254,6 @@ func (ts Tensor) indexer(indexSpec []TensorIndexer) (retVal Tensor, err error) {
|
|||
|
||||
switch reflect.TypeOf(spec).Name() {
|
||||
case "InsertNewAxis":
|
||||
fmt.Println(currIdx)
|
||||
nextTensor, err = currTensor.Unsqueeze(currIdx, true)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
|
|
|
@ -8,19 +8,76 @@ import (
|
|||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
func TestNewInsertAxis(t *testing.T) {
|
||||
|
||||
func TestIntegerIndex(t *testing.T) {
|
||||
// [ 0 1 2
|
||||
// 3 4 5 ]
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
// tensor, err := ts.NewTensorFromData([]bool{true, false, false, false, false, false}, []int64{2, 3})
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewSelect(1),
|
||||
}
|
||||
result1 := tensor.Idx(idx1)
|
||||
want1 := []int64{3, 4, 5}
|
||||
want1Shape := []int64{3}
|
||||
got1 := result1.Vals()
|
||||
got1Shape := result1.MustSize()
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
if !reflect.DeepEqual(want1Shape, got1Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got1Shape)
|
||||
}
|
||||
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor.MustSize()[0]),
|
||||
ts.NewSelect(2),
|
||||
}
|
||||
result2 := tensor.Idx(idx2)
|
||||
want2 := []int64{2, 5}
|
||||
want2Shape := []int64{2}
|
||||
got2 := result2.Vals()
|
||||
got2Shape := result2.MustSize()
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2)
|
||||
t.Errorf("Got tensor values: %v\n", got2)
|
||||
}
|
||||
if !reflect.DeepEqual(want2Shape, got2Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got2Shape)
|
||||
}
|
||||
|
||||
idx3 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor.MustSize()[0]),
|
||||
ts.NewSelect(-2),
|
||||
}
|
||||
result3 := tensor.Idx(idx3)
|
||||
want3 := []int64{1, 4}
|
||||
want3Shape := []int64{2}
|
||||
got3 := result3.Vals()
|
||||
got3Shape := result3.MustSize()
|
||||
if !reflect.DeepEqual(want3, got3) {
|
||||
t.Errorf("Expected tensor values: %v\n", want3)
|
||||
t.Errorf("Got tensor values: %v\n", got3)
|
||||
}
|
||||
if !reflect.DeepEqual(want3Shape, got3Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want3Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got3Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInsertAxis(t *testing.T) {
|
||||
tensor := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
var idxs1 []ts.TensorIndexer = []ts.TensorIndexer{
|
||||
ts.NewInsertNewAxis(),
|
||||
}
|
||||
|
||||
result1 := tensor.Idx(idxs1)
|
||||
|
||||
want1 := []int64{1, 2, 3}
|
||||
got1 := result1.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected a tensor shape: %v\n", want1)
|
||||
t.Errorf("Got a tensor shape: %v\n", got1)
|
||||
|
@ -30,12 +87,9 @@ func TestNewInsertAxis(t *testing.T) {
|
|||
ts.NewNarrow(0, tensor.MustSize()[0]),
|
||||
ts.NewInsertNewAxis(),
|
||||
}
|
||||
|
||||
result2 := tensor.Idx(idxs2)
|
||||
|
||||
want2 := []int64{2, 1, 3}
|
||||
got2 := result2.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Expected a tensor shape: %v\n", want2)
|
||||
t.Errorf("Got a tensor shape: %v\n", got2)
|
||||
|
@ -46,14 +100,94 @@ func TestNewInsertAxis(t *testing.T) {
|
|||
ts.NewNarrow(0, tensor.MustSize()[1]),
|
||||
ts.NewInsertNewAxis(),
|
||||
}
|
||||
|
||||
result3 := tensor.Idx(idxs3)
|
||||
|
||||
want3 := []int64{2, 3, 1}
|
||||
got3 := result3.MustSize()
|
||||
|
||||
if !reflect.DeepEqual(want3, got3) {
|
||||
t.Errorf("Expected a tensor shape: %v\n", want3)
|
||||
t.Errorf("Got a tensor shape: %v\n", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRangeIndex(t *testing.T) {
|
||||
|
||||
// Range
|
||||
tensor1 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx1 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(1, 3),
|
||||
}
|
||||
result1 := tensor1.Idx(idx1)
|
||||
want1 := []int64{3, 4, 5, 6, 7, 8}
|
||||
want1Shape := []int64{2, 3}
|
||||
got1 := result1.Vals()
|
||||
got1Shape := result1.MustSize()
|
||||
if !reflect.DeepEqual(want1, got1) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1)
|
||||
t.Errorf("Got tensor values: %v\n", got1)
|
||||
}
|
||||
if !reflect.DeepEqual(want1Shape, got1Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want1Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got1Shape)
|
||||
}
|
||||
|
||||
// Full range
|
||||
tensor2 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(2*3), gotch.Int64, gotch.CPU).MustView([]int64{2, 3}, true)
|
||||
idx2 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, tensor2.MustSize()[0]),
|
||||
}
|
||||
result2 := tensor2.Idx(idx2)
|
||||
want2 := []int64{0, 1, 2, 3, 4, 5}
|
||||
want2Shape := []int64{2, 3}
|
||||
got2 := result2.Vals()
|
||||
got2Shape := result2.MustSize()
|
||||
if !reflect.DeepEqual(want2, got2) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2)
|
||||
t.Errorf("Got tensor values: %v\n", got2)
|
||||
}
|
||||
if !reflect.DeepEqual(want2Shape, got2Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want2Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got2Shape)
|
||||
}
|
||||
|
||||
// Range from
|
||||
tensor3 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx3 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(2, tensor3.MustSize()[0]),
|
||||
}
|
||||
result3 := tensor3.Idx(idx3)
|
||||
want3 := []int64{6, 7, 8, 9, 10, 11}
|
||||
want3Shape := []int64{2, 3}
|
||||
got3 := result3.Vals()
|
||||
got3Shape := result3.MustSize()
|
||||
if !reflect.DeepEqual(want3, got3) {
|
||||
t.Errorf("Expected tensor values: %v\n", want3)
|
||||
t.Errorf("Got tensor values: %v\n", got3)
|
||||
}
|
||||
if !reflect.DeepEqual(want3Shape, got3Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want3Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got3Shape)
|
||||
}
|
||||
|
||||
// Range to
|
||||
tensor4 := ts.MustArange1(ts.IntScalar(0), ts.IntScalar(4*3), gotch.Int64, gotch.CPU).MustView([]int64{4, 3}, true)
|
||||
idx4 := []ts.TensorIndexer{
|
||||
ts.NewNarrow(0, 2),
|
||||
}
|
||||
result4 := tensor4.Idx(idx4)
|
||||
want4 := []int64{0, 1, 2, 3, 4, 5}
|
||||
want4Shape := []int64{2, 3}
|
||||
got4 := result4.Vals()
|
||||
got4Shape := result4.MustSize()
|
||||
if !reflect.DeepEqual(want4, got4) {
|
||||
t.Errorf("Expected tensor values: %v\n", want4)
|
||||
t.Errorf("Got tensor values: %v\n", got4)
|
||||
}
|
||||
if !reflect.DeepEqual(want4Shape, got4Shape) {
|
||||
t.Errorf("Expected tensor values: %v\n", want4Shape)
|
||||
t.Errorf("Got tensor values: %v\n", got4Shape)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSliceIndex(t *testing.T) {
|
||||
// TODO
|
||||
}
|
||||
|
|
|
@ -1029,6 +1029,38 @@ func (ts Tensor) Values() []float64 {
|
|||
return values
|
||||
}
|
||||
|
||||
// Vals returns tensor values in a slice
|
||||
// NOTE: need a type insersion to get runtime type
|
||||
// E.g. res := xs.Vals().([]int64)
|
||||
func (ts Tensor) Vals() (retVal interface{}) {
|
||||
dtype := ts.DType()
|
||||
numel := ts.Numel()
|
||||
|
||||
switch dtype.Name() {
|
||||
case "uint8":
|
||||
retVal = make([]uint8, numel)
|
||||
case "int8":
|
||||
retVal = make([]int8, numel)
|
||||
case "int16":
|
||||
retVal = make([]int16, numel)
|
||||
case "int32":
|
||||
retVal = make([]int32, numel)
|
||||
case "int64":
|
||||
retVal = make([]int64, numel)
|
||||
case "float32":
|
||||
retVal = make([]float32, numel)
|
||||
case "float64":
|
||||
retVal = make([]float64, numel)
|
||||
case "bool":
|
||||
retVal = make([]bool, numel)
|
||||
default:
|
||||
log.Fatalf("Unsupported dtype (%v)", dtype)
|
||||
}
|
||||
|
||||
ts.CopyData(retVal, numel)
|
||||
return retVal
|
||||
}
|
||||
|
||||
// FlatView flattens a tensor.
|
||||
//
|
||||
// This returns a flattened version of the given tensor. The first dimension
|
||||
|
|
Loading…
Reference in New Issue
Block a user