fixed RandomAdjustSharpness

This commit is contained in:
sugarme 2021-06-25 13:36:25 +10:00
parent 8aaf69494b
commit 5923e0f2e2
3 changed files with 10 additions and 12 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 39 KiB

After

Width:  |  Height:  |  Size: 86 KiB

View File

@ -1274,8 +1274,8 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
// kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel := ts.MustOnes([]int64{3, 3}, dtype, device)
// kernel[1, 1] = 5.0 - Center kernel value
kernelView := kernel.MustNarrow(1, 1, 2, false)
// kernel[1, 1] = 5.0
kernelView := kernel.MustNarrow(1, 1, 1, false).MustNarrow(0, 1, 1, true)
centerVal := kernelView.MustOnesLike(false).MustMul1(ts.FloatScalar(5.0), true)
kernelView.Copy_(centerVal) // center kernel value
centerVal.MustDrop()
@ -1299,13 +1299,16 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
dilation := []int64{1, 1}
resTmpDim := resTmp.MustSize()
group := resTmpDim[len(resTmpDim)-3]
// resTmp1 shape: [1, 3, h, w]
resTmp1 := ts.MustConv2d(resTmp, kernelExp, ts.NewTensor(), stride, padding, dilation, group)
// result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
// resTmp2 shape: [3, h, w]
resTmp2 := castSqueezeOut(resTmp1, needCast, needSqueeze, outDtype)
// result = img.clone()
out := img.MustShallowClone()
// NOTE. out := img.MustShallowClone() doesn't work!
out := img.MustZerosLike(false)
// result[..., 1:-1, 1:-1] = result_tmp
hDim := int64(len(dim) - 2) // second last dim

View File

@ -53,23 +53,18 @@ func newRandomAdjustSharpness(opts ...sharpnessOption) *RandomAdjustSharpness {
}
}
// NOTE. input img dtype shoule be `uint8` (Byte)
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
fx := Byte2FloatImage(x)
r := randPvalue()
var out *ts.Tensor
switch {
case r < ras.pvalue:
out = adjustSharpness(fx, ras.sharpnessFactor)
out = adjustSharpness(x, ras.sharpnessFactor)
default:
out = fx.MustShallowClone()
out = x.MustShallowClone()
}
bx := Float2ByteImage(out)
fx.MustDrop()
out.MustDrop()
return bx
return out
}
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {