fix(augmentation): more fix

This commit is contained in:
sugarme 2020-07-08 23:50:35 +10:00
parent 8e1ef76d20
commit ca74b1a949
2 changed files with 6 additions and 5 deletions

View File

@ -140,10 +140,10 @@ func main() {
devicedData := item.Data.MustTo(vs.Device(), true)
devicedLabel := item.Label.MustTo(vs.Device(), true)
// bimages := vision.Augmentation(devicedData, true, 0, 0)
bimages := vision.Augmentation(devicedData, true, 4, 8)
// logits := net.ForwardT(bimages, true)
logits := net.ForwardT(devicedData, true)
logits := net.ForwardT(bimages, true)
// logits := net.ForwardT(devicedData, true)
// logits := net.ForwardT(item.Data.MustTo(vs.Device(), true), false)
loss := logits.CrossEntropyForLogits(devicedLabel)
@ -156,9 +156,8 @@ func main() {
// item.Label.MustDrop()
devicedData.MustDrop()
devicedLabel.MustDrop()
// bimages.MustDrop()
bimages.MustDrop()
loss.MustDrop()
}
si = gotch.GetSysInfo()

View File

@ -57,6 +57,7 @@ func RandomFlip(t ts.Tensor) (retVal ts.Tensor) {
src = tView
} else {
src = tView.MustFlip([]int64{2})
tView.MustDrop()
}
outputView.Copy_(src)
@ -100,6 +101,7 @@ func RandomCrop(t ts.Tensor, pad int64) (retVal ts.Tensor) {
wIdx := ts.NewNarrow(int64(startW), int64(startW)+szW)
srcIdx = append(srcIdx, nIdx, cIdx, hIdx, wIdx)
src := padded.Idx(srcIdx)
padded.MustDrop()
outputView.Copy_(src)
src.MustDrop()
}