fix(vision/augementation)
This commit is contained in:
parent
2fa7431c25
commit
f68668530d
|
@ -140,7 +140,6 @@ func main() {
|
|||
|
||||
devicedData := item.Data.MustTo(vs.Device(), true)
|
||||
devicedLabel := item.Label.MustTo(vs.Device(), true)
|
||||
|
||||
bimages := vision.Augmentation(devicedData, true, 4, 8)
|
||||
|
||||
logits := net.ForwardT(bimages, true)
|
||||
|
|
|
@ -157,12 +157,14 @@ func Augmentation(t ts.Tensor, flip bool, crop int64, cutout int64) (retVal ts.T
|
|||
var cropTs ts.Tensor
|
||||
if crop > 0 {
|
||||
cropTs = RandomCrop(flipTs, crop)
|
||||
flipTs.MustDrop()
|
||||
} else {
|
||||
cropTs = flipTs
|
||||
}
|
||||
|
||||
if cutout > 0 {
|
||||
retVal = RandomCutout(cropTs, cutout)
|
||||
cropTs.MustDrop()
|
||||
} else {
|
||||
retVal = cropTs
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user