This commit is contained in:
sugarme 2020-07-08 22:29:04 +10:00
parent c1d53fc865
commit 2fa7431c25
2 changed files with 16 additions and 21 deletions

View File

@ -152,7 +152,7 @@ func main() {
lossVal = loss.Values()[0]
logits.MustDrop()
// logits.MustDrop()
// item.Data.MustDrop()
// item.Label.MustDrop()
devicedData.MustDrop()

View File

@ -43,23 +43,23 @@ func NewIter2(xs, ys Tensor, batchSize int64) (retVal Iter2, err error) {
return retVal, err
}
xsClone, err := xs.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
xsClone.Copy_(xs)
ysClone, err := ys.ZerosLike(false)
if err != nil {
log.Fatal(err)
}
ysClone.Copy_(ys)
// xsClone, err := xs.ZerosLike(false)
// if err != nil {
// log.Fatal(err)
// }
// xsClone.Copy_(xs)
//
// ysClone, err := ys.ZerosLike(false)
// if err != nil {
// log.Fatal(err)
// }
// ysClone.Copy_(ys)
retVal = Iter2{
// xs: xs.MustShallowClone(),
// ys: ys.MustShallowClone(),
xs: xsClone,
ys: ysClone,
xs: xs.MustShallowClone(),
ys: ys.MustShallowClone(),
// xs: xsClone,
// ys: ysClone,
batchIndex: 0,
batchSize: batchSize,
totalSize: totalSize,
@ -138,14 +138,9 @@ func (it *Iter2) Next() (item Iter2Item, ok bool) {
// Indexing
narrowIndex := NewNarrow(start, start+size)
// data := it.xs.Idx(narrowIndex).MustTo(it.device, false)
// label := it.ys.Idx(narrowIndex).MustTo(it.device, false)
return Iter2Item{
Data: it.xs.Idx(narrowIndex),
Label: it.ys.Idx(narrowIndex),
// Data: data,
// Label: label,
}, true
}
}