fix(vision/cifar): copy on indexing
This commit is contained in:
parent
ab0ebc21eb
commit
f33ca7edf1
|
@ -1,6 +1,7 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
|
@ -46,4 +47,14 @@ func main() {
|
|||
|
||||
combinedTs.Print()
|
||||
|
||||
// Copy to index
|
||||
desTs := tensor.MustZeros([]int64{5}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
srcTs := tensor.MustOnes([]int64{1}, gotch.Float.CInt(), gotch.CPU.CInt())
|
||||
idx := tensor.NewNarrow(0, 3)
|
||||
|
||||
// NOTE: indexing operations return view on the same memory
|
||||
desTs.Print()
|
||||
desTs.Idx(idx).MustView([]int64{-1}, false).Copy_(srcTs)
|
||||
desTs.Print()
|
||||
|
||||
}
|
||||
|
|
|
@ -58,6 +58,9 @@ package tensor
|
|||
// is that `i` guarantees the input and result tensor shares the same
|
||||
// underlying storage, while NumPy may copy the tensor in certain scenarios.
|
||||
|
||||
// NOTE: select, narrow and indexing operations (except when using a LongTensor index) return views onto the same memory.
|
||||
// https://discuss.pytorch.org/t/does-select-and-narrow-return-a-view-or-copy/289
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
|
|
@ -59,12 +59,11 @@ func readFile(filename string) (imagesTs ts.Tensor, labelsTs ts.Tensor) {
|
|||
tmp1 := content.MustNarrow(0, int64(1+contentOffset), int64(bytesPerImage-1), false)
|
||||
tmp2 := tmp1.MustView([]int64{cfC, cfH, cfW}, true)
|
||||
tmp3 := tmp2.MustTo(gotch.CPU, true)
|
||||
selectImageTs := images.Idx(ts.NewSelect(int64(idx)))
|
||||
selectImageTs.Copy_(tmp3)
|
||||
tmp3.MustDrop()
|
||||
|
||||
// TODO: concat all selectLabelTs and selectImageTs to single labelsTs and
|
||||
// imagesTs
|
||||
// NOTE: tensor indexing operations return view on the same memory
|
||||
// images.Idx(ts.NewSelect(int64(idx))).Copy_(tmp3)
|
||||
images.Idx(ts.NewSelect(int64(idx))).MustView([]int64{cfC, cfH, cfW}, false).Copy_(tmp3)
|
||||
tmp3.MustDrop()
|
||||
}
|
||||
|
||||
tmp1 := images.MustTotype(gotch.Float, true)
|
||||
|
|
Loading…
Reference in New Issue
Block a user