diff --git a/ts/tensor.go b/ts/tensor.go index c97f4cb..2c49c9e 100644 --- a/ts/tensor.go +++ b/ts/tensor.go @@ -212,6 +212,34 @@ func (ts *Tensor) MustSize() []int64 { return shape } +func (ts *Tensor) Stride() ([]int64, error) { + dim := lib.AtDim(ts.ctensor) + sz := make([]int64, dim) + szPtr, err := DataAsPtr(sz) + if err != nil { + return nil, err + } + defer C.free(unsafe.Pointer(szPtr)) + + lib.AtStride(ts.ctensor, szPtr) + if err = TorchErr(); err != nil { + return nil, err + } + + strides := decodeSize(szPtr, dim) + + return strides, nil +} + +func (ts *Tensor) MustStride() []int64 { + strides, err := ts.Stride() + if err != nil { + log.Fatal(err) + } + + return strides +} + // Size1 returns the tensor size for 1D tensors. func (ts *Tensor) Size1() (int64, error) { shape, err := ts.Size() diff --git a/ts/tensor_test.go b/ts/tensor_test.go index 846a99e..3f89821 100644 --- a/ts/tensor_test.go +++ b/ts/tensor_test.go @@ -156,3 +156,14 @@ func TestCudaCurrentDevice(t *testing.T) { } t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0 } + +func TestTensor_Stride(t *testing.T) { + shape := []int64{2, 3, 4} + x := ts.MustRand(shape, gotch.Float, gotch.CPU) + + got := x.MustStride() + want := []int64{12, 4, 1} + if !reflect.DeepEqual(want, got) { + t.Errorf("want %v, got %v\n", want, got) + } +}