added image augmentation and minor fixed on ts.Lstsq
This commit is contained in:
parent
fe6454c0ca
commit
7292c3575e
|
@ -6,9 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
|
## [0.3.9]
|
||||||
- [#24], [#26]: fixed memory leak.
|
- [#24], [#26]: fixed memory leak.
|
||||||
- [#30]: fixed varstore.Save() randomly panic - segmentfault
|
- [#30]: fixed varstore.Save() randomly panic - segmentfault
|
||||||
- [#32]: nn.Seq Forward return nil tensor if length of layers = 1
|
- [#32]: nn.Seq Forward return nil tensor if length of layers = 1
|
||||||
|
- [#36]: resolved image augmentation
|
||||||
|
|
||||||
## [0.3.8]
|
## [0.3.8]
|
||||||
|
|
||||||
|
|
31
example/augmentation/README.md
Normal file
31
example/augmentation/README.md
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Image Augmentation Example
|
||||||
|
|
||||||
|
This example demonstrates how to use image augmentation functions. It is implemented as similar as possible to [original Pytorch vision/transform](https://pytorch.org/vision/stable/transforms.html#).
|
||||||
|
|
||||||
|
There are 2 APIs (`aug.Compose` and `aug.OneOf`) to compose augmentation methods as shown in the example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
t, err := aug.Compose(
|
||||||
|
aug.WithRandomVFlip(0.5),
|
||||||
|
aug.WithRandomHFlip(0.5),
|
||||||
|
aug.WithRandomCutout(),
|
||||||
|
aug.OneOf(
|
||||||
|
0.3,
|
||||||
|
aug.WithColorJitter(0.3, 0.3, 0.3, 0.4),
|
||||||
|
aug.WithRandomGrayscale(1.0),
|
||||||
|
),
|
||||||
|
aug.OneOf(
|
||||||
|
0.3,
|
||||||
|
aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}),
|
||||||
|
aug.WithRandomAffine(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := t.Transform(imgTs)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
BIN
example/augmentation/bb.png
Normal file
BIN
example/augmentation/bb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
69
example/augmentation/main.go
Normal file
69
example/augmentation/main.go
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
"github.com/sugarme/gotch/vision/aug"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
n := 360
|
||||||
|
for i := 1; i <= n; i++ {
|
||||||
|
img, err := vision.Load("./bb.png")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
device := gotch.CudaIfAvailable()
|
||||||
|
// device := gotch.CPU
|
||||||
|
imgTs := img.MustTo(device, true)
|
||||||
|
// t, err := aug.Compose(aug.WithResize(512, 512)) // NOTE. WithResize just works on CPU.
|
||||||
|
// t, err := aug.Compose(aug.WithRandRotate(0, 360), aug.WithColorJitter(0.3, 0.3, 0.3, 0.4))
|
||||||
|
// t, err := aug.Compose(aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}), aug.WithRandRotate(0, 360), aug.WithColorJitter(0.3, 0.3, 0.3, 0.3))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomCrop([]int64{320, 320}, []int64{10, 10}, true, "constant"))
|
||||||
|
// t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320}))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5})))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomPerspective(aug.WithPerspectiveScale(0.6), aug.WithPerspectivePvalue(0.8)))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15})))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomGrayscale(0.5))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(0.5)))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomInvert(0.5))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomPosterize(aug.WithPosterizeBits(2), aug.WithPosterizePvalue(1.0)))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomAutocontrast())
|
||||||
|
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(0.3), aug.WithSharpnessFactor(10)))
|
||||||
|
// t, err := aug.Compose(aug.WithRandomEqualize(1.0))
|
||||||
|
// t, err := aug.Compose(aug.WithNormalize(aug.WithNormalizeMean([]float64{0.485, 0.456, 0.406}), aug.WithNormalizeStd([]float64{0.229, 0.224, 0.225})))
|
||||||
|
|
||||||
|
t, err := aug.Compose(
|
||||||
|
aug.WithRandomVFlip(0.5),
|
||||||
|
aug.WithRandomHFlip(0.5),
|
||||||
|
aug.WithRandomCutout(),
|
||||||
|
aug.OneOf(
|
||||||
|
0.3,
|
||||||
|
aug.WithColorJitter(0.3, 0.3, 0.3, 0.4),
|
||||||
|
aug.WithRandomGrayscale(1.0),
|
||||||
|
),
|
||||||
|
aug.OneOf(
|
||||||
|
0.3,
|
||||||
|
aug.WithGaussianBlur([]int64{5, 5}, []float64{1.0, 2.0}),
|
||||||
|
aug.WithRandomAffine(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
out := t.Transform(imgTs)
|
||||||
|
fname := fmt.Sprintf("./output/bb-%03d.png", i)
|
||||||
|
err = vision.Save(out, fname)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
imgTs.MustDrop()
|
||||||
|
out.MustDrop()
|
||||||
|
|
||||||
|
fmt.Printf("%03d/%v completed.\n", i, n)
|
||||||
|
}
|
||||||
|
}
|
3
example/augmentation/output/.gitignore
vendored
Normal file
3
example/augmentation/output/.gitignore
vendored
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
*
|
||||||
|
!.gitignore
|
||||||
|
!README.md
|
1
example/augmentation/output/README.md
Normal file
1
example/augmentation/output/README.md
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Output images will be here.
|
|
@ -581,7 +581,7 @@ func (ts *Tensor) Lstsq(a *Tensor, del bool) (retVal *Tensor, err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
func (ts *Tensor) MustLstsq(a *Tensor, del bool) (retVal *Tensor) {
|
||||||
retVal, err := ts.Lstsq(del)
|
retVal, err := ts.Lstsq(a, del)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
185
vision/aug/affine.go
Normal file
185
vision/aug/affine.go
Normal file
|
@ -0,0 +1,185 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomAffine is transformation of the image keeping center invariant.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// - degrees (sequence or number): Range of degrees to select from.
|
||||||
|
// If degrees is a number instead of sequence like (min, max), the range of degrees
|
||||||
|
// will be (-degrees, +degrees). Set to 0 to deactivate rotations.
|
||||||
|
// - translate (tuple, optional): tuple of maximum absolute fraction for horizontal
|
||||||
|
// and vertical translations. For example translate=(a, b), then horizontal shift
|
||||||
|
// is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
|
||||||
|
// randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
|
||||||
|
// - scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
|
||||||
|
// randomly sampled from the range a <= scale <= b. Will keep original scale by default.
|
||||||
|
// - shear (sequence or number, optional): Range of degrees to select from.
|
||||||
|
// If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
|
||||||
|
// will be applied. Else if shear is a sequence of 2 values a shear parallel to the x axis in the
|
||||||
|
// range (shear[0], shear[1]) will be applied. Else if shear is a sequence of 4 values,
|
||||||
|
// a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
|
||||||
|
// Will not apply shear by default.
|
||||||
|
// - interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||||
|
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
|
||||||
|
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
||||||
|
// For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
|
||||||
|
// - fill (sequence or number): Pixel fill value for the area outside the transformed
|
||||||
|
// image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
||||||
|
// Please use the ``interpolation`` parameter instead.
|
||||||
|
// .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
|
||||||
|
type RandomAffine struct {
|
||||||
|
degree []int64 // degree range
|
||||||
|
translate []float64
|
||||||
|
scale []float64 // scale range
|
||||||
|
shear []float64
|
||||||
|
interpolationMode string
|
||||||
|
fillValue []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ra *RandomAffine) getParams(imageSize []int64) (float64, []int64, float64, []float64) {
|
||||||
|
angleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
angleTs.MustUniform_(float64(ra.degree[0]), float64(ra.degree[1]))
|
||||||
|
angle := angleTs.Float64Values()[0]
|
||||||
|
angleTs.MustDrop()
|
||||||
|
|
||||||
|
var translations []int64 = []int64{0, 0}
|
||||||
|
if ra.translate != nil {
|
||||||
|
maxDX := ra.translate[0] * float64(imageSize[0])
|
||||||
|
maxDY := ra.translate[1] * float64(imageSize[1])
|
||||||
|
dx := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
dx.MustUniform_(-maxDX, maxDX)
|
||||||
|
tx := dx.Float64Values()[0]
|
||||||
|
dx.MustDrop()
|
||||||
|
|
||||||
|
dy := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
dy.MustUniform_(-maxDY, maxDY)
|
||||||
|
ty := dx.Float64Values()[0]
|
||||||
|
dy.MustDrop()
|
||||||
|
|
||||||
|
translations = []int64{int64(tx), int64(ty)} // should we use math.Round here???
|
||||||
|
}
|
||||||
|
|
||||||
|
scale := 1.0
|
||||||
|
if ra.scale != nil {
|
||||||
|
scaleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
scaleTs.MustUniform_(ra.scale[0], ra.scale[1])
|
||||||
|
scale = scaleTs.Float64Values()[0]
|
||||||
|
scaleTs.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
shearX, shearY float64 = 0.0, 0.0
|
||||||
|
)
|
||||||
|
if ra.shear != nil {
|
||||||
|
shearXTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
shearXTs.MustUniform_(ra.shear[0], ra.shear[1])
|
||||||
|
shearX = shearXTs.Float64Values()[0]
|
||||||
|
shearXTs.MustDrop()
|
||||||
|
|
||||||
|
if len(ra.shear) == 4 {
|
||||||
|
shearYTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
shearYTs.MustUniform_(ra.shear[2], ra.shear[3])
|
||||||
|
shearY = shearYTs.Float64Values()[0]
|
||||||
|
shearYTs.MustDrop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var shear []float64 = []float64{shearX, shearY}
|
||||||
|
|
||||||
|
return angle, translations, scale, shear
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ra *RandomAffine) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
w, h := getImageSize(x)
|
||||||
|
angle, translations, scale, shear := ra.getParams([]int64{w, h})
|
||||||
|
|
||||||
|
out := affine(x, angle, translations, scale, shear, ra.interpolationMode, ra.fillValue)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomAffine(opts ...affineOption) *RandomAffine {
|
||||||
|
p := defaultAffineOptions()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomAffine{
|
||||||
|
degree: p.degree,
|
||||||
|
translate: p.translate,
|
||||||
|
scale: p.scale,
|
||||||
|
shear: p.shear,
|
||||||
|
interpolationMode: p.interpolationMode,
|
||||||
|
fillValue: p.fillValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type affineOptions struct {
|
||||||
|
degree []int64
|
||||||
|
translate []float64
|
||||||
|
scale []float64
|
||||||
|
shear []float64
|
||||||
|
interpolationMode string
|
||||||
|
fillValue []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type affineOption func(*affineOptions)
|
||||||
|
|
||||||
|
func defaultAffineOptions() *affineOptions {
|
||||||
|
return &affineOptions{
|
||||||
|
degree: []int64{-180, 180},
|
||||||
|
translate: nil,
|
||||||
|
scale: nil,
|
||||||
|
shear: []float64{-180.0, 180.0},
|
||||||
|
interpolationMode: "bilinear",
|
||||||
|
fillValue: []float64{0.0, 0.0, 0.0},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineDegree(degree []int64) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.degree = degree
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineTranslate(translate []float64) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.translate = translate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineScale(scale []float64) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.scale = scale
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineShear(shear []float64) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.shear = shear
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineMode(mode string) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.interpolationMode = mode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAffineFillValue(fillValue []float64) affineOption {
|
||||||
|
return func(o *affineOptions) {
|
||||||
|
o.fillValue = fillValue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomAffine(opts ...affineOption) Option {
|
||||||
|
ra := newRandomAffine(opts...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomAffine = ra
|
||||||
|
}
|
||||||
|
}
|
89
vision/aug/blur.go
Normal file
89
vision/aug/blur.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GaussianBlur struct {
|
||||||
|
kernelSize []int64 // >= 0 && ks%2 != 0
|
||||||
|
sigma []float64 // [0.1, 2.0] range(min, max)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ks : kernal size. Can be 1-2 element slice
|
||||||
|
// sigma: minimal and maximal standard deviation that can be chosen for blurring kernel
|
||||||
|
// range (min, max). Can be 1-2 element slice
|
||||||
|
func newGaussianBlur(ks []int64, sig []float64) *GaussianBlur {
|
||||||
|
if len(ks) == 0 || len(ks) > 2 {
|
||||||
|
err := fmt.Errorf("Kernel size should have 1-2 elements. Got %v\n", len(ks))
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
for _, size := range ks {
|
||||||
|
if size <= 0 || size%2 == 0 {
|
||||||
|
err := fmt.Errorf("Kernel size should be an odd and positive number.")
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(sig) == 0 || len(sig) > 2 {
|
||||||
|
err := fmt.Errorf("Sigma should have 1-2 elements. Got %v\n", len(sig))
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range sig {
|
||||||
|
if s <= 0 {
|
||||||
|
err := fmt.Errorf("Sigma should be a positive number.")
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var kernelSize []int64
|
||||||
|
switch len(ks) {
|
||||||
|
case 1:
|
||||||
|
kernelSize = []int64{ks[0], ks[0]}
|
||||||
|
case 2:
|
||||||
|
kernelSize = ks
|
||||||
|
default:
|
||||||
|
panic("Shouldn't reach here.")
|
||||||
|
}
|
||||||
|
|
||||||
|
var sigma []float64
|
||||||
|
switch len(sig) {
|
||||||
|
case 1:
|
||||||
|
sigma = []float64{sig[0], sig[0]}
|
||||||
|
case 2:
|
||||||
|
min := sig[0]
|
||||||
|
max := sig[1]
|
||||||
|
if min > max {
|
||||||
|
min = sig[1]
|
||||||
|
max = sig[0]
|
||||||
|
}
|
||||||
|
sigma = []float64{min, max}
|
||||||
|
default:
|
||||||
|
panic("Shouldn't reach here.")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &GaussianBlur{
|
||||||
|
kernelSize: kernelSize,
|
||||||
|
sigma: sigma,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *GaussianBlur) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
sigmaTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
sigmaTs.MustUniform_(b.sigma[0], b.sigma[1])
|
||||||
|
sigmaVal := sigmaTs.Float64Values()[0]
|
||||||
|
sigmaTs.MustDrop()
|
||||||
|
|
||||||
|
return gaussianBlur(x, b.kernelSize, []float64{sigmaVal, sigmaVal})
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithGaussianBlur(ks []int64, sig []float64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
gb := newGaussianBlur(ks, sig)
|
||||||
|
o.gaussianBlur = gb
|
||||||
|
}
|
||||||
|
}
|
77
vision/aug/color.go
Normal file
77
vision/aug/color.go
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ref. https://github.com/pytorch/vision/blob/f1d734213af65dc06e777877d315973ba8386080/torchvision/transforms/functional_tensor.py
|
||||||
|
|
||||||
|
type ColorJitter struct {
|
||||||
|
brightness float64
|
||||||
|
contrast float64
|
||||||
|
saturation float64
|
||||||
|
hue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultColorJitter() *ColorJitter {
|
||||||
|
return &ColorJitter{
|
||||||
|
brightness: 1.0,
|
||||||
|
contrast: 1.0,
|
||||||
|
saturation: 1.0,
|
||||||
|
hue: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ColorJitter) setBrightness(brightness float64) {
|
||||||
|
c.brightness = brightness
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ColorJitter) setContrast(contrast float64) {
|
||||||
|
c.contrast = contrast
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ColorJitter) setSaturation(sat float64) {
|
||||||
|
c.saturation = sat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ColorJitter) setHue(hue float64) {
|
||||||
|
c.hue = hue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implement ts.Module by randomly picking one of brightness, contrast,
|
||||||
|
// staturation or hue function to transform input image tensor.
|
||||||
|
func (c *ColorJitter) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
idx := rand.Intn(4)
|
||||||
|
switch idx {
|
||||||
|
case 0:
|
||||||
|
v := randVal(getMinMax(c.brightness))
|
||||||
|
return adjustBrightness(x, v)
|
||||||
|
case 1:
|
||||||
|
v := randVal(getMinMax(c.contrast))
|
||||||
|
return adjustContrast(x, v)
|
||||||
|
case 2:
|
||||||
|
v := randVal(getMinMax(c.saturation))
|
||||||
|
return adjustSaturation(x, v)
|
||||||
|
case 3:
|
||||||
|
v := randVal(0, c.hue)
|
||||||
|
return adjustHue(x, v)
|
||||||
|
default:
|
||||||
|
panic("Shouldn't reach here.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithColorJitter(brightness, contrast, sat, hue float64) Option {
|
||||||
|
c := defaultColorJitter()
|
||||||
|
c.setBrightness(brightness)
|
||||||
|
c.setContrast(contrast)
|
||||||
|
c.setSaturation(sat)
|
||||||
|
c.setHue(hue)
|
||||||
|
|
||||||
|
return func(o *Options) {
|
||||||
|
o.colorJitter = c
|
||||||
|
}
|
||||||
|
}
|
43
vision/aug/contrast.go
Normal file
43
vision/aug/contrast.go
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomAutocontrast autocontrasts the pixels of the given image randomly with a given probability.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// - p (float): probability of the image being autocontrasted. Default value is 0.5
|
||||||
|
type RandomAutocontrast struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomAutocontrast(pOpt ...float64) *RandomAutocontrast {
|
||||||
|
p := 0.5
|
||||||
|
if len(pOpt) > 0 {
|
||||||
|
p = pOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomAutocontrast{p}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rac *RandomAutocontrast) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
r := randPvalue()
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < rac.pvalue:
|
||||||
|
out = autocontrast(x)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomAutocontrast(p ...float64) Option {
|
||||||
|
rac := newRandomAutocontrast(p...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomAutocontrast = rac
|
||||||
|
}
|
||||||
|
}
|
124
vision/aug/crop.go
Normal file
124
vision/aug/crop.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
// "math"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RandomCrop struct {
|
||||||
|
size []int64
|
||||||
|
padding []int64
|
||||||
|
paddingIfNeeded bool
|
||||||
|
paddingMode string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomCrop(size, padding []int64, paddingIfNeeded bool, paddingMode string) *RandomCrop {
|
||||||
|
return &RandomCrop{
|
||||||
|
size: size,
|
||||||
|
padding: padding,
|
||||||
|
paddingIfNeeded: paddingIfNeeded,
|
||||||
|
paddingMode: paddingMode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// get parameters for crop
|
||||||
|
func (c *RandomCrop) params(x *ts.Tensor) (int64, int64, int64, int64) {
|
||||||
|
w, h := getImageSize(x)
|
||||||
|
th, tw := c.size[0], c.size[1]
|
||||||
|
if h+1 < th || w+1 < tw {
|
||||||
|
err := fmt.Errorf("Required crop size %v is larger then input image size %v", c.size, []int64{h, w})
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if w == tw && h == th {
|
||||||
|
return 0, 0, h, w
|
||||||
|
}
|
||||||
|
|
||||||
|
iTs := ts.MustRandint1(0, h-th+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
i := iTs.Int64Values()[0]
|
||||||
|
iTs.MustDrop()
|
||||||
|
|
||||||
|
jTs := ts.MustRandint1(0, w-tw+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
j := jTs.Int64Values()[0]
|
||||||
|
jTs.MustDrop()
|
||||||
|
|
||||||
|
return i, j, th, tw
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RandomCrop) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
var img *ts.Tensor
|
||||||
|
if c.padding != nil {
|
||||||
|
img = pad(x, c.padding, c.paddingMode)
|
||||||
|
} else {
|
||||||
|
img = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
w, h := getImageSize(x)
|
||||||
|
|
||||||
|
var (
|
||||||
|
paddedW *ts.Tensor
|
||||||
|
paddedWH *ts.Tensor
|
||||||
|
)
|
||||||
|
// pad width if needed
|
||||||
|
if c.paddingIfNeeded && w < c.size[1] {
|
||||||
|
padding := []int64{c.size[1] - w, 0}
|
||||||
|
paddedW = pad(img, padding, c.paddingMode)
|
||||||
|
} else {
|
||||||
|
paddedW = img.MustShallowClone()
|
||||||
|
}
|
||||||
|
img.MustDrop()
|
||||||
|
|
||||||
|
// pad height if needed
|
||||||
|
if c.paddingIfNeeded && h < c.size[0] {
|
||||||
|
padding := []int64{0, c.size[0] - h}
|
||||||
|
paddedWH = pad(paddedW, padding, c.paddingMode)
|
||||||
|
} else {
|
||||||
|
paddedWH = paddedW.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
paddedW.MustDrop()
|
||||||
|
|
||||||
|
// i, j, h, w = self.get_params(img, self.size)
|
||||||
|
i, j, h, w := c.params(x)
|
||||||
|
out := crop(paddedWH, i, j, h, w)
|
||||||
|
paddedWH.MustDrop()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomCrop(size []int64, padding []int64, paddingIfNeeded bool, paddingMode string) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
c := newRandomCrop(size, padding, paddingIfNeeded, paddingMode)
|
||||||
|
o.randomCrop = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CenterCrop crops the given image at the center.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
|
||||||
|
type CenterCrop struct {
|
||||||
|
size []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newCenterCrop(size []int64) *CenterCrop {
|
||||||
|
if len(size) != 2 {
|
||||||
|
err := fmt.Errorf("Expected size of 2 elements. Got %v\n", len(size))
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return &CenterCrop{size}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (cc *CenterCrop) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
return centerCrop(x, cc.size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCenterCrop(size []int64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
cc := newCenterCrop(size)
|
||||||
|
o.centerCrop = cc
|
||||||
|
}
|
||||||
|
}
|
177
vision/aug/cutout.go
Normal file
177
vision/aug/cutout.go
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Randomly selects a rectangle region in an torch Tensor image and erases its pixels.
|
||||||
|
// This transform does not support PIL Image.
|
||||||
|
// 'Random Erasing Data Augmentation' by Zhong et al. See https://arxiv.org/abs/1708.04896
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// p: probability that the random erasing operation will be performed.
|
||||||
|
// scale: range of proportion of erased area against input image.
|
||||||
|
// ratio: range of aspect ratio of erased area.
|
||||||
|
// value: erasing value. Default is 0. If a single int, it is used to
|
||||||
|
// erase all pixels. If a tuple of length 3, it is used to erase
|
||||||
|
// R, G, B channels respectively.
|
||||||
|
// If a str of 'random', erasing each pixel with random values.
|
||||||
|
type RandomCutout struct {
|
||||||
|
pvalue float64
|
||||||
|
scale []float64
|
||||||
|
ratio []float64
|
||||||
|
rgbVal []int64 // RGB value
|
||||||
|
}
|
||||||
|
|
||||||
|
type cutoutOptions struct {
|
||||||
|
pvalue float64
|
||||||
|
scale []float64
|
||||||
|
ratio []float64
|
||||||
|
rgbVal []int64 // RGB value
|
||||||
|
}
|
||||||
|
|
||||||
|
type cutoutOption func(o *cutoutOptions)
|
||||||
|
|
||||||
|
func defaultCutoutOptions() *cutoutOptions {
|
||||||
|
return &cutoutOptions{
|
||||||
|
pvalue: 0.5,
|
||||||
|
scale: []float64{0.02, 0.33},
|
||||||
|
ratio: []float64{0.3, 3.3},
|
||||||
|
rgbVal: []int64{0, 0, 0},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomCutout(pvalue float64, scale, ratio []float64, rgbVal []int64) *RandomCutout {
|
||||||
|
return &RandomCutout{
|
||||||
|
pvalue: pvalue,
|
||||||
|
scale: scale,
|
||||||
|
ratio: ratio,
|
||||||
|
rgbVal: rgbVal,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCutoutPvalue(p float64) cutoutOption {
|
||||||
|
if p < 0 || p > 1 {
|
||||||
|
log.Fatalf("Cutout p-value must be in range from 0 to 1. Got %v\n", p)
|
||||||
|
}
|
||||||
|
return func(o *cutoutOptions) {
|
||||||
|
o.pvalue = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCutoutScale(scale []float64) cutoutOption {
|
||||||
|
if len(scale) != 2 {
|
||||||
|
log.Fatalf("Cutout scale should be in a range of 2 elments. Got %v elements\n", len(scale))
|
||||||
|
}
|
||||||
|
return func(o *cutoutOptions) {
|
||||||
|
o.scale = scale
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCutoutRatio(ratio []float64) cutoutOption {
|
||||||
|
if len(ratio) != 2 {
|
||||||
|
log.Fatalf("Cutout ratio should be in a range of 2 elments. Got %v elements\n", len(ratio))
|
||||||
|
}
|
||||||
|
return func(o *cutoutOptions) {
|
||||||
|
o.ratio = ratio
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCutoutValue(rgb []int64) cutoutOption {
|
||||||
|
var rgbVal []int64
|
||||||
|
switch len(rgb) {
|
||||||
|
case 1:
|
||||||
|
rgbVal = []int64{rgb[0], rgb[0], rgb[0]}
|
||||||
|
case 3:
|
||||||
|
rgbVal = rgb
|
||||||
|
default:
|
||||||
|
err := fmt.Errorf("Cutout values can be single value or 3-element (RGB) value. Got %v values.", len(rgb))
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return func(o *cutoutOptions) {
|
||||||
|
o.rgbVal = rgbVal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RandomCutout) cutoutParams(x *ts.Tensor) (int64, int64, int64, int64, *ts.Tensor) {
|
||||||
|
dim := x.MustSize()
|
||||||
|
|
||||||
|
imgH, imgW := dim[len(dim)-2], dim[len(dim)-1]
|
||||||
|
area := float64(imgH * imgW)
|
||||||
|
logRatio := ts.MustOfSlice(rc.ratio).MustLog(true).Float64Values()
|
||||||
|
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
scaleTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
scaleTs.MustUniform_(rc.scale[0], rc.scale[1])
|
||||||
|
scaleVal := scaleTs.Float64Values()[0]
|
||||||
|
scaleTs.MustDrop()
|
||||||
|
eraseArea := area * scaleVal
|
||||||
|
|
||||||
|
ratioTs := ts.MustEmpty([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
ratioTs.MustUniform_(logRatio[0], logRatio[1])
|
||||||
|
asTs := ratioTs.MustExp(true)
|
||||||
|
asVal := asTs.Float64Values()[0] // aspect ratio
|
||||||
|
asTs.MustDrop()
|
||||||
|
|
||||||
|
// h = int(round(math.sqrt(erase_area * aspect_ratio)))
|
||||||
|
// w = int(round(math.sqrt(erase_area / aspect_ratio)))
|
||||||
|
h := int64(math.Round(math.Sqrt(eraseArea * asVal)))
|
||||||
|
w := int64(math.Round(math.Sqrt(eraseArea / asVal)))
|
||||||
|
if !(h < imgH && w < imgW) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// v = torch.tensor(value)[:, None, None]
|
||||||
|
v := ts.MustOfSlice(rc.rgbVal).MustUnsqueeze(1, true).MustUnsqueeze(1, true)
|
||||||
|
|
||||||
|
// i = torch.randint(0, img_h - h + 1, size=(1, )).item()
|
||||||
|
iTs := ts.MustRandint1(0, imgH-h+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
i := iTs.Int64Values()[0]
|
||||||
|
iTs.MustDrop()
|
||||||
|
// j = torch.randint(0, img_w - w + 1, size=(1, )).item()
|
||||||
|
jTs := ts.MustRandint1(0, imgW-w+1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
j := jTs.Int64Values()[0]
|
||||||
|
jTs.MustDrop()
|
||||||
|
return i, j, h, w, v
|
||||||
|
}
|
||||||
|
|
||||||
|
// return original image
|
||||||
|
img := x.MustShallowClone()
|
||||||
|
return 0, 0, imgH, imgW, img
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rc *RandomCutout) Forward(img *ts.Tensor) *ts.Tensor {
|
||||||
|
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
randVal := randTs.Float64Values()[0]
|
||||||
|
randTs.MustDrop()
|
||||||
|
|
||||||
|
switch randVal < rc.pvalue {
|
||||||
|
case true:
|
||||||
|
x, y, h, w, v := rc.cutoutParams(img)
|
||||||
|
out := cutout(img, x, y, h, w, rc.rgbVal)
|
||||||
|
v.MustDrop()
|
||||||
|
return out
|
||||||
|
case false:
|
||||||
|
out := img.MustShallowClone()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
panic("Shouldn't reach here")
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomCutout(opts ...cutoutOption) Option {
|
||||||
|
params := defaultCutoutOptions()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(o *Options) {
|
||||||
|
rc := newRandomCutout(params.pvalue, params.scale, params.ratio, params.rgbVal)
|
||||||
|
o.randomCutout = rc
|
||||||
|
}
|
||||||
|
}
|
46
vision/aug/equalize.go
Normal file
46
vision/aug/equalize.go
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomEqualize equalizes the histogram of the given image randomly with a given probability.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// - p (float): probability of the image being equalized. Default value is 0.5
|
||||||
|
// Histogram equalization
|
||||||
|
// Ref. https://en.wikipedia.org/wiki/Histogram_equalization
|
||||||
|
type RandomEqualize struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomEqualize(pOpt ...float64) *RandomEqualize {
|
||||||
|
p := 0.5
|
||||||
|
if len(pOpt) > 0 {
|
||||||
|
p = pOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomEqualize{p}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (re *RandomEqualize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
r := randPvalue()
|
||||||
|
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < re.pvalue:
|
||||||
|
out = equalize(x)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomEqualize(p ...float64) Option {
|
||||||
|
re := newRandomEqualize(p...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomEqualize = re
|
||||||
|
}
|
||||||
|
}
|
78
vision/aug/flip.go
Normal file
78
vision/aug/flip.go
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomHorizontalFlip horizontally flips the given image randomly with a given probability.
|
||||||
|
//
|
||||||
|
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
||||||
|
// where ... means an arbitrary number of leading dimensions
|
||||||
|
// Args:
|
||||||
|
// p (float): probability of the image being flipped. Default value is 0.5
|
||||||
|
type RandomHorizontalFlip struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomHorizontalFlip(pvalue float64) *RandomHorizontalFlip {
|
||||||
|
return &RandomHorizontalFlip{
|
||||||
|
pvalue: pvalue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (hf *RandomHorizontalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
randVal := randTs.Float64Values()[0]
|
||||||
|
randTs.MustDrop()
|
||||||
|
switch {
|
||||||
|
case randVal < hf.pvalue:
|
||||||
|
return hflip(x)
|
||||||
|
default:
|
||||||
|
out := x.MustShallowClone()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomHFlip(pvalue float64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
hf := newRandomHorizontalFlip(pvalue)
|
||||||
|
o.randomHFlip = hf
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomVerticalFlip vertically flips the given image randomly with a given probability.
|
||||||
|
//
|
||||||
|
// If the image is torch Tensor, it is expected to have [..., H, W] shape,
|
||||||
|
// where ... means an arbitrary number of leading dimensions
|
||||||
|
// Args:
|
||||||
|
// p (float): probability of the image being flipped. Default value is 0.5
|
||||||
|
type RandomVerticalFlip struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomVerticalFlip(pvalue float64) *RandomVerticalFlip {
|
||||||
|
return &RandomVerticalFlip{
|
||||||
|
pvalue: pvalue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (vf *RandomVerticalFlip) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
randTs := ts.MustRandn([]int64{1}, gotch.Float, gotch.CPU)
|
||||||
|
randVal := randTs.Float64Values()[0]
|
||||||
|
randTs.MustDrop()
|
||||||
|
switch {
|
||||||
|
case randVal < vf.pvalue:
|
||||||
|
return vflip(x)
|
||||||
|
default:
|
||||||
|
out := x.MustShallowClone()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomVFlip(pvalue float64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
vf := newRandomVerticalFlip(pvalue)
|
||||||
|
o.randomVFlip = vf
|
||||||
|
}
|
||||||
|
}
|
1514
vision/aug/function.go
Normal file
1514
vision/aug/function.go
Normal file
File diff suppressed because it is too large
Load Diff
81
vision/aug/grayscale.go
Normal file
81
vision/aug/grayscale.go
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
// "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GrayScale converts image to grayscale.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||||
|
// Args:
|
||||||
|
// - num_output_channels (int): (1 or 3) number of channels desired for output image
|
||||||
|
type Grayscale struct {
|
||||||
|
outChan int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (gs *Grayscale) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
out := rgb2Gray(x, gs.outChan)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func newGrayscale(outChanOpt ...int64) *Grayscale {
|
||||||
|
var outChan int64 = 3
|
||||||
|
if len(outChanOpt) > 0 {
|
||||||
|
c := outChanOpt[0]
|
||||||
|
switch c {
|
||||||
|
case 1:
|
||||||
|
outChan = 1
|
||||||
|
case 3:
|
||||||
|
outChan = 3
|
||||||
|
default:
|
||||||
|
log.Fatalf("Out channels should be either 1 or 3. Got %v\n", c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Grayscale{outChan}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomGrayscale randomly converts image to grayscale with a probability of p (default 0.1).
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions
|
||||||
|
// Args:
|
||||||
|
// - p (float): probability that image should be converted to grayscale.
|
||||||
|
type RandomGrayscale struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomGrayscale(pvalueOpt ...float64) *RandomGrayscale {
|
||||||
|
pvalue := 0.1
|
||||||
|
if len(pvalueOpt) > 0 {
|
||||||
|
pvalue = pvalueOpt[0]
|
||||||
|
}
|
||||||
|
return &RandomGrayscale{pvalue}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rgs *RandomGrayscale) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
c := getImageChanNum(x)
|
||||||
|
r := randPvalue()
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < rgs.pvalue:
|
||||||
|
out = rgb2Gray(x, c)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomGrayscale(pvalueOpt ...float64) Option {
|
||||||
|
var p float64 = 0.1
|
||||||
|
if len(pvalueOpt) > 0 {
|
||||||
|
p = pvalueOpt[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
rgs := newRandomGrayscale(p)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomGrayscale = rgs
|
||||||
|
}
|
||||||
|
}
|
39
vision/aug/invert.go
Normal file
39
vision/aug/invert.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RandomInvert struct {
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomInvert(pOpt ...float64) *RandomInvert {
|
||||||
|
p := 0.5
|
||||||
|
if len(pOpt) > 0 {
|
||||||
|
p = pOpt[0]
|
||||||
|
}
|
||||||
|
return &RandomInvert{p}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ri *RandomInvert) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
r := randPvalue()
|
||||||
|
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < ri.pvalue:
|
||||||
|
out = invert(x)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomInvert(pvalueOpt ...float64) Option {
|
||||||
|
ri := newRandomInvert(pvalueOpt...)
|
||||||
|
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomInvert = ri
|
||||||
|
}
|
||||||
|
}
|
91
vision/aug/normalize.go
Normal file
91
vision/aug/normalize.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Normalize normalizes a tensor image with mean and standard deviation.
|
||||||
|
// Given mean: ``(mean[1],...,mean[n])`` and std: ``(std[1],..,std[n])`` for ``n``
|
||||||
|
// channels, this transform will normalize each channel of the input
|
||||||
|
// ``torch.*Tensor`` i.e.,
|
||||||
|
// ``output[channel] = (input[channel] - mean[channel]) / std[channel]``
|
||||||
|
// .. note::
|
||||||
|
// This transform acts out of place, i.e., it does not mutate the input tensor.
|
||||||
|
// Args:
|
||||||
|
// - mean (sequence): Sequence of means for each channel.
|
||||||
|
// - std (sequence): Sequence of standard deviations for each channel.
|
||||||
|
type Normalize struct {
|
||||||
|
mean []float64 // should be from 0 to 1
|
||||||
|
std []float64 // should be > 0 and <= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
type normalizeOptions struct {
|
||||||
|
mean []float64
|
||||||
|
std []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type normalizeOption func(*normalizeOptions)
|
||||||
|
|
||||||
|
// Mean and SD can be calculated for specific dataset as follow:
|
||||||
|
/*
|
||||||
|
mean = 0.0
|
||||||
|
meansq = 0.0
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for index, data in enumerate(train_loader):
|
||||||
|
mean = data.sum()
|
||||||
|
meansq = meansq + (data**2).sum()
|
||||||
|
count += np.prod(data.shape)
|
||||||
|
|
||||||
|
total_mean = mean/count
|
||||||
|
total_var = (meansq/count) - (total_mean**2)
|
||||||
|
total_std = torch.sqrt(total_var)
|
||||||
|
print("mean: " + str(total_mean))
|
||||||
|
print("std: " + str(total_std))
|
||||||
|
*/
|
||||||
|
|
||||||
|
// For example. ImageNet dataset has RGB mean and standard error:
|
||||||
|
// meanVals := []float64{0.485, 0.456, 0.406}
|
||||||
|
// sdVals := []float64{0.229, 0.224, 0.225}
|
||||||
|
func defaultNormalizeOptions() *normalizeOptions {
|
||||||
|
return &normalizeOptions{
|
||||||
|
mean: []float64{0, 0, 0},
|
||||||
|
std: []float64{1, 1, 1},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithNormalizeStd(std []float64) normalizeOption {
|
||||||
|
return func(o *normalizeOptions) {
|
||||||
|
o.std = std
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithNormalizeMean(mean []float64) normalizeOption {
|
||||||
|
return func(o *normalizeOptions) {
|
||||||
|
o.mean = mean
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNormalize(opts ...normalizeOption) *Normalize {
|
||||||
|
p := defaultNormalizeOptions()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Normalize{
|
||||||
|
mean: p.mean,
|
||||||
|
std: p.std,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *Normalize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
out := normalize(x, n.mean, n.std)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithNormalize(opts ...normalizeOption) Option {
|
||||||
|
n := newNormalize(opts...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.normalize = n
|
||||||
|
}
|
||||||
|
}
|
1
vision/aug/pad.go
Normal file
1
vision/aug/pad.go
Normal file
|
@ -0,0 +1 @@
|
||||||
|
package aug
|
190
vision/aug/perspective.go
Normal file
190
vision/aug/perspective.go
Normal file
|
@ -0,0 +1,190 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
// "fmt"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomPerspective performs a random perspective transformation of the given image with a given probability.
|
||||||
|
// If the image is torch Tensor, it is expected
|
||||||
|
// to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1.
|
||||||
|
// Default is 0.5.
|
||||||
|
// p (float): probability of the image being transformed. Default is 0.5.
|
||||||
|
// interpolation (InterpolationMode): Desired interpolation enum defined by
|
||||||
|
// :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
|
||||||
|
// If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
|
||||||
|
// For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
|
||||||
|
// fill (sequence or number): Pixel fill value for the area outside the transformed
|
||||||
|
// image. Default is ``0``. If given a number, the value is used for all bands respectively.
|
||||||
|
type RandomPerspective struct {
|
||||||
|
distortionScale float64 // range [0, 1]
|
||||||
|
pvalue float64 // range [0, 1]
|
||||||
|
interpolationMode string
|
||||||
|
fillValue []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type perspectiveOptions struct {
|
||||||
|
distortionScale float64 // range [0, 1]
|
||||||
|
pvalue float64 // range [0, 1]
|
||||||
|
interpolationMode string
|
||||||
|
fillValue []float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultPerspectiveOptions() *perspectiveOptions {
|
||||||
|
return &perspectiveOptions{
|
||||||
|
distortionScale: 0.5,
|
||||||
|
pvalue: 0.5,
|
||||||
|
interpolationMode: "bilinear",
|
||||||
|
fillValue: []float64{0.0, 0.0, 0.0},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type perspectiveOption func(*perspectiveOptions)
|
||||||
|
|
||||||
|
func WithPerspectivePvalue(p float64) perspectiveOption {
|
||||||
|
return func(o *perspectiveOptions) {
|
||||||
|
o.pvalue = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPerspectiveScale(s float64) perspectiveOption {
|
||||||
|
return func(o *perspectiveOptions) {
|
||||||
|
o.distortionScale = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPerspectiveMode(m string) perspectiveOption {
|
||||||
|
return func(o *perspectiveOptions) {
|
||||||
|
o.interpolationMode = m
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPerspectiveValue(v []float64) perspectiveOption {
|
||||||
|
return func(o *perspectiveOptions) {
|
||||||
|
o.fillValue = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomPerspective(opts ...perspectiveOption) *RandomPerspective {
|
||||||
|
params := defaultPerspectiveOptions()
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomPerspective{
|
||||||
|
distortionScale: params.distortionScale,
|
||||||
|
pvalue: params.pvalue,
|
||||||
|
interpolationMode: params.interpolationMode,
|
||||||
|
fillValue: params.fillValue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get parameters for ``perspective`` for a random perspective transform.
|
||||||
|
//
|
||||||
|
// Args:
|
||||||
|
// - width (int): width of the image.
|
||||||
|
// - height (int): height of the image.
|
||||||
|
// Returns:
|
||||||
|
// - List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
|
||||||
|
// - List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image.
|
||||||
|
func (rp *RandomPerspective) getParams(w, h int64) ([][]int64, [][]int64) {
|
||||||
|
halfH := h / 2
|
||||||
|
halfW := w / 2
|
||||||
|
|
||||||
|
var (
|
||||||
|
topLeft []int64
|
||||||
|
topRight []int64
|
||||||
|
bottomRight []int64
|
||||||
|
bottomLeft []int64
|
||||||
|
)
|
||||||
|
|
||||||
|
// topleft = [
|
||||||
|
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
|
||||||
|
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||||
|
// ]
|
||||||
|
tlVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||||
|
tlTs1 := ts.MustRandint1(0, tlVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
tl1 := tlTs1.Int64Values()[0]
|
||||||
|
tlTs1.MustDrop()
|
||||||
|
tlVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||||
|
tlTs2 := ts.MustRandint1(0, tlVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
tl2 := tlTs2.Int64Values()[0]
|
||||||
|
tlTs2.MustDrop()
|
||||||
|
topLeft = []int64{tl1, tl2}
|
||||||
|
|
||||||
|
// topright = [
|
||||||
|
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
|
||||||
|
// int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1, )).item())
|
||||||
|
// ]
|
||||||
|
trVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||||
|
trTs1 := ts.MustRandint1(trVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
tr1 := trTs1.Int64Values()[0]
|
||||||
|
trTs1.MustDrop()
|
||||||
|
trVal2 := int64(rp.distortionScale*float64(halfH)) + 1
|
||||||
|
trTs2 := ts.MustRandint1(0, trVal2, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
tr2 := trTs2.Int64Values()[0]
|
||||||
|
trTs2.MustDrop()
|
||||||
|
topRight = []int64{tr1, tr2}
|
||||||
|
|
||||||
|
// botright = [
|
||||||
|
// int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1, )).item()),
|
||||||
|
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||||
|
// ]
|
||||||
|
brVal1 := w - int64(rp.distortionScale*float64(halfW)) - 1
|
||||||
|
brTs1 := ts.MustRandint1(brVal1, w, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
br1 := brTs1.Int64Values()[0]
|
||||||
|
brTs1.MustDrop()
|
||||||
|
brVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||||
|
brTs2 := ts.MustRandint1(brVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
br2 := brTs2.Int64Values()[0]
|
||||||
|
brTs2.MustDrop()
|
||||||
|
bottomRight = []int64{br1, br2}
|
||||||
|
|
||||||
|
// botleft = [
|
||||||
|
// int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1, )).item()),
|
||||||
|
// int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1, )).item())
|
||||||
|
// ]
|
||||||
|
blVal1 := int64(rp.distortionScale*float64(halfW)) + 1
|
||||||
|
blTs1 := ts.MustRandint1(0, blVal1, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
bl1 := blTs1.Int64Values()[0]
|
||||||
|
blTs1.MustDrop()
|
||||||
|
blVal2 := h - int64(rp.distortionScale*float64(halfH)) - 1
|
||||||
|
blTs2 := ts.MustRandint1(blVal2, h, []int64{1}, gotch.Int64, gotch.CPU)
|
||||||
|
bl2 := blTs2.Int64Values()[0]
|
||||||
|
blTs2.MustDrop()
|
||||||
|
bottomLeft = []int64{bl1, bl2}
|
||||||
|
|
||||||
|
startPoints := [][]int64{
|
||||||
|
{0, 0},
|
||||||
|
{w - 1, 0},
|
||||||
|
{w - 1, h - 1},
|
||||||
|
{0, h - 1},
|
||||||
|
}
|
||||||
|
|
||||||
|
endPoints := [][]int64{
|
||||||
|
topLeft,
|
||||||
|
topRight,
|
||||||
|
bottomRight,
|
||||||
|
bottomLeft,
|
||||||
|
}
|
||||||
|
|
||||||
|
return startPoints, endPoints
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *RandomPerspective) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
height, width := getImageSize(x)
|
||||||
|
startPoints, endPoints := rp.getParams(height, width)
|
||||||
|
out := perspective(x, startPoints, endPoints, rp.interpolationMode, rp.fillValue)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomPerspective(opts ...perspectiveOption) Option {
|
||||||
|
rp := newRandomPerspective(opts...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomPerspective = rp
|
||||||
|
}
|
||||||
|
}
|
77
vision/aug/posterize.go
Normal file
77
vision/aug/posterize.go
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomPosterize posterizes the image randomly with a given probability by reducing the
|
||||||
|
// number of bits for each color channel. If the image is torch Tensor, it should be of type torch.uint8,
|
||||||
|
// and it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// - bits (int): number of bits to keep for each channel (0-8)
|
||||||
|
// - p (float): probability of the image being color inverted. Default value is 0.5
|
||||||
|
// Ref. https://en.wikipedia.org/wiki/Posterization
|
||||||
|
type RandomPosterize struct {
|
||||||
|
pvalue float64
|
||||||
|
bits uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
type posterizeOptions struct {
|
||||||
|
pvalue float64
|
||||||
|
bits uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
type posterizeOption func(*posterizeOptions)
|
||||||
|
|
||||||
|
func defaultPosterizeOptions() *posterizeOptions {
|
||||||
|
return &posterizeOptions{
|
||||||
|
pvalue: 0.5,
|
||||||
|
bits: 4,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPosterizePvalue(p float64) posterizeOption {
|
||||||
|
return func(o *posterizeOptions) {
|
||||||
|
o.pvalue = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithPosterizeBits(bits uint8) posterizeOption {
|
||||||
|
return func(o *posterizeOptions) {
|
||||||
|
o.bits = bits
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomPosterize(opts ...posterizeOption) *RandomPosterize {
|
||||||
|
p := defaultPosterizeOptions()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomPosterize{
|
||||||
|
pvalue: p.pvalue,
|
||||||
|
bits: p.bits,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rp *RandomPosterize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
|
||||||
|
r := randPvalue()
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < rp.pvalue:
|
||||||
|
out = posterize(x, rp.bits)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomPosterize(opts ...posterizeOption) Option {
|
||||||
|
rp := newRandomPosterize(opts...)
|
||||||
|
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomPosterize = rp
|
||||||
|
}
|
||||||
|
}
|
39
vision/aug/resize.go
Normal file
39
vision/aug/resize.go
Normal file
|
@ -0,0 +1,39 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ResizeModule struct {
|
||||||
|
height int64
|
||||||
|
width int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newResizeModule(h, w int64) *ResizeModule {
|
||||||
|
return &ResizeModule{h, w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements ts.Module for RandRotateModule
|
||||||
|
func (rs *ResizeModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
imgTs := x.MustTotype(gotch.Uint8, false)
|
||||||
|
out, err := vision.Resize(imgTs, rs.width, rs.height)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
imgTs.MustDrop()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithResize(h, w int64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
rs := newResizeModule(h, w)
|
||||||
|
o.resize = rs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO.
|
||||||
|
type RandomResizedCrop struct{}
|
109
vision/aug/rotate.go
Normal file
109
vision/aug/rotate.go
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"math"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomRotate randomly rotates a tensor image within a specifed angle range (degree).
|
||||||
|
func RandomRotate(img *ts.Tensor, min, max float64) (*ts.Tensor, error) {
|
||||||
|
if min > max {
|
||||||
|
tmp := min
|
||||||
|
min = max
|
||||||
|
max = tmp
|
||||||
|
}
|
||||||
|
if min < -360 || min > 360 || max < -360 || max > 360 {
|
||||||
|
err := fmt.Errorf("min and max should be in range from -360 to 360. Got %v and %v\n", min, max)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// device := img.MustDevice()
|
||||||
|
dtype := gotch.Double
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
angle := min + rand.Float64()*(max-min)
|
||||||
|
|
||||||
|
theta := float64(angle) * (math.Pi / 180)
|
||||||
|
input := img.MustUnsqueeze(0, false).MustTotype(dtype, true)
|
||||||
|
r, err := rotImg(input, theta, dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
input.MustDrop()
|
||||||
|
rotatedImg := r.MustSqueeze(true)
|
||||||
|
return rotatedImg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Rotate(img *ts.Tensor, angle float64) (*ts.Tensor, error) {
|
||||||
|
if angle < -360 || angle > 360 {
|
||||||
|
err := fmt.Errorf("angle must be in range (-360, 360)")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dtype := gotch.Double
|
||||||
|
theta := float64(angle) * (math.Pi / 180)
|
||||||
|
input := img.MustUnsqueeze(0, false).MustTotype(dtype, true)
|
||||||
|
r, err := rotImg(input, theta, dtype)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
input.MustDrop()
|
||||||
|
rotatedImg := r.MustSqueeze(true)
|
||||||
|
return rotatedImg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RotateModule
|
||||||
|
type RotateModule struct {
|
||||||
|
angle float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRotate(angle float64) *RotateModule {
|
||||||
|
return &RotateModule{angle}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements ts.Module for RotateModule
|
||||||
|
func (r *RotateModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
out, err := Rotate(x, r.angle)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRotate(angle float64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
r := newRotate(angle)
|
||||||
|
o.rotate = r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RandomRotateModule
|
||||||
|
type RandRotateModule struct {
|
||||||
|
minAngle float64
|
||||||
|
maxAngle float64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandRotate(min, max float64) *RandRotateModule {
|
||||||
|
return &RandRotateModule{min, max}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward implements ts.Module for RandRotateModule
|
||||||
|
func (rr *RandRotateModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
out, err := RandomRotate(x, rr.minAngle, rr.maxAngle)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandRotate(minAngle, maxAngle float64) Option {
|
||||||
|
return func(o *Options) {
|
||||||
|
r := newRandRotate(minAngle, maxAngle)
|
||||||
|
o.randRotate = r
|
||||||
|
}
|
||||||
|
}
|
74
vision/aug/sharpness.go
Normal file
74
vision/aug/sharpness.go
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Adjust the sharpness of the image randomly with a given probability. If the image is torch Tensor,
|
||||||
|
// it is expected to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
|
||||||
|
// Args:
|
||||||
|
// sharpness_factor (float): How much to adjust the sharpness. Can be
|
||||||
|
// any non negative number. 0 gives a blurred image, 1 gives the
|
||||||
|
// original image while 2 increases the sharpness by a factor of 2.
|
||||||
|
// p (float): probability of the image being color inverted. Default value is 0.5
|
||||||
|
type RandomAdjustSharpness struct {
|
||||||
|
sharpnessFactor float64
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharpnessOptions struct {
|
||||||
|
sharpnessFactor float64
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type sharpnessOption func(*sharpnessOptions)
|
||||||
|
|
||||||
|
func defaultSharpnessOptions() *sharpnessOptions {
|
||||||
|
return &sharpnessOptions{
|
||||||
|
sharpnessFactor: 1.0,
|
||||||
|
pvalue: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSharpnessPvalue(p float64) sharpnessOption {
|
||||||
|
return func(o *sharpnessOptions) {
|
||||||
|
o.pvalue = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSharpnessFactor(f float64) sharpnessOption {
|
||||||
|
return func(o *sharpnessOptions) {
|
||||||
|
o.sharpnessFactor = f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomAdjustSharpness(opts ...sharpnessOption) *RandomAdjustSharpness {
|
||||||
|
p := defaultSharpnessOptions()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(p)
|
||||||
|
}
|
||||||
|
return &RandomAdjustSharpness{
|
||||||
|
sharpnessFactor: p.sharpnessFactor,
|
||||||
|
pvalue: p.pvalue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
r := randPvalue()
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < ras.pvalue:
|
||||||
|
out = adjustSharpness(x, ras.sharpnessFactor)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {
|
||||||
|
ras := newRandomAdjustSharpness(opts...)
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomAdjustSharpness = ras
|
||||||
|
}
|
||||||
|
}
|
79
vision/aug/solarize.go
Normal file
79
vision/aug/solarize.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RandomSolarize solarizes the image randomly with a given probability by inverting all pixel
|
||||||
|
// values above a threshold. If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format,
|
||||||
|
// where ... means it can have an arbitrary number of leading dimensions.
|
||||||
|
// If img is PIL Image, it is expected to be in mode "L" or "RGB".
|
||||||
|
// Args:
|
||||||
|
// - threshold (float): all pixels equal or above this value are inverted.
|
||||||
|
// - p (float): probability of the image being color inverted. Default value is 0.5
|
||||||
|
// Ref. https://en.wikipedia.org/wiki/Solarization_(photography)
|
||||||
|
type RandomSolarize struct {
|
||||||
|
threshold float64
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type solarizeOptions struct {
|
||||||
|
threshold float64
|
||||||
|
pvalue float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type solarizeOption func(*solarizeOptions)
|
||||||
|
|
||||||
|
func defaultSolarizeOptions() *solarizeOptions {
|
||||||
|
return &solarizeOptions{
|
||||||
|
threshold: 128,
|
||||||
|
pvalue: 0.5,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSolarizePvalue(p float64) solarizeOption {
|
||||||
|
return func(o *solarizeOptions) {
|
||||||
|
o.pvalue = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithSolarizeThreshold(th float64) solarizeOption {
|
||||||
|
return func(o *solarizeOptions) {
|
||||||
|
o.threshold = th
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRandomSolarize(opts ...solarizeOption) *RandomSolarize {
|
||||||
|
params := defaultSolarizeOptions()
|
||||||
|
|
||||||
|
for _, o := range opts {
|
||||||
|
o(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RandomSolarize{
|
||||||
|
threshold: params.threshold,
|
||||||
|
pvalue: params.pvalue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rs *RandomSolarize) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
|
r := randPvalue()
|
||||||
|
|
||||||
|
var out *ts.Tensor
|
||||||
|
switch {
|
||||||
|
case r < rs.pvalue:
|
||||||
|
out = solarize(x, rs.threshold)
|
||||||
|
default:
|
||||||
|
out = x.MustShallowClone()
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRandomSolarize(opts ...solarizeOption) Option {
|
||||||
|
rs := newRandomSolarize(opts...)
|
||||||
|
|
||||||
|
return func(o *Options) {
|
||||||
|
o.randomSolarize = rs
|
||||||
|
}
|
||||||
|
}
|
188
vision/aug/transform.go
Normal file
188
vision/aug/transform.go
Normal file
|
@ -0,0 +1,188 @@
|
||||||
|
package aug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch/nn"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transformer is an interface that can transform an image tensor.
|
||||||
|
type Transformer interface {
|
||||||
|
Transform(x *ts.Tensor) *ts.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Augment is a struct composes of augmentation functions to implement Transformer interface.
|
||||||
|
type Augment struct {
|
||||||
|
augments *nn.Sequential
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transform implements Transformer interface for Augment struct.
|
||||||
|
func (a *Augment) Transform(image *ts.Tensor) *ts.Tensor {
|
||||||
|
out := a.augments.Forward(image)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
rotate *RotateModule
|
||||||
|
randRotate *RandRotateModule
|
||||||
|
resize *ResizeModule
|
||||||
|
colorJitter *ColorJitter
|
||||||
|
gaussianBlur *GaussianBlur
|
||||||
|
randomHFlip *RandomHorizontalFlip
|
||||||
|
randomVFlip *RandomVerticalFlip
|
||||||
|
randomCrop *RandomCrop
|
||||||
|
centerCrop *CenterCrop
|
||||||
|
randomCutout *RandomCutout
|
||||||
|
randomPerspective *RandomPerspective
|
||||||
|
randomAffine *RandomAffine
|
||||||
|
randomGrayscale *RandomGrayscale
|
||||||
|
randomSolarize *RandomSolarize
|
||||||
|
randomPosterize *RandomPosterize
|
||||||
|
randomInvert *RandomInvert
|
||||||
|
randomAutocontrast *RandomAutocontrast
|
||||||
|
randomAdjustSharpness *RandomAdjustSharpness
|
||||||
|
randomEqualize *RandomEqualize
|
||||||
|
normalize *Normalize
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultOption() *Options {
|
||||||
|
return &Options{
|
||||||
|
rotate: nil,
|
||||||
|
randRotate: nil,
|
||||||
|
resize: nil,
|
||||||
|
colorJitter: nil,
|
||||||
|
gaussianBlur: nil,
|
||||||
|
randomHFlip: nil,
|
||||||
|
randomVFlip: nil,
|
||||||
|
randomCrop: nil,
|
||||||
|
centerCrop: nil,
|
||||||
|
randomCutout: nil,
|
||||||
|
randomPerspective: nil,
|
||||||
|
randomAffine: nil,
|
||||||
|
randomGrayscale: nil,
|
||||||
|
randomSolarize: nil,
|
||||||
|
randomPosterize: nil,
|
||||||
|
randomInvert: nil,
|
||||||
|
randomAutocontrast: nil,
|
||||||
|
randomAdjustSharpness: nil,
|
||||||
|
randomEqualize: nil,
|
||||||
|
normalize: nil,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Option func(o *Options)
|
||||||
|
|
||||||
|
// Compose creates a new Augment struct by adding augmentation methods.
|
||||||
|
func Compose(opts ...Option) (Transformer, error) {
|
||||||
|
augOpts := defaultOption()
|
||||||
|
for _, opt := range opts {
|
||||||
|
if opt != nil {
|
||||||
|
opt(augOpts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var augs *nn.Sequential = nn.Seq()
|
||||||
|
|
||||||
|
if augOpts.rotate != nil {
|
||||||
|
augs.Add(augOpts.rotate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randRotate != nil {
|
||||||
|
augs.Add(augOpts.randRotate)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.resize != nil {
|
||||||
|
augs.Add(augOpts.resize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.colorJitter != nil {
|
||||||
|
augs.Add(augOpts.colorJitter)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.gaussianBlur != nil {
|
||||||
|
augs.Add(augOpts.gaussianBlur)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomHFlip != nil {
|
||||||
|
augs.Add(augOpts.randomHFlip)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomVFlip != nil {
|
||||||
|
augs.Add(augOpts.randomVFlip)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomCrop != nil {
|
||||||
|
augs.Add(augOpts.randomCrop)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.centerCrop != nil {
|
||||||
|
augs.Add(augOpts.centerCrop)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomCutout != nil {
|
||||||
|
augs.Add(augOpts.randomCutout)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomPerspective != nil {
|
||||||
|
augs.Add(augOpts.randomPerspective)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomAffine != nil {
|
||||||
|
augs.Add(augOpts.randomAffine)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomGrayscale != nil {
|
||||||
|
augs.Add(augOpts.randomGrayscale)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomSolarize != nil {
|
||||||
|
augs.Add(augOpts.randomSolarize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomPosterize != nil {
|
||||||
|
augs.Add(augOpts.randomPosterize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomInvert != nil {
|
||||||
|
augs.Add(augOpts.randomInvert)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomAutocontrast != nil {
|
||||||
|
augs.Add(augOpts.randomAutocontrast)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomAdjustSharpness != nil {
|
||||||
|
augs.Add(augOpts.randomAdjustSharpness)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.randomEqualize != nil {
|
||||||
|
augs.Add(augOpts.randomEqualize)
|
||||||
|
}
|
||||||
|
|
||||||
|
if augOpts.normalize != nil {
|
||||||
|
augs.Add(augOpts.normalize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Augment{augs}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OneOf randomly return one transformer from list of transformers
|
||||||
|
// with a specific p value.
|
||||||
|
func OneOf(pvalue float64, tfOpts ...Option) Option {
|
||||||
|
tfsNum := len(tfOpts)
|
||||||
|
if tfsNum < 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
randP := randPvalue()
|
||||||
|
if randP >= pvalue {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rand.Seed(time.Now().UnixNano())
|
||||||
|
idx := rand.Intn(tfsNum)
|
||||||
|
|
||||||
|
return tfOpts[idx]
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user