added missing
This commit is contained in:
parent
dbbb4b97c1
commit
d1b9267c77
28
ts/tensor.go
28
ts/tensor.go
|
@ -212,6 +212,34 @@ func (ts *Tensor) MustSize() []int64 {
|
|||
return shape
|
||||
}
|
||||
|
||||
func (ts *Tensor) Stride() ([]int64, error) {
|
||||
dim := lib.AtDim(ts.ctensor)
|
||||
sz := make([]int64, dim)
|
||||
szPtr, err := DataAsPtr(sz)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer C.free(unsafe.Pointer(szPtr))
|
||||
|
||||
lib.AtStride(ts.ctensor, szPtr)
|
||||
if err = TorchErr(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
strides := decodeSize(szPtr, dim)
|
||||
|
||||
return strides, nil
|
||||
}
|
||||
|
||||
func (ts *Tensor) MustStride() []int64 {
|
||||
strides, err := ts.Stride()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return strides
|
||||
}
|
||||
|
||||
// Size1 returns the tensor size for 1D tensors.
|
||||
func (ts *Tensor) Size1() (int64, error) {
|
||||
shape, err := ts.Size()
|
||||
|
|
|
@ -156,3 +156,14 @@ func TestCudaCurrentDevice(t *testing.T) {
|
|||
}
|
||||
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
|
||||
}
|
||||
|
||||
func TestTensor_Stride(t *testing.T) {
|
||||
shape := []int64{2, 3, 4}
|
||||
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||
|
||||
got := x.MustStride()
|
||||
want := []int64{12, 4, 1}
|
||||
if !reflect.DeepEqual(want, got) {
|
||||
t.Errorf("want %v, got %v\n", want, got)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user