fixed RandomAdjustSharpness
This commit is contained in:
parent
8aaf69494b
commit
5923e0f2e2
Binary file not shown.
Before Width: | Height: | Size: 39 KiB After Width: | Height: | Size: 86 KiB |
|
@ -1274,8 +1274,8 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
|
||||||
// kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
|
// kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
|
||||||
kernel := ts.MustOnes([]int64{3, 3}, dtype, device)
|
kernel := ts.MustOnes([]int64{3, 3}, dtype, device)
|
||||||
|
|
||||||
// kernel[1, 1] = 5.0 - Center kernel value
|
// kernel[1, 1] = 5.0
|
||||||
kernelView := kernel.MustNarrow(1, 1, 2, false)
|
kernelView := kernel.MustNarrow(1, 1, 1, false).MustNarrow(0, 1, 1, true)
|
||||||
centerVal := kernelView.MustOnesLike(false).MustMul1(ts.FloatScalar(5.0), true)
|
centerVal := kernelView.MustOnesLike(false).MustMul1(ts.FloatScalar(5.0), true)
|
||||||
kernelView.Copy_(centerVal) // center kernel value
|
kernelView.Copy_(centerVal) // center kernel value
|
||||||
centerVal.MustDrop()
|
centerVal.MustDrop()
|
||||||
|
@ -1299,13 +1299,16 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
|
||||||
dilation := []int64{1, 1}
|
dilation := []int64{1, 1}
|
||||||
resTmpDim := resTmp.MustSize()
|
resTmpDim := resTmp.MustSize()
|
||||||
group := resTmpDim[len(resTmpDim)-3]
|
group := resTmpDim[len(resTmpDim)-3]
|
||||||
|
// resTmp1 shape: [1, 3, h, w]
|
||||||
resTmp1 := ts.MustConv2d(resTmp, kernelExp, ts.NewTensor(), stride, padding, dilation, group)
|
resTmp1 := ts.MustConv2d(resTmp, kernelExp, ts.NewTensor(), stride, padding, dilation, group)
|
||||||
|
|
||||||
// result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
|
// result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
|
||||||
|
// resTmp2 shape: [3, h, w]
|
||||||
resTmp2 := castSqueezeOut(resTmp1, needCast, needSqueeze, outDtype)
|
resTmp2 := castSqueezeOut(resTmp1, needCast, needSqueeze, outDtype)
|
||||||
|
|
||||||
// result = img.clone()
|
// result = img.clone()
|
||||||
out := img.MustShallowClone()
|
// NOTE. out := img.MustShallowClone() doesn't work!
|
||||||
|
out := img.MustZerosLike(false)
|
||||||
|
|
||||||
// result[..., 1:-1, 1:-1] = result_tmp
|
// result[..., 1:-1, 1:-1] = result_tmp
|
||||||
hDim := int64(len(dim) - 2) // second last dim
|
hDim := int64(len(dim) - 2) // second last dim
|
||||||
|
|
|
@ -53,23 +53,18 @@ func newRandomAdjustSharpness(opts ...sharpnessOption) *RandomAdjustSharpness {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NOTE. input img dtype shoule be `uint8` (Byte)
|
||||||
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
|
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
|
||||||
fx := Byte2FloatImage(x)
|
|
||||||
|
|
||||||
r := randPvalue()
|
r := randPvalue()
|
||||||
var out *ts.Tensor
|
var out *ts.Tensor
|
||||||
switch {
|
switch {
|
||||||
case r < ras.pvalue:
|
case r < ras.pvalue:
|
||||||
out = adjustSharpness(fx, ras.sharpnessFactor)
|
out = adjustSharpness(x, ras.sharpnessFactor)
|
||||||
default:
|
default:
|
||||||
out = fx.MustShallowClone()
|
out = x.MustShallowClone()
|
||||||
}
|
}
|
||||||
|
|
||||||
bx := Float2ByteImage(out)
|
return out
|
||||||
fx.MustDrop()
|
|
||||||
out.MustDrop()
|
|
||||||
|
|
||||||
return bx
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {
|
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {
|
||||||
|
|
Loading…
Reference in New Issue
Block a user