223 lines
4.8 KiB
Go
223 lines
4.8 KiB
Go
package aug
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
|
|
"git.andr3h3nriqu3s.com/andr3/gotch"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/ts"
|
|
"git.andr3h3nriqu3s.com/andr3/gotch/vision"
|
|
)
|
|
|
|
type ResizeModule struct {
|
|
height int64
|
|
width int64
|
|
}
|
|
|
|
func newResizeModule(h, w int64) *ResizeModule {
|
|
return &ResizeModule{h, w}
|
|
}
|
|
|
|
// Forward implements ts.Module for RandRotateModule
|
|
// NOTE. input tensor must be uint8 (Byte) dtype otherwise panic!
|
|
func (rs *ResizeModule) Forward(x *ts.Tensor) *ts.Tensor {
|
|
dtype := x.DType()
|
|
if dtype != gotch.Uint8 {
|
|
err := fmt.Errorf("Invalid dtype. Expect uint8 (Byte) dtype. Got %v\n", dtype)
|
|
panic(err)
|
|
}
|
|
|
|
device := x.MustDevice()
|
|
var xCPU *ts.Tensor
|
|
if device != gotch.CPU {
|
|
xCPU = x.MustTo(device, false)
|
|
} else {
|
|
xCPU = x.MustShallowClone()
|
|
}
|
|
|
|
out, err := vision.Resize(xCPU, rs.width, rs.height)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
xCPU.MustDrop()
|
|
|
|
return out.MustTo(device, true)
|
|
}
|
|
|
|
func WithResize(h, w int64) Option {
|
|
return func(o *Options) {
|
|
rs := newResizeModule(h, w)
|
|
o.resize = rs
|
|
}
|
|
}
|
|
|
|
// TODO.
|
|
type RandomResizedCrop struct{}
|
|
|
|
type DownSample struct{}
|
|
|
|
func newDownSample(p float64) *DownSample {
|
|
return &DownSample{}
|
|
}
|
|
|
|
// Forward implements ts.Module for RandRotateModule
|
|
// NOTE. input tensor must be uint8 (Byte) dtype otherwise panic!
|
|
func (rs *DownSample) Forward(x *ts.Tensor) *ts.Tensor {
|
|
dtype := x.DType()
|
|
if dtype != gotch.Uint8 {
|
|
err := fmt.Errorf("Invalid dtype. Expect uint8 (Byte) dtype. Got %v\n", dtype)
|
|
panic(err)
|
|
}
|
|
|
|
device := x.MustDevice()
|
|
h := x.MustSize()[1]
|
|
w := x.MustSize()[2]
|
|
var xCPU *ts.Tensor
|
|
if device != gotch.CPU {
|
|
xCPU = x.MustTo(device, false)
|
|
} else {
|
|
xCPU = x.MustShallowClone()
|
|
}
|
|
|
|
out, err := vision.Resize(xCPU, w/2, h/2)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
xCPU.MustDrop()
|
|
return out.MustTo(device, true)
|
|
}
|
|
|
|
type ZoomIn struct {
|
|
v float64 // v should be [0, 0.5]
|
|
}
|
|
|
|
func newZoomIn(v float64) *ZoomIn {
|
|
return &ZoomIn{v: v}
|
|
}
|
|
|
|
func WithZoomIn(v float64) Option {
|
|
if v < 0 || v > 0.5 {
|
|
err := fmt.Errorf("Invalid input value. Expect value in range [0, 0.5]. Got %v\n", v)
|
|
panic(err)
|
|
}
|
|
return func(o *Options) {
|
|
ds := newZoomIn(v)
|
|
o.zoomIn = ds
|
|
}
|
|
}
|
|
|
|
// Forward implements ts.Module for RandRotateModule
|
|
// NOTE. input tensor must be uint8 (Byte) dtype otherwise panic!
|
|
func (rs *ZoomIn) Forward(x *ts.Tensor) *ts.Tensor {
|
|
dtype := x.DType()
|
|
if dtype != gotch.Uint8 {
|
|
err := fmt.Errorf("Invalid dtype. Expect uint8 (Byte) dtype. Got %v\n", dtype)
|
|
panic(err)
|
|
}
|
|
|
|
device := x.MustDevice()
|
|
h := x.MustSize()[1]
|
|
w := x.MustSize()[2]
|
|
var xCPU *ts.Tensor
|
|
if device != gotch.CPU {
|
|
xCPU = x.MustTo(device, false)
|
|
} else {
|
|
xCPU = x.MustShallowClone()
|
|
}
|
|
|
|
var out *ts.Tensor
|
|
var err error
|
|
r := randPvalue()
|
|
switch {
|
|
case r < rs.v:
|
|
cropW := int64(rs.v) * w
|
|
cropH := int64(rs.v) * h
|
|
newW := w - cropW
|
|
newH := h - cropH
|
|
// img = PIL.ImageOps.fit(img, size=(new_w,new_h), bleed=v/2, method=Image.BILINEAR)
|
|
fitImg := fitImg(xCPU, newW, newH)
|
|
xCPU.MustDrop()
|
|
// return img.resize((w,h), resample=Image.BILINEAR)
|
|
out, err = vision.Resize(fitImg, w, h)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fitImg.MustDrop()
|
|
default:
|
|
out = x.MustShallowClone()
|
|
}
|
|
|
|
return out.MustTo(device, true)
|
|
}
|
|
|
|
// TODO.
|
|
func fitImg(x *ts.Tensor, w, h int64) *ts.Tensor {
|
|
|
|
panic("Not implemented")
|
|
}
|
|
|
|
type ZoomOut struct {
|
|
v float64 // v should be [0, 0.5]
|
|
}
|
|
|
|
func newZoomOut(v float64) *ZoomOut {
|
|
return &ZoomOut{v: v}
|
|
}
|
|
|
|
func WithZoomOut(v float64) Option {
|
|
if v < 0 || v > 0.5 {
|
|
err := fmt.Errorf("Invalid input value. Expect value in range [0, 0.5]. Got %v\n", v)
|
|
panic(err)
|
|
}
|
|
return func(o *Options) {
|
|
ds := newZoomOut(v)
|
|
o.zoomOut = ds
|
|
}
|
|
}
|
|
|
|
// Forward implements ts.Module for RandRotateModule
|
|
// NOTE. input tensor must be uint8 (Byte) dtype otherwise panic!
|
|
func (rs *ZoomOut) Forward(x *ts.Tensor) *ts.Tensor {
|
|
dtype := x.DType()
|
|
if dtype != gotch.Uint8 {
|
|
err := fmt.Errorf("Invalid dtype. Expect uint8 (Byte) dtype. Got %v\n", dtype)
|
|
panic(err)
|
|
}
|
|
|
|
device := x.MustDevice()
|
|
var xCPU *ts.Tensor
|
|
if device != gotch.CPU {
|
|
xCPU = x.MustTo(device, false)
|
|
} else {
|
|
xCPU = x.MustShallowClone()
|
|
}
|
|
|
|
Fimg := Byte2FloatImage(xCPU)
|
|
|
|
fmt.Printf("Fimg size: %v\n", Fimg.MustSize())
|
|
h := float64(Fimg.MustSize()[1])
|
|
w := float64(Fimg.MustSize()[2])
|
|
padW := int64(rs.v * w)
|
|
padH := int64(rs.v * h)
|
|
fmt.Printf("padH: %v - padW: %v\n", padH, padW)
|
|
|
|
// img = np.pad(img, [(pad_h//2,pad_h//2), (pad_w//2,pad_w//2), (0,0)], mode='reflect')
|
|
padding := []int64{padH / 2, padH / 2, padW / 2, padW / 2, 0, 0}
|
|
fmt.Printf("padding: %+v\n", padding)
|
|
padImg := pad(Fimg, padding, "reflection")
|
|
xCPU.MustDrop()
|
|
Fimg.MustDrop()
|
|
// return img.resize((w,h), resample=Image.BILINEAR)
|
|
out, err := vision.Resize(padImg, int64(w), int64(h))
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
padImg.MustDrop()
|
|
|
|
return out.MustTo(device, true)
|
|
}
|