added image augmentation and minor fixed on ts.Lstsq
This commit is contained in:
parent
fe6454c0ca
commit
7292c3575e
|
@ -6,9 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
|||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.3.9]
|
||||
- [#24], [#26]: fixed memory leak.
|
||||
- [#30]: fixed varstore.Save() randomly panic - segmentfault
|
||||
- [#32]: nn.Seq Forward return nil tensor if length of layers = 1
|
||||
- [#36]: resolved image augmentation
|
||||
|
||||
## [0.3.8]
|
||||
|
||||
|
|
31
example/augmentation/README.md
Normal file
31
example/augmentation/README.md
Normal file
|
@ -0,0 +1,31 @@
|
|||
# Image Augmentation Example
|
||||
|
||||
This example demonstrates how to use image augmentation functions. It is implemented as similar as possible to [original Pytorch vision/transform](https://pytorch.org/vision/stable/transforms.html#).
|
||||
|
||||
There are 2 APIs (`aug.Compose` and `aug.OneOf`) to compose augmentation methods as shown in the example:
|
||||
|
||||
```go
|
||||
t, err := aug.Compose(
|
||||
aug.WithRandomVFlip(0.5),
|
||||
aug.WithRandomHFlip(0.5),
|
||||
aug.WithRandomCutout(),
|
||||
aug.OneOf(
|
||||
0.3,
|
||||
aug.WithColorJitter(0.3, 0.3, 0.3, 0.4),
|
||||
aug.WithRandomGrayscale(1.0),
|
||||
),
|
||||
aug.OneOf(
|
||||
0.3,
|
||||
aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}),
|
||||
aug.WithRandomAffine(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
out := t.Transform(imgTs)
|
||||
```
|
||||
|
||||
|
||||
|
BIN
example/augmentation/bb.png
Normal file
BIN
example/augmentation/bb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
69
example/augmentation/main.go
Normal file
69
example/augmentation/main.go
Normal file
|
@ -0,0 +1,69 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
"github.com/sugarme/gotch/vision/aug"
|
||||
)
|
||||
|
||||
func main() {
|
||||
n := 360
|
||||
for i := 1; i <= n; i++ {
|
||||
img, err := vision.Load("./bb.png")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
device := gotch.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
imgTs := img.MustTo(device, true)
|
||||
// t, err := aug.Compose(aug.WithResize(512, 512)) // NOTE. WithResize just works on CPU.
|
||||
// t, err := aug.Compose(aug.WithRandRotate(0, 360), aug.WithColorJitter(0.3, 0.3, 0.3, 0.4))
|
||||
// t, err := aug.Compose(aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}), aug.WithRandRotate(0, 360), aug.WithColorJitter(0.3, 0.3, 0.3, 0.3))
|
||||
// t, err := aug.Compose(aug.WithRandomCrop([]int64{320, 320}, []int64{10, 10}, true, "constant"))
|
||||
// t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320}))
|
||||
// t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5})))
|
||||
// t, err := aug.Compose(aug.WithRandomPerspective(aug.WithPerspectiveScale(0.6), aug.WithPerspectivePvalue(0.8)))
|
||||
// t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15})))
|
||||
// t, err := aug.Compose(aug.WithRandomGrayscale(0.5))
|
||||
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(0.5)))
|
||||
// t, err := aug.Compose(aug.WithRandomInvert(0.5))
|
||||
// t, err := aug.Compose(aug.WithRandomPosterize(aug.WithPosterizeBits(2), aug.WithPosterizePvalue(1.0)))
|
||||
// t, err := aug.Compose(aug.WithRandomAutocontrast())
|
||||
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(0.3), aug.WithSharpnessFactor(10)))
|
||||
// t, err := aug.Compose(aug.WithRandomEqualize(1.0))
|
||||
// t, err := aug.Compose(aug.WithNormalize(aug.WithNormalizeMean([]float64{0.485, 0.456, 0.406}), aug.WithNormalizeStd([]float64{0.229, 0.224, 0.225})))
|
||||
|
||||
t, err := aug.Compose(
|
||||
aug.WithRandomVFlip(0.5),
|
||||
aug.WithRandomHFlip(0.5),
|
||||
aug.WithRandomCutout(),
|
||||
aug.OneOf(
|
||||
0.3,
|
||||
aug.WithColorJitter(0.3, 0.3, 0.3, 0.4),
|
||||
aug.WithRandomGrayscale(1.0),
|
||||
),
|
||||
aug.OneOf(
|
||||
0.3,
|
||||
aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}),
|
||||
aug.WithRandomAffine(),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
out := t.Transform(imgTs)
|
||||
fname := fmt.Sprintf("./output/bb-%03d.png", i)
|
||||
err = vision.Save(out, fname)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
imgTs.MustDrop()
|
||||
out.MustDrop()
|
||||
|
||||
fmt.Printf("%03d/%v completed.\n", i, n)
|
||||
}
|
||||
}
|
3
example/augmentation/output/.gitignore
vendored
Normal file
3
example/augmentation/output/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
|||
*
|
||||
!.gitignore
|
||||
!README.md
|
1
example/augmentation/output/README.md
Normal file
1
example/augmentation/output/README.md
Normal file
|
@ -0,0 +1 @@
|
|||
Output images will be here.
|
|
@ -581,7 +581,7 @@ func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
|||
}
|
||||
|
||||
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
||||
retVal, err := ts.Lstsq(del)
|
||||
retVal, err := ts.Lstsq(a, del)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
|
185
vision/aug/affine.go
Normal file
185
vision/aug/affine.go
Normal file
|
@ -0,0 +1,185 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomAffine is transformation of the image keeping center invariant.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// - degrees (sequence or number): Range of degrees to select from.
|
||||
// If degrees is a number instead of sequence like (min, max), the range of degrees
|
||||
// will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
||||
// - translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
||||
// and vertical translations. For example translate=(a, b), then horizontal shift
|
||||
// is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
||||
// randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
||||
// - scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
||||
// randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
||||
// - shear (sequence or number, optional): Range of degrees to select from.
|
||||
// If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
|
||||
// will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
|
||||
// range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
|
||||
// a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
|
||||
// Will not apply shear by default.
|
||||
// - interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
||||
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
||||
// For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
|
||||
// - fill (sequence or number): Pixel fill value for the area outside the transformed
|
||||
// image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
||||
// Please use the ``interpolation`` parameter instead.
|
||||
// .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
||||
type RandomAffine struct {
|
||||
degree []int64 // degree range
|
||||
translate []float64
|
||||
scale []float64 // scale range
|
||||
shear []float64
|
||||
interpolationMode string
|
||||
fillValue []float64
|
||||
}
|
||||
|
||||
func (ra *RandomAffine) getParams(imageSize []int64) (float64, []int64, float64, []float64) {
|
||||
angleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
angleTs.MustUniform_(float64(ra.degree[0]), float64(ra.degree[1]))
|
||||
angle := angleTs.Float64Values()[0]
|
||||
angleTs.MustDrop()
|
||||
|
||||
var translations []int64 = []int64{0, 0}
|
||||
if ra.translate != nil {
|
||||
maxDX := ra.translate[0] * float64(imageSize[0])
|
||||
maxDY := ra.translate[1] * float64(imageSize[1])
|
||||
dx := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
dx.MustUniform_(-maxDX, maxDX)
|
||||
tx := dx.Float64Values()[0]
|
||||
dx.MustDrop()
|
||||
|
||||
dy := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
dy.MustUniform_(-maxDY, maxDY)
|
||||
ty := dx.Float64Values()[0]
|
||||
dy.MustDrop()
|
||||
|
||||
translations = []int64{int64(tx), int64(ty)} // should we use math.Round here???
|
||||
}
|
||||
|
||||
scale := 1.0
|
||||
if ra.scale != nil {
|
||||
scaleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
scaleTs.MustUniform_(ra.scale[0], ra.scale[1])
|
||||
scale = scaleTs.Float64Values()[0]
|
||||
scaleTs.MustDrop()
|
||||
}
|
||||
|
||||
var (
|
||||
shearX, shearY float64 = 0.0, 0.0
|
||||
)
|
||||
if ra.shear != nil {
|
||||
shearXTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
shearXTs.MustUniform_(ra.shear[0], ra.shear[1])
|
||||
shearX = shearXTs.Float64Values()[0]
|
||||
shearXTs.MustDrop()
|
||||
|
||||
if len(ra.shear) == 4 {
|
||||
shearYTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
shearYTs.MustUniform_(ra.shear[2], ra.shear[3])
|
||||
shearY = shearYTs.Float64Values()[0]
|
||||
shearYTs.MustDrop()
|
||||
}
|
||||
}
|
||||
|
||||
var shear []float64 = []float64{shearX, shearY}
|
||||
|
||||
return angle, translations, scale, shear
|
||||
}
|
||||
|
||||
func (ra *RandomAffine) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
w, h := getImageSize(x)
|
||||
angle, translations, scale, shear := ra.getParams([]int64{w, h})
|
||||
|
||||
out := affine(x, angle, translations, scale, shear, ra.interpolationMode, ra.fillValue)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func newRandomAffine(opts ...affineOption) *RandomAffine {
|
||||
p := defaultAffineOptions()
|
||||
for _, o := range opts {
|
||||
o(p)
|
||||
}
|
||||
|
||||
return &RandomAffine{
|
||||
degree: p.degree,
|
||||
translate: p.translate,
|
||||
scale: p.scale,
|
||||
shear: p.shear,
|
||||
interpolationMode: p.interpolationMode,
|
||||
fillValue: p.fillValue,
|
||||
}
|
||||
}
|
||||
|
||||
type affineOptions struct {
|
||||
degree []int64
|
||||
translate []float64
|
||||
scale []float64
|
||||
shear []float64
|
||||
interpolationMode string
|
||||
fillValue []float64
|
||||
}
|
||||
|
||||
type affineOption func(*affineOptions)
|
||||
|
||||
func defaultAffineOptions() *affineOptions {
|
||||
return &affineOptions{
|
||||
degree: []int64{-180, 180},
|
||||
translate: nil,
|
||||
scale: nil,
|
||||
shear: []float64{-180.0, 180.0},
|
||||
interpolationMode: "bilinear",
|
||||
fillValue: []float64{0.0, 0.0, 0.0},
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineDegree(degree []int64) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.degree = degree
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineTranslate(translate []float64) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.translate = translate
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineScale(scale []float64) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineShear(shear []float64) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.shear = shear
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineMode(mode string) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.interpolationMode = mode
|
||||
}
|
||||
}
|
||||
|
||||
func WithAffineFillValue(fillValue []float64) affineOption {
|
||||
return func(o *affineOptions) {
|
||||
o.fillValue = fillValue
|
||||
}
|
||||
}
|
||||
|
||||
func WithRandomAffine(opts ...affineOption) Option {
|
||||
ra := newRandomAffine(opts...)
|
||||
return func(o *Options) {
|
||||
o.randomAffine = ra
|
||||
}
|
||||
}
|
89
vision/aug/blur.go
Normal file
89
vision/aug/blur.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type GaussianBlur struct {
|
||||
kernelSize []int64 // >= 0 && ks%2 != 0
|
||||
sigma []float64 // [0.1, 2.0] range(min, max)
|
||||
}
|
||||
|
||||
// ks : kernal size. Can be 1-2 element slice
|
||||
// sigma: minimal and maximal standard deviation that can be chosen for blurring kernel
|
||||
// range (min, max). Can be 1-2 element slice
|
||||
func newGaussianBlur(ks []int64, sig []float64) *GaussianBlur {
|
||||
if len(ks) == 0 || len(ks) > 2 {
|
||||
err := fmt.Errorf("Kernel size should have 1-2 elements. Got %v\n", len(ks))
|
||||
log.Fatal(err)
|
||||
}
|
||||
for _, size := range ks {
|
||||
if size <= 0 || size%2 == 0 {
|
||||
err := fmt.Errorf("Kernel size should be an odd and positive number.")
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(sig) == 0 || len(sig) > 2 {
|
||||
err := fmt.Errorf("Sigma should have 1-2 elements. Got %v\n", len(sig))
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
for _, s := range sig {
|
||||
if s <= 0 {
|
||||
err := fmt.Errorf("Sigma should be a positive number.")
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var kernelSize []int64
|
||||
switch len(ks) {
|
||||
case 1:
|
||||
kernelSize = []int64{ks[0], ks[0]}
|
||||
case 2:
|
||||
kernelSize = ks
|
||||
default:
|
||||
panic("Shouldn't reach here.")
|
||||
}
|
||||
|
||||
var sigma []float64
|
||||
switch len(sig) {
|
||||
case 1:
|
||||
sigma = []float64{sig[0], sig[0]}
|
||||
case 2:
|
||||
min := sig[0]
|
||||
max := sig[1]
|
||||
if min > max {
|
||||
min = sig[1]
|
||||
max = sig[0]
|
||||
}
|
||||
sigma = []float64{min, max}
|
||||
default:
|
||||
panic("Shouldn't reach here.")
|
||||
}
|
||||
|
||||
return &GaussianBlur{
|
||||
kernelSize: kernelSize,
|
||||
sigma: sigma,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *GaussianBlur) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
sigmaTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
sigmaTs.MustUniform_(b.sigma[0], b.sigma[1])
|
||||
sigmaVal := sigmaTs.Float64Values()[0]
|
||||
sigmaTs.MustDrop()
|
||||
|
||||
return gaussianBlur(x, b.kernelSize, []float64{sigmaVal, sigmaVal})
|
||||
}
|
||||
|
||||
func WithGaussianBlur(ks []int64, sig []float64) Option {
|
||||
return func(o *Options) {
|
||||
gb := newGaussianBlur(ks, sig)
|
||||
o.gaussianBlur = gb
|
||||
}
|
||||
}
|
77
vision/aug/color.go
Normal file
77
vision/aug/color.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Ref. https://github.com/pytorch/vision/blob/f1d734213af65dc06e777877d315973ba8386080/torchvision/transforms/functional_tensor.py
|
||||
|
||||
type ColorJitter struct {
|
||||
brightness float64
|
||||
contrast float64
|
||||
saturation float64
|
||||
hue float64
|
||||
}
|
||||
|
||||
func defaultColorJitter() *ColorJitter {
|
||||
return &ColorJitter{
|
||||
brightness: 1.0,
|
||||
contrast: 1.0,
|
||||
saturation: 1.0,
|
||||
hue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ColorJitter) setBrightness(brightness float64) {
|
||||
c.brightness = brightness
|
||||
}
|
||||
|
||||
func (c *ColorJitter) setContrast(contrast float64) {
|
||||
c.contrast = contrast
|
||||
}
|
||||
|
||||
func (c *ColorJitter) setSaturation(sat float64) {
|
||||
c.saturation = sat
|
||||
}
|
||||
|
||||
func (c *ColorJitter) setHue(hue float64) {
|
||||
c.hue = hue
|
||||
}
|
||||
|
||||
// Forward implement ts.Module by randomly picking one of brightness, contrast,
|
||||
// staturation or hue function to transform input image tensor.
|
||||
func (c *ColorJitter) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
idx := rand.Intn(4)
|
||||
switch idx {
|
||||
case 0:
|
||||
v := randVal(getMinMax(c.brightness))
|
||||
return adjustBrightness(x, v)
|
||||
case 1:
|
||||
v := randVal(getMinMax(c.contrast))
|
||||
return adjustContrast(x, v)
|
||||
case 2:
|
||||
v := randVal(getMinMax(c.saturation))
|
||||
return adjustSaturation(x, v)
|
||||
case 3:
|
||||
v := randVal(0, c.hue)
|
||||
return adjustHue(x, v)
|
||||
default:
|
||||
panic("Shouldn't reach here.")
|
||||
}
|
||||
}
|
||||
|
||||
func WithColorJitter(brightness, contrast, sat, hue float64) Option {
|
||||
c := defaultColorJitter()
|
||||
c.setBrightness(brightness)
|
||||
c.setContrast(contrast)
|
||||
c.setSaturation(sat)
|
||||
c.setHue(hue)
|
||||
|
||||
return func(o *Options) {
|
||||
o.colorJitter = c
|
||||
}
|
||||
}
|
43
vision/aug/contrast.go
Normal file
43
vision/aug/contrast.go
Normal file
|
@ -0,0 +1,43 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomAutocontrast autocontrasts the pixels of the given image randomly with a given probability.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// - p (float): probability of the image being autocontrasted. Default value is 0.5
|
||||
type RandomAutocontrast struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomAutocontrast(pOpt ...float64) *RandomAutocontrast {
|
||||
p := 0.5
|
||||
if len(pOpt) > 0 {
|
||||
p = pOpt[0]
|
||||
}
|
||||
|
||||
return &RandomAutocontrast{p}
|
||||
}
|
||||
|
||||
func (rac *RandomAutocontrast) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
r := randPvalue()
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < rac.pvalue:
|
||||
out = autocontrast(x)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomAutocontrast(p ...float64) Option {
|
||||
rac := newRandomAutocontrast(p...)
|
||||
return func(o *Options) {
|
||||
o.randomAutocontrast = rac
|
||||
}
|
||||
}
|
124
vision/aug/crop.go
Normal file
124
vision/aug/crop.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
// "math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type RandomCrop struct {
|
||||
size []int64
|
||||
padding []int64
|
||||
paddingIfNeeded bool
|
||||
paddingMode string
|
||||
}
|
||||
|
||||
func newRandomCrop(size, padding []int64, paddingIfNeeded bool, paddingMode string) *RandomCrop {
|
||||
return &RandomCrop{
|
||||
size: size,
|
||||
padding: padding,
|
||||
paddingIfNeeded: paddingIfNeeded,
|
||||
paddingMode: paddingMode,
|
||||
}
|
||||
}
|
||||
|
||||
// get parameters for crop
|
||||
func (c *RandomCrop) params(x *ts.Tensor) (int64, int64, int64, int64) {
|
||||
w, h := getImageSize(x)
|
||||
th, tw := c.size[0], c.size[1]
|
||||
if h+1 < th || w+1 < tw {
|
||||
err := fmt.Errorf("Required crop size %v is larger then input image size %v", c.size, []int64{h, w})
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if w == tw && h == th {
|
||||
return 0, 0, h, w
|
||||
}
|
||||
|
||||
iTs := ts.MustRandint1(0, h-th+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
i := iTs.Int64Values()[0]
|
||||
iTs.MustDrop()
|
||||
|
||||
jTs := ts.MustRandint1(0, w-tw+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
j := jTs.Int64Values()[0]
|
||||
jTs.MustDrop()
|
||||
|
||||
return i, j, th, tw
|
||||
}
|
||||
|
||||
func (c *RandomCrop) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
var img *ts.Tensor
|
||||
if c.padding != nil {
|
||||
img = pad(x, c.padding, c.paddingMode)
|
||||
} else {
|
||||
img = x.MustShallowClone()
|
||||
}
|
||||
|
||||
w, h := getImageSize(x)
|
||||
|
||||
var (
|
||||
paddedW *ts.Tensor
|
||||
paddedWH *ts.Tensor
|
||||
)
|
||||
// pad width if needed
|
||||
if c.paddingIfNeeded && w < c.size[1] {
|
||||
padding := []int64{c.size[1] - w, 0}
|
||||
paddedW = pad(img, padding, c.paddingMode)
|
||||
} else {
|
||||
paddedW = img.MustShallowClone()
|
||||
}
|
||||
img.MustDrop()
|
||||
|
||||
// pad height if needed
|
||||
if c.paddingIfNeeded && h < c.size[0] {
|
||||
padding := []int64{0, c.size[0] - h}
|
||||
paddedWH = pad(paddedW, padding, c.paddingMode)
|
||||
} else {
|
||||
paddedWH = paddedW.MustShallowClone()
|
||||
}
|
||||
|
||||
paddedW.MustDrop()
|
||||
|
||||
// i, j, h, w = self.get_params(img, self.size)
|
||||
i, j, h, w := c.params(x)
|
||||
out := crop(paddedWH, i, j, h, w)
|
||||
paddedWH.MustDrop()
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomCrop(size []int64, padding []int64, paddingIfNeeded bool, paddingMode string) Option {
|
||||
return func(o *Options) {
|
||||
c := newRandomCrop(size, padding, paddingIfNeeded, paddingMode)
|
||||
o.randomCrop = c
|
||||
}
|
||||
}
|
||||
|
||||
// CenterCrop crops the given image at the center.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
||||
type CenterCrop struct {
|
||||
size []int64
|
||||
}
|
||||
|
||||
func newCenterCrop(size []int64) *CenterCrop {
|
||||
if len(size) != 2 {
|
||||
err := fmt.Errorf("Expected size of 2 elements. Got %v\n", len(size))
|
||||
log.Fatal(err)
|
||||
}
|
||||
return &CenterCrop{size}
|
||||
}
|
||||
|
||||
func (cc *CenterCrop) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
return centerCrop(x, cc.size)
|
||||
}
|
||||
|
||||
func WithCenterCrop(size []int64) Option {
|
||||
return func(o *Options) {
|
||||
cc := newCenterCrop(size)
|
||||
o.centerCrop = cc
|
||||
}
|
||||
}
|
177
vision/aug/cutout.go
Normal file
177
vision/aug/cutout.go
Normal file
|
@ -0,0 +1,177 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
|
||||
// This transform does not support PIL Image.
|
||||
// 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
||||
//
|
||||
// Args:
|
||||
// p: probability that the random erasing operation will be performed.
|
||||
// scale: range of proportion of erased area against input image.
|
||||
// ratio: range of aspect ratio of erased area.
|
||||
// value: erasing value. Default is 0. If a single int, it is used to
|
||||
// erase all pixels. If a tuple of length 3, it is used to erase
|
||||
// R, G, B channels respectively.
|
||||
// If a str of 'random', erasing each pixel with random values.
|
||||
type RandomCutout struct {
|
||||
pvalue float64
|
||||
scale []float64
|
||||
ratio []float64
|
||||
rgbVal []int64 // RGB value
|
||||
}
|
||||
|
||||
type cutoutOptions struct {
|
||||
pvalue float64
|
||||
scale []float64
|
||||
ratio []float64
|
||||
rgbVal []int64 // RGB value
|
||||
}
|
||||
|
||||
type cutoutOption func(o *cutoutOptions)
|
||||
|
||||
func defaultCutoutOptions() *cutoutOptions {
|
||||
return &cutoutOptions{
|
||||
pvalue: 0.5,
|
||||
scale: []float64{0.02, 0.33},
|
||||
ratio: []float64{0.3, 3.3},
|
||||
rgbVal: []int64{0, 0, 0},
|
||||
}
|
||||
}
|
||||
|
||||
func newRandomCutout(pvalue float64, scale, ratio []float64, rgbVal []int64) *RandomCutout {
|
||||
return &RandomCutout{
|
||||
pvalue: pvalue,
|
||||
scale: scale,
|
||||
ratio: ratio,
|
||||
rgbVal: rgbVal,
|
||||
}
|
||||
}
|
||||
|
||||
func WithCutoutPvalue(p float64) cutoutOption {
|
||||
if p < 0 || p > 1 {
|
||||
log.Fatalf("Cutout p-value must be in range from 0 to 1. Got %v\n", p)
|
||||
}
|
||||
return func(o *cutoutOptions) {
|
||||
o.pvalue = p
|
||||
}
|
||||
}
|
||||
|
||||
func WithCutoutScale(scale []float64) cutoutOption {
|
||||
if len(scale) != 2 {
|
||||
log.Fatalf("Cutout scale should be in a range of 2 elments. Got %v elements\n", len(scale))
|
||||
}
|
||||
return func(o *cutoutOptions) {
|
||||
o.scale = scale
|
||||
}
|
||||
}
|
||||
|
||||
func WithCutoutRatio(ratio []float64) cutoutOption {
|
||||
if len(ratio) != 2 {
|
||||
log.Fatalf("Cutout ratio should be in a range of 2 elments. Got %v elements\n", len(ratio))
|
||||
}
|
||||
return func(o *cutoutOptions) {
|
||||
o.ratio = ratio
|
||||
}
|
||||
}
|
||||
|
||||
func WithCutoutValue(rgb []int64) cutoutOption {
|
||||
var rgbVal []int64
|
||||
switch len(rgb) {
|
||||
case 1:
|
||||
rgbVal = []int64{rgb[0], rgb[0], rgb[0]}
|
||||
case 3:
|
||||
rgbVal = rgb
|
||||
default:
|
||||
err := fmt.Errorf("Cutout values can be single value or 3-element (RGB) value. Got %v values.", len(rgb))
|
||||
log.Fatal(err)
|
||||
}
|
||||
return func(o *cutoutOptions) {
|
||||
o.rgbVal = rgbVal
|
||||
}
|
||||
}
|
||||
|
||||
func (rc *RandomCutout) cutoutParams(x *ts.Tensor) (int64, int64, int64, int64, *ts.Tensor) {
|
||||
dim := x.MustSize()
|
||||
|
||||
imgH, imgW := dim[len(dim)-2], dim[len(dim)-1]
|
||||
area := float64(imgH * imgW)
|
||||
logRatio := ts.MustOfSlice(rc.ratio).MustLog(true).Float64Values()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
scaleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
scaleTs.MustUniform_(rc.scale[0], rc.scale[1])
|
||||
scaleVal := scaleTs.Float64Values()[0]
|
||||
scaleTs.MustDrop()
|
||||
eraseArea := area * scaleVal
|
||||
|
||||
ratioTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||
ratioTs.MustUniform_(logRatio[0], logRatio[1])
|
||||
asTs := ratioTs.MustExp(true)
|
||||
asVal := asTs.Float64Values()[0] // aspect ratio
|
||||
asTs.MustDrop()
|
||||
|
||||
// h = int(round(math.sqrt(erase_area * aspect_ratio)))
|
||||
// w = int(round(math.sqrt(erase_area / aspect_ratio)))
|
||||
h := int64(math.Round(math.Sqrt(eraseArea * asVal)))
|
||||
w := int64(math.Round(math.Sqrt(eraseArea / asVal)))
|
||||
if !(h < imgH && w < imgW) {
|
||||
continue
|
||||
}
|
||||
|
||||
// v = torch.tensor(value)[:, None, None]
|
||||
v := ts.MustOfSlice(rc.rgbVal).MustUnsqueeze(1, true).MustUnsqueeze(1, true)
|
||||
|
||||
// i = torch.randint(0, img_h - h + 1, size=(1, )).item()
|
||||
iTs := ts.MustRandint1(0, imgH-h+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
i := iTs.Int64Values()[0]
|
||||
iTs.MustDrop()
|
||||
// j = torch.randint(0, img_w - w + 1, size=(1, )).item()
|
||||
jTs := ts.MustRandint1(0, imgW-w+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
j := jTs.Int64Values()[0]
|
||||
jTs.MustDrop()
|
||||
return i, j, h, w, v
|
||||
}
|
||||
|
||||
// return original image
|
||||
img := x.MustShallowClone()
|
||||
return 0, 0, imgH, imgW, img
|
||||
}
|
||||
|
||||
func (rc *RandomCutout) Forward(img *ts.Tensor) *ts.Tensor {
|
||||
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||
randVal := randTs.Float64Values()[0]
|
||||
randTs.MustDrop()
|
||||
|
||||
switch randVal < rc.pvalue {
|
||||
case true:
|
||||
x, y, h, w, v := rc.cutoutParams(img)
|
||||
out := cutout(img, x, y, h, w, rc.rgbVal)
|
||||
v.MustDrop()
|
||||
return out
|
||||
case false:
|
||||
out := img.MustShallowClone()
|
||||
return out
|
||||
}
|
||||
|
||||
panic("Shouldn't reach here")
|
||||
}
|
||||
|
||||
func WithRandomCutout(opts ...cutoutOption) Option {
|
||||
params := defaultCutoutOptions()
|
||||
for _, o := range opts {
|
||||
o(params)
|
||||
}
|
||||
|
||||
return func(o *Options) {
|
||||
rc := newRandomCutout(params.pvalue, params.scale, params.ratio, params.rgbVal)
|
||||
o.randomCutout = rc
|
||||
}
|
||||
}
|
46
vision/aug/equalize.go
Normal file
46
vision/aug/equalize.go
Normal file
|
@ -0,0 +1,46 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomEqualize equalizes the histogram of the given image randomly with a given probability.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// - p (float): probability of the image being equalized. Default value is 0.5
|
||||
// Histogram equalization
|
||||
// Ref. https://en.wikipedia.org/wiki/Histogram_equalization
|
||||
type RandomEqualize struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomEqualize(pOpt ...float64) *RandomEqualize {
|
||||
p := 0.5
|
||||
if len(pOpt) > 0 {
|
||||
p = pOpt[0]
|
||||
}
|
||||
|
||||
return &RandomEqualize{p}
|
||||
}
|
||||
|
||||
func (re *RandomEqualize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
r := randPvalue()
|
||||
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < re.pvalue:
|
||||
out = equalize(x)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomEqualize(p ...float64) Option {
|
||||
re := newRandomEqualize(p...)
|
||||
return func(o *Options) {
|
||||
o.randomEqualize = re
|
||||
}
|
||||
}
|
78
vision/aug/flip.go
Normal file
78
vision/aug/flip.go
Normal file
|
@ -0,0 +1,78 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomHorizontalFlip horizontally flips the given image randomly with a given probability.
|
||||
//
|
||||
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
||||
// where ... means an arbitrary number of leading dimensions
|
||||
// Args:
|
||||
// p (float): probability of the image being flipped. Default value is 0.5
|
||||
type RandomHorizontalFlip struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomHorizontalFlip(pvalue float64) *RandomHorizontalFlip {
|
||||
return &RandomHorizontalFlip{
|
||||
pvalue: pvalue,
|
||||
}
|
||||
}
|
||||
|
||||
func (hf *RandomHorizontalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||
randVal := randTs.Float64Values()[0]
|
||||
randTs.MustDrop()
|
||||
switch {
|
||||
case randVal < hf.pvalue:
|
||||
return hflip(x)
|
||||
default:
|
||||
out := x.MustShallowClone()
|
||||
return out
|
||||
}
|
||||
}
|
||||
|
||||
func WithRandomHFlip(pvalue float64) Option {
|
||||
return func(o *Options) {
|
||||
hf := newRandomHorizontalFlip(pvalue)
|
||||
o.randomHFlip = hf
|
||||
}
|
||||
}
|
||||
|
||||
// RandomVerticalFlip vertically flips the given image randomly with a given probability.
|
||||
//
|
||||
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
||||
// where ... means an arbitrary number of leading dimensions
|
||||
// Args:
|
||||
// p (float): probability of the image being flipped. Default value is 0.5
|
||||
type RandomVerticalFlip struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomVerticalFlip(pvalue float64) *RandomVerticalFlip {
|
||||
return &RandomVerticalFlip{
|
||||
pvalue: pvalue,
|
||||
}
|
||||
}
|
||||
|
||||
func (vf *RandomVerticalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||
randVal := randTs.Float64Values()[0]
|
||||
randTs.MustDrop()
|
||||
switch {
|
||||
case randVal < vf.pvalue:
|
||||
return vflip(x)
|
||||
default:
|
||||
out := x.MustShallowClone()
|
||||
return out
|
||||
}
|
||||
}
|
||||
|
||||
func WithRandomVFlip(pvalue float64) Option {
|
||||
return func(o *Options) {
|
||||
vf := newRandomVerticalFlip(pvalue)
|
||||
o.randomVFlip = vf
|
||||
}
|
||||
}
|
1514
vision/aug/function.go
Normal file
1514
vision/aug/function.go
Normal file
File diff suppressed because it is too large
Load Diff
81
vision/aug/grayscale.go
Normal file
81
vision/aug/grayscale.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
// "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// GrayScale converts image to grayscale.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||
// Args:
|
||||
// - num_output_channels (int): (1 or 3) number of channels desired for output image
|
||||
type Grayscale struct {
|
||||
outChan int64
|
||||
}
|
||||
|
||||
func (gs *Grayscale) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
out := rgb2Gray(x, gs.outChan)
|
||||
return out
|
||||
}
|
||||
|
||||
func newGrayscale(outChanOpt ...int64) *Grayscale {
|
||||
var outChan int64 = 3
|
||||
if len(outChanOpt) > 0 {
|
||||
c := outChanOpt[0]
|
||||
switch c {
|
||||
case 1:
|
||||
outChan = 1
|
||||
case 3:
|
||||
outChan = 3
|
||||
default:
|
||||
log.Fatalf("Out channels should be either 1 or 3. Got %v\n", c)
|
||||
}
|
||||
}
|
||||
return &Grayscale{outChan}
|
||||
}
|
||||
|
||||
// RandomGrayscale randomly converts image to grayscale with a probability of p (default 0.1).
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||
// Args:
|
||||
// - p (float): probability that image should be converted to grayscale.
|
||||
type RandomGrayscale struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomGrayscale(pvalueOpt ...float64) *RandomGrayscale {
|
||||
pvalue := 0.1
|
||||
if len(pvalueOpt) > 0 {
|
||||
pvalue = pvalueOpt[0]
|
||||
}
|
||||
return &RandomGrayscale{pvalue}
|
||||
}
|
||||
|
||||
func (rgs *RandomGrayscale) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
c := getImageChanNum(x)
|
||||
r := randPvalue()
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < rgs.pvalue:
|
||||
out = rgb2Gray(x, c)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomGrayscale(pvalueOpt ...float64) Option {
|
||||
var p float64 = 0.1
|
||||
if len(pvalueOpt) > 0 {
|
||||
p = pvalueOpt[0]
|
||||
}
|
||||
|
||||
rgs := newRandomGrayscale(p)
|
||||
return func(o *Options) {
|
||||
o.randomGrayscale = rgs
|
||||
}
|
||||
}
|
39
vision/aug/invert.go
Normal file
39
vision/aug/invert.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type RandomInvert struct {
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
func newRandomInvert(pOpt ...float64) *RandomInvert {
|
||||
p := 0.5
|
||||
if len(pOpt) > 0 {
|
||||
p = pOpt[0]
|
||||
}
|
||||
return &RandomInvert{p}
|
||||
}
|
||||
|
||||
func (ri *RandomInvert) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
r := randPvalue()
|
||||
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < ri.pvalue:
|
||||
out = invert(x)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomInvert(pvalueOpt ...float64) Option {
|
||||
ri := newRandomInvert(pvalueOpt...)
|
||||
|
||||
return func(o *Options) {
|
||||
o.randomInvert = ri
|
||||
}
|
||||
}
|
91
vision/aug/normalize.go
Normal file
91
vision/aug/normalize.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Normalize normalizes a tensor image with mean and standard deviation.
|
||||
// Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
||||
// channels, this transform will normalize each channel of the input
|
||||
// ``torch.*Tensor`` i.e.,
|
||||
// ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||
// .. note::
|
||||
// This transform acts out of place, i.e., it does not mutate the input tensor.
|
||||
// Args:
|
||||
// - mean (sequence): Sequence of means for each channel.
|
||||
// - std (sequence): Sequence of standard deviations for each channel.
|
||||
type Normalize struct {
|
||||
mean []float64 // should be from 0 to 1
|
||||
std []float64 // should be > 0 and <= 1
|
||||
}
|
||||
|
||||
type normalizeOptions struct {
|
||||
mean []float64
|
||||
std []float64
|
||||
}
|
||||
|
||||
type normalizeOption func(*normalizeOptions)
|
||||
|
||||
// Mean and SD can be calculated for specific dataset as follow:
|
||||
/*
|
||||
mean = 0.0
|
||||
meansq = 0.0
|
||||
count = 0
|
||||
|
||||
for index, data in enumerate(train_loader):
|
||||
mean = data.sum()
|
||||
meansq = meansq + (data**2).sum()
|
||||
count += np.prod(data.shape)
|
||||
|
||||
total_mean = mean/count
|
||||
total_var = (meansq/count) - (total_mean**2)
|
||||
total_std = torch.sqrt(total_var)
|
||||
print("mean: " + str(total_mean))
|
||||
print("std: " + str(total_std))
|
||||
*/
|
||||
|
||||
// For example. ImageNet dataset has RGB mean and standard error:
|
||||
// meanVals := []float64{0.485, 0.456, 0.406}
|
||||
// sdVals := []float64{0.229, 0.224, 0.225}
|
||||
func defaultNormalizeOptions() *normalizeOptions {
|
||||
return &normalizeOptions{
|
||||
mean: []float64{0, 0, 0},
|
||||
std: []float64{1, 1, 1},
|
||||
}
|
||||
}
|
||||
|
||||
func WithNormalizeStd(std []float64) normalizeOption {
|
||||
return func(o *normalizeOptions) {
|
||||
o.std = std
|
||||
}
|
||||
}
|
||||
|
||||
func WithNormalizeMean(mean []float64) normalizeOption {
|
||||
return func(o *normalizeOptions) {
|
||||
o.mean = mean
|
||||
}
|
||||
}
|
||||
|
||||
func newNormalize(opts ...normalizeOption) *Normalize {
|
||||
p := defaultNormalizeOptions()
|
||||
for _, o := range opts {
|
||||
o(p)
|
||||
}
|
||||
|
||||
return &Normalize{
|
||||
mean: p.mean,
|
||||
std: p.std,
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Normalize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
out := normalize(x, n.mean, n.std)
|
||||
return out
|
||||
}
|
||||
|
||||
func WithNormalize(opts ...normalizeOption) Option {
|
||||
n := newNormalize(opts...)
|
||||
return func(o *Options) {
|
||||
o.normalize = n
|
||||
}
|
||||
}
|
1
vision/aug/pad.go
Normal file
1
vision/aug/pad.go
Normal file
|
@ -0,0 +1 @@
|
|||
package aug
|
190
vision/aug/perspective.go
Normal file
190
vision/aug/perspective.go
Normal file
|
@ -0,0 +1,190 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
// "fmt"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomPerspective performs a random perspective transformation of the given image with a given probability.
|
||||
// If the image is torch Tensor, it is expected
|
||||
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
|
||||
// Default is 0.5.
|
||||
// p (float): probability of the image being transformed. Default is 0.5.
|
||||
// interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
||||
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
||||
// For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
|
||||
// fill (sequence or number): Pixel fill value for the area outside the transformed
|
||||
// image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
||||
type RandomPerspective struct {
|
||||
distortionScale float64 // range [0, 1]
|
||||
pvalue float64 // range [0, 1]
|
||||
interpolationMode string
|
||||
fillValue []float64
|
||||
}
|
||||
|
||||
type perspectiveOptions struct {
|
||||
distortionScale float64 // range [0, 1]
|
||||
pvalue float64 // range [0, 1]
|
||||
interpolationMode string
|
||||
fillValue []float64
|
||||
}
|
||||
|
||||
func defaultPerspectiveOptions() *perspectiveOptions {
|
||||
return &perspectiveOptions{
|
||||
distortionScale: 0.5,
|
||||
pvalue: 0.5,
|
||||
interpolationMode: "bilinear",
|
||||
fillValue: []float64{0.0, 0.0, 0.0},
|
||||
}
|
||||
}
|
||||
|
||||
type perspectiveOption func(*perspectiveOptions)
|
||||
|
||||
func WithPerspectivePvalue(p float64) perspectiveOption {
|
||||
return func(o *perspectiveOptions) {
|
||||
o.pvalue = p
|
||||
}
|
||||
}
|
||||
|
||||
func WithPerspectiveScale(s float64) perspectiveOption {
|
||||
return func(o *perspectiveOptions) {
|
||||
o.distortionScale = s
|
||||
}
|
||||
}
|
||||
|
||||
func WithPerspectiveMode(m string) perspectiveOption {
|
||||
return func(o *perspectiveOptions) {
|
||||
o.interpolationMode = m
|
||||
}
|
||||
}
|
||||
|
||||
func WithPerspectiveValue(v []float64) perspectiveOption {
|
||||
return func(o *perspectiveOptions) {
|
||||
o.fillValue = v
|
||||
}
|
||||
}
|
||||
|
||||
func newRandomPerspective(opts ...perspectiveOption) *RandomPerspective {
|
||||
params := defaultPerspectiveOptions()
|
||||
for _, opt := range opts {
|
||||
opt(params)
|
||||
}
|
||||
|
||||
return &RandomPerspective{
|
||||
distortionScale: params.distortionScale,
|
||||
pvalue: params.pvalue,
|
||||
interpolationMode: params.interpolationMode,
|
||||
fillValue: params.fillValue,
|
||||
}
|
||||
}
|
||||
|
||||
// Get parameters for ``perspective`` for a random perspective transform.
|
||||
//
|
||||
// Args:
|
||||
// - width (int): width of the image.
|
||||
// - height (int): height of the image.
|
||||
// Returns:
|
||||
// - List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
|
||||
// - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
|
||||
func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
||||
halfH := h / 2
|
||||
halfW := w / 2
|
||||
|
||||
var (
|
||||
topLeft []int64
|
||||
topRight []int64
|
||||
bottomRight []int64
|
||||
bottomLeft []int64
|
||||
)
|
||||
|
||||
// topleft = [
|
||||
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
|
||||
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||
// ]
|
||||
tlVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||
tlTs1 := ts.MustRandint1(0, tlVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tl1 := tlTs1.Int64Values()[0]
|
||||
tlTs1.MustDrop()
|
||||
tlVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||
tlTs2 := ts.MustRandint1(0, tlVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tl2 := tlTs2.Int64Values()[0]
|
||||
tlTs2.MustDrop()
|
||||
topLeft = []int64{tl1, tl2}
|
||||
|
||||
// topright = [
|
||||
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
|
||||
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||
// ]
|
||||
trVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||
trTs1 := ts.MustRandint1(trVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tr1 := trTs1.Int64Values()[0]
|
||||
trTs1.MustDrop()
|
||||
trVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||
trTs2 := ts.MustRandint1(0, trVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
tr2 := trTs2.Int64Values()[0]
|
||||
trTs2.MustDrop()
|
||||
topRight = []int64{tr1, tr2}
|
||||
|
||||
// botright = [
|
||||
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
|
||||
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||
// ]
|
||||
brVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||
brTs1 := ts.MustRandint1(brVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
br1 := brTs1.Int64Values()[0]
|
||||
brTs1.MustDrop()
|
||||
brVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||
brTs2 := ts.MustRandint1(brVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
br2 := brTs2.Int64Values()[0]
|
||||
brTs2.MustDrop()
|
||||
bottomRight = []int64{br1, br2}
|
||||
|
||||
// botleft = [
|
||||
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
|
||||
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||
// ]
|
||||
blVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||
blTs1 := ts.MustRandint1(0, blVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
bl1 := blTs1.Int64Values()[0]
|
||||
blTs1.MustDrop()
|
||||
blVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||
blTs2 := ts.MustRandint1(blVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||
bl2 := blTs2.Int64Values()[0]
|
||||
blTs2.MustDrop()
|
||||
bottomLeft = []int64{bl1, bl2}
|
||||
|
||||
startPoints := [][]int64{
|
||||
{0, 0},
|
||||
{w - 1, 0},
|
||||
{w - 1, h - 1},
|
||||
{0, h - 1},
|
||||
}
|
||||
|
||||
endPoints := [][]int64{
|
||||
topLeft,
|
||||
topRight,
|
||||
bottomRight,
|
||||
bottomLeft,
|
||||
}
|
||||
|
||||
return startPoints, endPoints
|
||||
}
|
||||
|
||||
func (rp *RandomPerspective) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
height, width := getImageSize(x)
|
||||
startPoints, endPoints := rp.getParams(height, width)
|
||||
out := perspective(x, startPoints, endPoints, rp.interpolationMode, rp.fillValue)
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomPerspective(opts ...perspectiveOption) Option {
|
||||
rp := newRandomPerspective(opts...)
|
||||
return func(o *Options) {
|
||||
o.randomPerspective = rp
|
||||
}
|
||||
}
|
77
vision/aug/posterize.go
Normal file
77
vision/aug/posterize.go
Normal file
|
@ -0,0 +1,77 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomPosterize posterizes the image randomly with a given probability by reducing the
|
||||
// number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
|
||||
// and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// - bits (int): number of bits to keep for each channel (0-8)
|
||||
// - p (float): probability of the image being color inverted. Default value is 0.5
|
||||
// Ref. https://en.wikipedia.org/wiki/Posterization
|
||||
type RandomPosterize struct {
|
||||
pvalue float64
|
||||
bits uint8
|
||||
}
|
||||
|
||||
type posterizeOptions struct {
|
||||
pvalue float64
|
||||
bits uint8
|
||||
}
|
||||
|
||||
type posterizeOption func(*posterizeOptions)
|
||||
|
||||
func defaultPosterizeOptions() *posterizeOptions {
|
||||
return &posterizeOptions{
|
||||
pvalue: 0.5,
|
||||
bits: 4,
|
||||
}
|
||||
}
|
||||
|
||||
func WithPosterizePvalue(p float64) posterizeOption {
|
||||
return func(o *posterizeOptions) {
|
||||
o.pvalue = p
|
||||
}
|
||||
}
|
||||
|
||||
func WithPosterizeBits(bits uint8) posterizeOption {
|
||||
return func(o *posterizeOptions) {
|
||||
o.bits = bits
|
||||
}
|
||||
}
|
||||
|
||||
func newRandomPosterize(opts ...posterizeOption) *RandomPosterize {
|
||||
p := defaultPosterizeOptions()
|
||||
for _, o := range opts {
|
||||
o(p)
|
||||
}
|
||||
|
||||
return &RandomPosterize{
|
||||
pvalue: p.pvalue,
|
||||
bits: p.bits,
|
||||
}
|
||||
}
|
||||
|
||||
func (rp *RandomPosterize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
|
||||
r := randPvalue()
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < rp.pvalue:
|
||||
out = posterize(x, rp.bits)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomPosterize(opts ...posterizeOption) Option {
|
||||
rp := newRandomPosterize(opts...)
|
||||
|
||||
return func(o *Options) {
|
||||
o.randomPosterize = rp
|
||||
}
|
||||
}
|
39
vision/aug/resize.go
Normal file
39
vision/aug/resize.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
type ResizeModule struct {
|
||||
height int64
|
||||
width int64
|
||||
}
|
||||
|
||||
func newResizeModule(h, w int64) *ResizeModule {
|
||||
return &ResizeModule{h, w}
|
||||
}
|
||||
|
||||
// Forward implements ts.Module for RandRotateModule
|
||||
func (rs *ResizeModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
imgTs := x.MustTotype(gotch.Uint8, false)
|
||||
out, err := vision.Resize(imgTs, rs.width, rs.height)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
imgTs.MustDrop()
|
||||
return out
|
||||
}
|
||||
|
||||
func WithResize(h, w int64) Option {
|
||||
return func(o *Options) {
|
||||
rs := newResizeModule(h, w)
|
||||
o.resize = rs
|
||||
}
|
||||
}
|
||||
|
||||
// TODO.
|
||||
type RandomResizedCrop struct{}
|
109
vision/aug/rotate.go
Normal file
109
vision/aug/rotate.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomRotate randomly rotates a tensor image within a specifed angle range (degree).
|
||||
func RandomRotate(img *ts.Tensor, min, max float64) (*ts.Tensor, error) {
|
||||
if min > max {
|
||||
tmp := min
|
||||
min = max
|
||||
max = tmp
|
||||
}
|
||||
if min < -360 || min > 360 || max < -360 || max > 360 {
|
||||
err := fmt.Errorf("min and max should be in range from -360 to 360. Got %v and %v\n", min, max)
|
||||
return nil, err
|
||||
}
|
||||
// device := img.MustDevice()
|
||||
dtype := gotch.Double
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
angle := min + rand.Float64()*(max-min)
|
||||
|
||||
theta := float64(angle) * (math.Pi / 180)
|
||||
input := img.MustUnsqueeze(0, false).MustTotype(dtype, true)
|
||||
r, err := rotImg(input, theta, dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input.MustDrop()
|
||||
rotatedImg := r.MustSqueeze(true)
|
||||
return rotatedImg, nil
|
||||
}
|
||||
|
||||
func Rotate(img *ts.Tensor, angle float64) (*ts.Tensor, error) {
|
||||
if angle < -360 || angle > 360 {
|
||||
err := fmt.Errorf("angle must be in range (-360, 360)")
|
||||
return nil, err
|
||||
}
|
||||
dtype := gotch.Double
|
||||
theta := float64(angle) * (math.Pi / 180)
|
||||
input := img.MustUnsqueeze(0, false).MustTotype(dtype, true)
|
||||
r, err := rotImg(input, theta, dtype)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
input.MustDrop()
|
||||
rotatedImg := r.MustSqueeze(true)
|
||||
return rotatedImg, nil
|
||||
}
|
||||
|
||||
// RotateModule
|
||||
type RotateModule struct {
|
||||
angle float64
|
||||
}
|
||||
|
||||
func newRotate(angle float64) *RotateModule {
|
||||
return &RotateModule{angle}
|
||||
}
|
||||
|
||||
// Forward implements ts.Module for RotateModule
|
||||
func (r *RotateModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
out, err := Rotate(x, r.angle)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRotate(angle float64) Option {
|
||||
return func(o *Options) {
|
||||
r := newRotate(angle)
|
||||
o.rotate = r
|
||||
}
|
||||
}
|
||||
|
||||
// RandomRotateModule
|
||||
type RandRotateModule struct {
|
||||
minAngle float64
|
||||
maxAngle float64
|
||||
}
|
||||
|
||||
func newRandRotate(min, max float64) *RandRotateModule {
|
||||
return &RandRotateModule{min, max}
|
||||
}
|
||||
|
||||
// Forward implements ts.Module for RandRotateModule
|
||||
func (rr *RandRotateModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
out, err := RandomRotate(x, rr.minAngle, rr.maxAngle)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandRotate(minAngle, maxAngle float64) Option {
|
||||
return func(o *Options) {
|
||||
r := newRandRotate(minAngle, maxAngle)
|
||||
o.randRotate = r
|
||||
}
|
||||
}
|
74
vision/aug/sharpness.go
Normal file
74
vision/aug/sharpness.go
Normal file
|
@ -0,0 +1,74 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
|
||||
// it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||
// Args:
|
||||
// sharpness_factor (float): How much to adjust the sharpness. Can be
|
||||
// any non negative number. 0 gives a blurred image, 1 gives the
|
||||
// original image while 2 increases the sharpness by a factor of 2.
|
||||
// p (float): probability of the image being color inverted. Default value is 0.5
|
||||
type RandomAdjustSharpness struct {
|
||||
sharpnessFactor float64
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
type sharpnessOptions struct {
|
||||
sharpnessFactor float64
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
type sharpnessOption func(*sharpnessOptions)
|
||||
|
||||
func defaultSharpnessOptions() *sharpnessOptions {
|
||||
return &sharpnessOptions{
|
||||
sharpnessFactor: 1.0,
|
||||
pvalue: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
func WithSharpnessPvalue(p float64) sharpnessOption {
|
||||
return func(o *sharpnessOptions) {
|
||||
o.pvalue = p
|
||||
}
|
||||
}
|
||||
|
||||
func WithSharpnessFactor(f float64) sharpnessOption {
|
||||
return func(o *sharpnessOptions) {
|
||||
o.sharpnessFactor = f
|
||||
}
|
||||
}
|
||||
|
||||
func newRandomAdjustSharpness(opts ...sharpnessOption) *RandomAdjustSharpness {
|
||||
p := defaultSharpnessOptions()
|
||||
for _, o := range opts {
|
||||
o(p)
|
||||
}
|
||||
return &RandomAdjustSharpness{
|
||||
sharpnessFactor: p.sharpnessFactor,
|
||||
pvalue: p.pvalue,
|
||||
}
|
||||
}
|
||||
|
||||
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
r := randPvalue()
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < ras.pvalue:
|
||||
out = adjustSharpness(x, ras.sharpnessFactor)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {
|
||||
ras := newRandomAdjustSharpness(opts...)
|
||||
return func(o *Options) {
|
||||
o.randomAdjustSharpness = ras
|
||||
}
|
||||
}
|
79
vision/aug/solarize.go
Normal file
79
vision/aug/solarize.go
Normal file
|
@ -0,0 +1,79 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// RandomSolarize solarizes the image randomly with a given probability by inverting all pixel
|
||||
// values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
||||
// where ... means it can have an arbitrary number of leading dimensions.
|
||||
// If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
||||
// Args:
|
||||
// - threshold (float): all pixels equal or above this value are inverted.
|
||||
// - p (float): probability of the image being color inverted. Default value is 0.5
|
||||
// Ref. https://en.wikipedia.org/wiki/Solarization_(photography)
|
||||
type RandomSolarize struct {
|
||||
threshold float64
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
type solarizeOptions struct {
|
||||
threshold float64
|
||||
pvalue float64
|
||||
}
|
||||
|
||||
type solarizeOption func(*solarizeOptions)
|
||||
|
||||
func defaultSolarizeOptions() *solarizeOptions {
|
||||
return &solarizeOptions{
|
||||
threshold: 128,
|
||||
pvalue: 0.5,
|
||||
}
|
||||
}
|
||||
|
||||
func WithSolarizePvalue(p float64) solarizeOption {
|
||||
return func(o *solarizeOptions) {
|
||||
o.pvalue = p
|
||||
}
|
||||
}
|
||||
|
||||
func WithSolarizeThreshold(th float64) solarizeOption {
|
||||
return func(o *solarizeOptions) {
|
||||
o.threshold = th
|
||||
}
|
||||
}
|
||||
|
||||
func newRandomSolarize(opts ...solarizeOption) *RandomSolarize {
|
||||
params := defaultSolarizeOptions()
|
||||
|
||||
for _, o := range opts {
|
||||
o(params)
|
||||
}
|
||||
|
||||
return &RandomSolarize{
|
||||
threshold: params.threshold,
|
||||
pvalue: params.pvalue,
|
||||
}
|
||||
}
|
||||
|
||||
func (rs *RandomSolarize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
r := randPvalue()
|
||||
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < rs.pvalue:
|
||||
out = solarize(x, rs.threshold)
|
||||
default:
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomSolarize(opts ...solarizeOption) Option {
|
||||
rs := newRandomSolarize(opts...)
|
||||
|
||||
return func(o *Options) {
|
||||
o.randomSolarize = rs
|
||||
}
|
||||
}
|
188
vision/aug/transform.go
Normal file
188
vision/aug/transform.go
Normal file
|
@ -0,0 +1,188 @@
|
|||
package aug
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/sugarme/gotch/nn"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Transformer is an interface that can transform an image tensor.
|
||||
type Transformer interface {
|
||||
Transform(x *ts.Tensor) *ts.Tensor
|
||||
}
|
||||
|
||||
// Augment is a struct composes of augmentation functions to implement Transformer interface.
|
||||
type Augment struct {
|
||||
augments *nn.Sequential
|
||||
}
|
||||
|
||||
// Transform implements Transformer interface for Augment struct.
|
||||
func (a *Augment) Transform(image *ts.Tensor) *ts.Tensor {
|
||||
out := a.augments.Forward(image)
|
||||
return out
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
rotate *RotateModule
|
||||
randRotate *RandRotateModule
|
||||
resize *ResizeModule
|
||||
colorJitter *ColorJitter
|
||||
gaussianBlur *GaussianBlur
|
||||
randomHFlip *RandomHorizontalFlip
|
||||
randomVFlip *RandomVerticalFlip
|
||||
randomCrop *RandomCrop
|
||||
centerCrop *CenterCrop
|
||||
randomCutout *RandomCutout
|
||||
randomPerspective *RandomPerspective
|
||||
randomAffine *RandomAffine
|
||||
randomGrayscale *RandomGrayscale
|
||||
randomSolarize *RandomSolarize
|
||||
randomPosterize *RandomPosterize
|
||||
randomInvert *RandomInvert
|
||||
randomAutocontrast *RandomAutocontrast
|
||||
randomAdjustSharpness *RandomAdjustSharpness
|
||||
randomEqualize *RandomEqualize
|
||||
normalize *Normalize
|
||||
}
|
||||
|
||||
func defaultOption() *Options {
|
||||
return &Options{
|
||||
rotate: nil,
|
||||
randRotate: nil,
|
||||
resize: nil,
|
||||
colorJitter: nil,
|
||||
gaussianBlur: nil,
|
||||
randomHFlip: nil,
|
||||
randomVFlip: nil,
|
||||
randomCrop: nil,
|
||||
centerCrop: nil,
|
||||
randomCutout: nil,
|
||||
randomPerspective: nil,
|
||||
randomAffine: nil,
|
||||
randomGrayscale: nil,
|
||||
randomSolarize: nil,
|
||||
randomPosterize: nil,
|
||||
randomInvert: nil,
|
||||
randomAutocontrast: nil,
|
||||
randomAdjustSharpness: nil,
|
||||
randomEqualize: nil,
|
||||
normalize: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type Option func(o *Options)
|
||||
|
||||
// Compose creates a new Augment struct by adding augmentation methods.
|
||||
func Compose(opts ...Option) (Transformer, error) {
|
||||
augOpts := defaultOption()
|
||||
for _, opt := range opts {
|
||||
if opt != nil {
|
||||
opt(augOpts)
|
||||
}
|
||||
}
|
||||
|
||||
var augs *nn.Sequential = nn.Seq()
|
||||
|
||||
if augOpts.rotate != nil {
|
||||
augs.Add(augOpts.rotate)
|
||||
}
|
||||
|
||||
if augOpts.randRotate != nil {
|
||||
augs.Add(augOpts.randRotate)
|
||||
}
|
||||
|
||||
if augOpts.resize != nil {
|
||||
augs.Add(augOpts.resize)
|
||||
}
|
||||
|
||||
if augOpts.colorJitter != nil {
|
||||
augs.Add(augOpts.colorJitter)
|
||||
}
|
||||
|
||||
if augOpts.gaussianBlur != nil {
|
||||
augs.Add(augOpts.gaussianBlur)
|
||||
}
|
||||
|
||||
if augOpts.randomHFlip != nil {
|
||||
augs.Add(augOpts.randomHFlip)
|
||||
}
|
||||
|
||||
if augOpts.randomVFlip != nil {
|
||||
augs.Add(augOpts.randomVFlip)
|
||||
}
|
||||
|
||||
if augOpts.randomCrop != nil {
|
||||
augs.Add(augOpts.randomCrop)
|
||||
}
|
||||
|
||||
if augOpts.centerCrop != nil {
|
||||
augs.Add(augOpts.centerCrop)
|
||||
}
|
||||
|
||||
if augOpts.randomCutout != nil {
|
||||
augs.Add(augOpts.randomCutout)
|
||||
}
|
||||
|
||||
if augOpts.randomPerspective != nil {
|
||||
augs.Add(augOpts.randomPerspective)
|
||||
}
|
||||
|
||||
if augOpts.randomAffine != nil {
|
||||
augs.Add(augOpts.randomAffine)
|
||||
}
|
||||
|
||||
if augOpts.randomGrayscale != nil {
|
||||
augs.Add(augOpts.randomGrayscale)
|
||||
}
|
||||
|
||||
if augOpts.randomSolarize != nil {
|
||||
augs.Add(augOpts.randomSolarize)
|
||||
}
|
||||
|
||||
if augOpts.randomPosterize != nil {
|
||||
augs.Add(augOpts.randomPosterize)
|
||||
}
|
||||
|
||||
if augOpts.randomInvert != nil {
|
||||
augs.Add(augOpts.randomInvert)
|
||||
}
|
||||
|
||||
if augOpts.randomAutocontrast != nil {
|
||||
augs.Add(augOpts.randomAutocontrast)
|
||||
}
|
||||
|
||||
if augOpts.randomAdjustSharpness != nil {
|
||||
augs.Add(augOpts.randomAdjustSharpness)
|
||||
}
|
||||
|
||||
if augOpts.randomEqualize != nil {
|
||||
augs.Add(augOpts.randomEqualize)
|
||||
}
|
||||
|
||||
if augOpts.normalize != nil {
|
||||
augs.Add(augOpts.normalize)
|
||||
}
|
||||
|
||||
return &Augment{augs}, nil
|
||||
}
|
||||
|
||||
// OneOf randomly return one transformer from list of transformers
|
||||
// with a specific p value.
|
||||
func OneOf(pvalue float64, tfOpts ...Option) Option {
|
||||
tfsNum := len(tfOpts)
|
||||
if tfsNum < 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
randP := randPvalue()
|
||||
if randP >= pvalue {
|
||||
return nil
|
||||
}
|
||||
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
idx := rand.Intn(tfsNum)
|
||||
|
||||
return tfOpts[idx]
|
||||
}
|
Loading…
Reference in New Issue
Block a user