added missing

This commit is contained in:
sugarme 2023-07-23 12:17:27 +10:00
parent dbbb4b97c1
commit d1b9267c77
2 changed files with 39 additions and 0 deletions

View File

@ -212,6 +212,34 @@ func (ts *Tensor) MustSize() []int64 {
return shape
}
func (ts *Tensor) Stride() ([]int64, error) {
dim := lib.AtDim(ts.ctensor)
sz := make([]int64, dim)
szPtr, err := DataAsPtr(sz)
if err != nil {
return nil, err
}
defer C.free(unsafe.Pointer(szPtr))
lib.AtStride(ts.ctensor, szPtr)
if err = TorchErr(); err != nil {
return nil, err
}
strides := decodeSize(szPtr, dim)
return strides, nil
}
func (ts *Tensor) MustStride() []int64 {
strides, err := ts.Stride()
if err != nil {
log.Fatal(err)
}
return strides
}
// Size1 returns the tensor size for 1D tensors.
func (ts *Tensor) Size1() (int64, error) {
shape, err := ts.Size()

View File

@ -156,3 +156,14 @@ func TestCudaCurrentDevice(t *testing.T) {
}
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
}
func TestTensor_Stride(t *testing.T) {
shape := []int64{2, 3, 4}
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
got := x.MustStride()
want := []int64{12, 4, 1}
if !reflect.DeepEqual(want, got) {
t.Errorf("want %v, got %v\n", want, got)
}
}