added missing
This commit is contained in:
parent
dbbb4b97c1
commit
d1b9267c77
28
ts/tensor.go
28
ts/tensor.go
|
@ -212,6 +212,34 @@ func (ts *Tensor) MustSize() []int64 {
|
||||||
return shape
|
return shape
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts *Tensor) Stride() ([]int64, error) {
|
||||||
|
dim := lib.AtDim(ts.ctensor)
|
||||||
|
sz := make([]int64, dim)
|
||||||
|
szPtr, err := DataAsPtr(sz)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer C.free(unsafe.Pointer(szPtr))
|
||||||
|
|
||||||
|
lib.AtStride(ts.ctensor, szPtr)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
strides := decodeSize(szPtr, dim)
|
||||||
|
|
||||||
|
return strides, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts *Tensor) MustStride() []int64 {
|
||||||
|
strides, err := ts.Stride()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return strides
|
||||||
|
}
|
||||||
|
|
||||||
// Size1 returns the tensor size for 1D tensors.
|
// Size1 returns the tensor size for 1D tensors.
|
||||||
func (ts *Tensor) Size1() (int64, error) {
|
func (ts *Tensor) Size1() (int64, error) {
|
||||||
shape, err := ts.Size()
|
shape, err := ts.Size()
|
||||||
|
|
|
@ -156,3 +156,14 @@ func TestCudaCurrentDevice(t *testing.T) {
|
||||||
}
|
}
|
||||||
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
|
t.Logf("Cuda index AFTER set: %v\n", cudaIdxAfter) // 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTensor_Stride(t *testing.T) {
|
||||||
|
shape := []int64{2, 3, 4}
|
||||||
|
x := ts.MustRand(shape, gotch.Float, gotch.CPU)
|
||||||
|
|
||||||
|
got := x.MustStride()
|
||||||
|
want := []int64{12, 4, 1}
|
||||||
|
if !reflect.DeepEqual(want, got) {
|
||||||
|
t.Errorf("want %v, got %v\n", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user