gotch/vision/aug/perspective.go

198 lines
6.1 KiB
Go

package aug
import (
// "fmt"
"git.andr3h3nriqu3s.com/andr3/gotch"
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
)
// RandomPerspective performs a random perspective transformation of the given image with a given probability.
// If the image is torch Tensor, it is expected
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
// Args:
// distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
// Default is 0.5.
// p (float): probability of the image being transformed. Default is 0.5.
// interpolation (InterpolationMode): Desired interpolation enum defined by
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
// For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
// fill (sequence or number): Pixel fill value for the area outside the transformed
// image. Default is ``0``. If given a number, the value is used for all bands respectively.
type RandomPerspective struct {
distortionScale float64 // range [0, 1]
pvalue float64 // range [0, 1]
interpolationMode string
fillValue []float64
}
type perspectiveOptions struct {
distortionScale float64 // range [0, 1]
pvalue float64 // range [0, 1]
interpolationMode string
fillValue []float64
}
func defaultPerspectiveOptions() *perspectiveOptions {
return &perspectiveOptions{
distortionScale: 0.5,
pvalue: 0.5,
interpolationMode: "bilinear",
fillValue: []float64{0.0, 0.0, 0.0},
}
}
type PerspectiveOption func(*perspectiveOptions)
func WithPerspectivePvalue(p float64) PerspectiveOption {
return func(o *perspectiveOptions) {
o.pvalue = p
}
}
func WithPerspectiveScale(s float64) PerspectiveOption {
return func(o *perspectiveOptions) {
o.distortionScale = s
}
}
func WithPerspectiveMode(m string) PerspectiveOption {
return func(o *perspectiveOptions) {
o.interpolationMode = m
}
}
func WithPerspectiveValue(v []float64) PerspectiveOption {
return func(o *perspectiveOptions) {
o.fillValue = v
}
}
func newRandomPerspective(opts ...PerspectiveOption) *RandomPerspective {
params := defaultPerspectiveOptions()
for _, opt := range opts {
opt(params)
}
return &RandomPerspective{
distortionScale: params.distortionScale,
pvalue: params.pvalue,
interpolationMode: params.interpolationMode,
fillValue: params.fillValue,
}
}
// Get parameters for ``perspective`` for a random perspective transform.
//
// Args:
// - width (int): width of the image.
// - height (int): height of the image.
// Returns:
// - List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
// - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
halfH := h / 2
halfW := w / 2
var (
topLeft []int64
topRight []int64
bottomRight []int64
bottomLeft []int64
)
// topleft = [
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
// ]
tlVal1 := int64(rp.distortionScale*float64(halfW)) + 1
tlTs1 := ts.MustRandint(tlVal1, []int64{1}, gotch.Int64, gotch.CPU)
tl1 := tlTs1.Int64Values()[0]
tlTs1.MustDrop()
tlVal2 := int64(rp.distortionScale*float64(halfH)) + 1
tlTs2 := ts.MustRandint(tlVal2, []int64{1}, gotch.Int64, gotch.CPU)
tl2 := tlTs2.Int64Values()[0]
tlTs2.MustDrop()
topLeft = []int64{tl1, tl2}
// topright = [
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
// ]
trVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
trTs1 := ts.MustRandintLow(trVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
tr1 := trTs1.Int64Values()[0]
trTs1.MustDrop()
trVal2 := int64(rp.distortionScale*float64(halfH)) + 1
trTs2 := ts.MustRandint(trVal2, []int64{1}, gotch.Int64, gotch.CPU)
tr2 := trTs2.Int64Values()[0]
trTs2.MustDrop()
topRight = []int64{tr1, tr2}
// botright = [
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
// ]
brVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
brTs1 := ts.MustRandintLow(brVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
br1 := brTs1.Int64Values()[0]
brTs1.MustDrop()
brVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
brTs2 := ts.MustRandintLow(brVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
br2 := brTs2.Int64Values()[0]
brTs2.MustDrop()
bottomRight = []int64{br1, br2}
// botleft = [
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
// ]
blVal1 := int64(rp.distortionScale*float64(halfW)) + 1
blTs1 := ts.MustRandint(blVal1, []int64{1}, gotch.Int64, gotch.CPU)
bl1 := blTs1.Int64Values()[0]
blTs1.MustDrop()
blVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
blTs2 := ts.MustRandintLow(blVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
bl2 := blTs2.Int64Values()[0]
blTs2.MustDrop()
bottomLeft = []int64{bl1, bl2}
startPoints := [][]int64{
{0, 0},
{w - 1, 0},
{w - 1, h - 1},
{0, h - 1},
}
endPoints := [][]int64{
topLeft,
topRight,
bottomRight,
bottomLeft,
}
return startPoints, endPoints
}
func (rp *RandomPerspective) Forward(x *ts.Tensor) *ts.Tensor {
fx := Byte2FloatImage(x)
height, width := getImageSize(fx)
startPoints, endPoints := rp.getParams(height, width)
out := perspective(fx, startPoints, endPoints, rp.interpolationMode, rp.fillValue)
bx := Float2ByteImage(out)
fx.MustDrop()
out.MustDrop()
return bx
}
func WithRandomPerspective(opts ...PerspectiveOption) Option {
rp := newRandomPerspective(opts...)
return func(o *Options) {
o.randomPerspective = rp
}
}