temporarily fixed aug/equalize
This commit is contained in:
parent
d6bfc5cd39
commit
b49afeb322
Binary file not shown.
Before Width: | Height: | Size: 23 KiB After Width: | Height: | Size: 61 KiB |
Binary file not shown.
Before Width: | Height: | Size: 23 KiB |
|
@ -42,8 +42,8 @@ func tOne() {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// device := gotch.CudaIfAvailable()
|
device := gotch.CudaIfAvailable()
|
||||||
device := gotch.CPU
|
// device := gotch.CPU
|
||||||
imgTs := img.MustTo(device, true)
|
imgTs := img.MustTo(device, true)
|
||||||
|
|
||||||
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
||||||
|
|
|
@ -1383,140 +1383,51 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor {
|
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||||
// # TODO: we should expect bincount to always be faster than histc, but this
|
|
||||||
// # isn't always the case. Once
|
|
||||||
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
|
|
||||||
// # block and only use bincount.
|
|
||||||
// if img_chan.is_cuda:
|
|
||||||
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
||||||
// else:
|
// NOTE. Use `Bincount` so that result similar to Pytorch. If use `Histc`, results are different!!!
|
||||||
// hist = torch.bincount(img_chan.view(-1), minlength=256)
|
device := imgChan.MustDevice()
|
||||||
|
var histo *ts.Tensor
|
||||||
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
if device == gotch.CPU {
|
||||||
hist := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
histo = imgChan.MustFlatten(0, -1, false).MustBincount(ts.NewTensor(), 256, true)
|
||||||
|
} else {
|
||||||
// nonzero_hist = hist[hist != 0]
|
histo = imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
||||||
nonZeroHist := hist.MustNonzero(false) // [n, 1]
|
|
||||||
|
|
||||||
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
|
|
||||||
nonZeroHistDim := nonZeroHist.MustSize()
|
|
||||||
nonZeroHistSum := nonZeroHist.MustNarrow(0, 0, nonZeroHistDim[0]-1, true).MustSum(gotch.Int64, true)
|
|
||||||
step := nonZeroHistSum.MustDiv1(ts.IntScalar(255), true)
|
|
||||||
stepVal := step.Int64Values()[0]
|
|
||||||
|
|
||||||
var out *ts.Tensor
|
|
||||||
// if step == 0:
|
|
||||||
// return img_chan
|
|
||||||
if stepVal == 0 {
|
|
||||||
out = imgChan.MustShallowClone()
|
|
||||||
|
|
||||||
hist.MustDrop()
|
|
||||||
step.MustDrop()
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),step, rounding_mode='floor')
|
// nonzero_hist = hist[hist != 0]
|
||||||
dtype := gotch.Float
|
nonzeroHistoIdx := histo.MustNonzero(false).MustFlatten(0, -1, true)
|
||||||
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
nonzeroHisto := histo.MustIndexSelect(0, nonzeroHistoIdx, false)
|
||||||
histCumSum := hist.MustCumsum(0, dtype, false)
|
nonzeroHistoIdx.MustDrop()
|
||||||
histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum
|
|
||||||
halfStep.MustDrop()
|
|
||||||
lut := histStep.MustDiv(step, true) // deleted histStep
|
|
||||||
step.MustDrop()
|
|
||||||
hist.MustDrop()
|
|
||||||
|
|
||||||
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
|
||||||
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
|
histoLen := nonzeroHisto.MustSize()[0]
|
||||||
lut1Dim := lut1.MustSize()
|
step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivide1(ts.FloatScalar(255.0), true)
|
||||||
|
|
||||||
// lut2 composes of 256 elements with value in range [0, 255]
|
|
||||||
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
|
|
||||||
|
|
||||||
/*
|
|
||||||
|
|
||||||
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
|
||||||
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
|
|
||||||
// imgChan is individual channel of image with [h, w]
|
|
||||||
h := imgChan.MustSize()[0]
|
|
||||||
var xs []ts.Tensor = make([]ts.Tensor, h)
|
|
||||||
imgC := imgChan.MustTotype(gotch.Int64, false)
|
|
||||||
// Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values.
|
|
||||||
for i := 0; i < int(h); i++ {
|
|
||||||
// NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix!
|
|
||||||
idx := imgC.MustSelect(0, int64(i), false)
|
|
||||||
x := lut2.MustIndexSelect(0, idx, false)
|
|
||||||
xs[i] = *x
|
|
||||||
idx.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
imgC.MustDrop()
|
|
||||||
lut2.MustDrop()
|
|
||||||
|
|
||||||
out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true)
|
|
||||||
// delete intermediate tensors
|
|
||||||
for _, x := range xs {
|
|
||||||
x.MustDrop()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
|
||||||
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
|
||||||
out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
|
||||||
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
|
||||||
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
|
||||||
// Compute the histogram of the image channel.
|
|
||||||
// histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
|
|
||||||
histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
|
||||||
// histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true)
|
|
||||||
|
|
||||||
// For the purposes of computing the step, filter out the nonzeros.
|
|
||||||
// nonzero_histo = torch.reshape(histo[histo != 0], [-1])
|
|
||||||
nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true)
|
|
||||||
|
|
||||||
// step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
|
||||||
sum := nonzeroHisto.MustSum(gotch.Int64, false)
|
|
||||||
sub := nonzeroHisto.MustSelect(0, -1, true)
|
|
||||||
step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true)
|
|
||||||
sub.MustDrop()
|
|
||||||
stepVal := step.Float64Values()[0]
|
stepVal := step.Float64Values()[0]
|
||||||
if stepVal == 0 {
|
if stepVal == 0 {
|
||||||
histo.MustDrop()
|
histo.MustDrop()
|
||||||
|
step.MustDrop()
|
||||||
out := imgChan.MustShallowClone()
|
out := imgChan.MustShallowClone()
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), step, rounding_mode='floor')
|
||||||
|
halfStep := step.MustFloorDivide1(ts.FloatScalar(2.0), false)
|
||||||
|
lut := histo.Must_Cumsum(0, true).MustAdd(halfStep, true).MustFloorDivide(step, true)
|
||||||
|
step.MustDrop()
|
||||||
|
halfStep.MustDrop()
|
||||||
|
|
||||||
// Build lut
|
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
||||||
// Compute the cumulative sum, shifting by step // 2
|
lutLen := lut.MustSize()[0]
|
||||||
// and then normalization by step.
|
lut = lut.MustConstantPadNd([]int64{1, 0}, true).MustNarrow(0, 0, lutLen, true).MustClamp(ts.FloatScalar(0), ts.FloatScalar(255.0), true)
|
||||||
// lut = (torch.cumsum(histo, 0) + (step // 2)) // step
|
|
||||||
lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256
|
|
||||||
|
|
||||||
// Shift lut, prepending with 0.
|
|
||||||
// lut = torch.cat([torch.zeros(1), lut[:-1]])
|
|
||||||
l := lut.MustSize()[0] // 256
|
|
||||||
lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255
|
|
||||||
zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice())
|
|
||||||
lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0)
|
|
||||||
|
|
||||||
// Clip the counts to be in range.
|
|
||||||
// torch.clamp(lut, 0, 255)
|
|
||||||
lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true)
|
|
||||||
|
|
||||||
|
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
||||||
// can't index using 2d index. Have to flatten and then reshape
|
// can't index using 2d index. Have to flatten and then reshape
|
||||||
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
|
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
|
||||||
// result = result.reshape_as(im)
|
// result = result.reshape_as(im)
|
||||||
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
||||||
out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
out := lut.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true).MustTotype(gotch.Uint8, true)
|
||||||
|
flattenImg.MustDrop()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user