96 lines
2.1 KiB
Go
96 lines
2.1 KiB
Go
package aug
|
|
|
|
import (
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
)
|
|
|
|
// RandomHorizontalFlip horizontally flips the given image randomly with a given probability.
|
|
//
|
|
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
|
// where ... means an arbitrary number of leading dimensions
|
|
// Args:
|
|
// p (float): probability of the image being flipped. Default value is 0.5
|
|
type RandomHorizontalFlip struct {
|
|
pvalue float64
|
|
}
|
|
|
|
func newRandomHorizontalFlip(pvalue float64) *RandomHorizontalFlip {
|
|
return &RandomHorizontalFlip{
|
|
pvalue: pvalue,
|
|
}
|
|
}
|
|
|
|
func (hf *RandomHorizontalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
|
fx := Byte2FloatImage(x)
|
|
|
|
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
|
randVal := randTs.Float64Values()[0]
|
|
randTs.MustDrop()
|
|
var out *ts.Tensor
|
|
switch {
|
|
case randVal < hf.pvalue:
|
|
out = hflip(fx)
|
|
default:
|
|
out = fx.MustShallowClone()
|
|
}
|
|
|
|
bx := Float2ByteImage(out)
|
|
fx.MustDrop()
|
|
out.MustDrop()
|
|
|
|
return bx
|
|
}
|
|
|
|
func WithRandomHFlip(pvalue float64) Option {
|
|
return func(o *Options) {
|
|
hf := newRandomHorizontalFlip(pvalue)
|
|
o.randomHFlip = hf
|
|
}
|
|
}
|
|
|
|
// RandomVerticalFlip vertically flips the given image randomly with a given probability.
|
|
//
|
|
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
|
// where ... means an arbitrary number of leading dimensions
|
|
// Args:
|
|
// p (float): probability of the image being flipped. Default value is 0.5
|
|
type RandomVerticalFlip struct {
|
|
pvalue float64
|
|
}
|
|
|
|
func newRandomVerticalFlip(pvalue float64) *RandomVerticalFlip {
|
|
return &RandomVerticalFlip{
|
|
pvalue: pvalue,
|
|
}
|
|
}
|
|
|
|
func (vf *RandomVerticalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
|
fx := Byte2FloatImage(x)
|
|
|
|
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
|
randVal := randTs.Float64Values()[0]
|
|
randTs.MustDrop()
|
|
|
|
var out *ts.Tensor
|
|
switch {
|
|
case randVal < vf.pvalue:
|
|
out = vflip(fx)
|
|
default:
|
|
out = fx.MustShallowClone()
|
|
}
|
|
|
|
bx := Float2ByteImage(out)
|
|
fx.MustDrop()
|
|
out.MustDrop()
|
|
|
|
return bx
|
|
}
|
|
|
|
func WithRandomVFlip(pvalue float64) Option {
|
|
return func(o *Options) {
|
|
vf := newRandomVerticalFlip(pvalue)
|
|
o.randomVFlip = vf
|
|
}
|
|
}
|