fix(vision/augementation)

This commit is contained in:
sugarme 2020-07-08 22:49:39 +10:00
parent 2fa7431c25
commit f68668530d
2 changed files with 2 additions and 1 deletions

View File

@ -140,7 +140,6 @@ func main() {
devicedData := item.Data.MustTo(vs.Device(), true)
devicedLabel := item.Label.MustTo(vs.Device(), true)
bimages := vision.Augmentation(devicedData, true, 4, 8)
logits := net.ForwardT(bimages, true)

View File

@ -157,12 +157,14 @@ func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.T
var cropTs ts.Tensor
if crop > 0 {
cropTs = RandomCrop(flipTs, crop)
flipTs.MustDrop()
} else {
cropTs = flipTs
}
if cutout > 0 {
retVal = RandomCutout(cropTs, cutout)
cropTs.MustDrop()
} else {
retVal = cropTs
}