fix(vision/cifar): copy on indexing

This commit is contained in:
sugarme 2020-07-07 19:26:35 +10:00
parent ab0ebc21eb
commit f33ca7edf1
3 changed files with 18 additions and 5 deletions

View File

@ -1,6 +1,7 @@
package main
import (
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/tensor"
)
@ -46,4 +47,14 @@ func main() {
combinedTs.Print()
// Copy to index
desTs := tensor.MustZeros([]int64{5}, gotch.Float.CInt(), gotch.CPU.CInt())
srcTs := tensor.MustOnes([]int64{1}, gotch.Float.CInt(), gotch.CPU.CInt())
idx := tensor.NewNarrow(0, 3)
// NOTE: indexing operations return view on the same memory
desTs.Print()
desTs.Idx(idx).MustView([]int64{-1}, false).Copy_(srcTs)
desTs.Print()
}

View File

@ -58,6 +58,9 @@ package tensor
// is that `i` guarantees the input and result tensor shares the same
// underlying storage, while NumPy may copy the tensor in certain scenarios.
// NOTE: select, narrow and indexing operations (except when using a LongTensor index) return views onto the same memory.
// https://discuss.pytorch.org/t/does-select-and-narrow-return-a-view-or-copy/289
import (
"fmt"
"log"

View File

@ -59,12 +59,11 @@ func readFile(filename string) (imagesTs ts.Tensor, labelsTs ts.Tensor) {
tmp1 := content.MustNarrow(0, int64(1+contentOffset), int64(bytesPerImage-1), false)
tmp2 := tmp1.MustView([]int64{cfC, cfH, cfW}, true)
tmp3 := tmp2.MustTo(gotch.CPU, true)
selectImageTs := images.Idx(ts.NewSelect(int64(idx)))
selectImageTs.Copy_(tmp3)
tmp3.MustDrop()
// TODO: concat all selectLabelTs and selectImageTs to single labelsTs and
// imagesTs
// NOTE: tensor indexing operations return view on the same memory
// images.Idx(ts.NewSelect(int64(idx))).Copy_(tmp3)
images.Idx(ts.NewSelect(int64(idx))).MustView([]int64{cfC, cfH, cfW}, false).Copy_(tmp3)
tmp3.MustDrop()
}
tmp1 := images.MustTotype(gotch.Float, true)