fixed RandomAdjustSharpness
This commit is contained in:
parent
8aaf69494b
commit
5923e0f2e2
Binary file not shown.
Before Width: | Height: | Size: 39 KiB After Width: | Height: | Size: 86 KiB |
|
@ -1274,8 +1274,8 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
|
|||
// kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
|
||||
kernel := ts.MustOnes([]int64{3, 3}, dtype, device)
|
||||
|
||||
// kernel[1, 1] = 5.0 - Center kernel value
|
||||
kernelView := kernel.MustNarrow(1, 1, 2, false)
|
||||
// kernel[1, 1] = 5.0
|
||||
kernelView := kernel.MustNarrow(1, 1, 1, false).MustNarrow(0, 1, 1, true)
|
||||
centerVal := kernelView.MustOnesLike(false).MustMul1(ts.FloatScalar(5.0), true)
|
||||
kernelView.Copy_(centerVal) // center kernel value
|
||||
centerVal.MustDrop()
|
||||
|
@ -1299,13 +1299,16 @@ func blurredDegenerateImage(img *ts.Tensor) *ts.Tensor {
|
|||
dilation := []int64{1, 1}
|
||||
resTmpDim := resTmp.MustSize()
|
||||
group := resTmpDim[len(resTmpDim)-3]
|
||||
// resTmp1 shape: [1, 3, h, w]
|
||||
resTmp1 := ts.MustConv2d(resTmp, kernelExp, ts.NewTensor(), stride, padding, dilation, group)
|
||||
|
||||
// result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
|
||||
// resTmp2 shape: [3, h, w]
|
||||
resTmp2 := castSqueezeOut(resTmp1, needCast, needSqueeze, outDtype)
|
||||
|
||||
// result = img.clone()
|
||||
out := img.MustShallowClone()
|
||||
// NOTE. out := img.MustShallowClone() doesn't work!
|
||||
out := img.MustZerosLike(false)
|
||||
|
||||
// result[..., 1:-1, 1:-1] = result_tmp
|
||||
hDim := int64(len(dim) - 2) // second last dim
|
||||
|
|
|
@ -53,23 +53,18 @@ func newRandomAdjustSharpness(opts ...sharpnessOption) *RandomAdjustSharpness {
|
|||
}
|
||||
}
|
||||
|
||||
// NOTE. input img dtype shoule be `uint8` (Byte)
|
||||
func (ras *RandomAdjustSharpness) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
fx := Byte2FloatImage(x)
|
||||
|
||||
r := randPvalue()
|
||||
var out *ts.Tensor
|
||||
switch {
|
||||
case r < ras.pvalue:
|
||||
out = adjustSharpness(fx, ras.sharpnessFactor)
|
||||
out = adjustSharpness(x, ras.sharpnessFactor)
|
||||
default:
|
||||
out = fx.MustShallowClone()
|
||||
out = x.MustShallowClone()
|
||||
}
|
||||
|
||||
bx := Float2ByteImage(out)
|
||||
fx.MustDrop()
|
||||
out.MustDrop()
|
||||
|
||||
return bx
|
||||
return out
|
||||
}
|
||||
|
||||
func WithRandomAdjustSharpness(opts ...sharpnessOption) Option {
|
||||
|
|
Loading…
Reference in New Issue
Block a user