temporarily fixed aug/equalize

This commit is contained in:
sugarme 2021-06-24 18:41:40 +10:00
parent d6bfc5cd39
commit b49afeb322
4 changed files with 29 additions and 118 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

After

Width:  |  Height:  |  Size: 61 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

View File

@ -42,8 +42,8 @@ func tOne() {
panic(err)
}
// device := gotch.CudaIfAvailable()
device := gotch.CPU
device := gotch.CudaIfAvailable()
// device := gotch.CPU
imgTs := img.MustTo(device, true)
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))

View File

@ -1383,140 +1383,51 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
return out
}
func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor {
// # TODO: we should expect bincount to always be faster than histc, but this
// # isn't always the case. Once
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
// # block and only use bincount.
// if img_chan.is_cuda:
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
// else:
// hist = torch.bincount(img_chan.view(-1), minlength=256)
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
hist := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
// nonzero_hist = hist[hist != 0]
nonZeroHist := hist.MustNonzero(false) // [n, 1]
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
nonZeroHistDim := nonZeroHist.MustSize()
nonZeroHistSum := nonZeroHist.MustNarrow(0, 0, nonZeroHistDim[0]-1, true).MustSum(gotch.Int64, true)
step := nonZeroHistSum.MustDiv1(ts.IntScalar(255), true)
stepVal := step.Int64Values()[0]
var out *ts.Tensor
// if step == 0:
// return img_chan
if stepVal == 0 {
out = imgChan.MustShallowClone()
hist.MustDrop()
step.MustDrop()
return out
// NOTE. Use `Bincount` so that result similar to Pytorch. If use `Histc`, results are different!!!
device := imgChan.MustDevice()
var histo *ts.Tensor
if device == gotch.CPU {
histo = imgChan.MustFlatten(0, -1, false).MustBincount(ts.NewTensor(), 256, true)
} else {
histo = imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
}
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),step, rounding_mode='floor')
dtype := gotch.Float
halfStep := step.MustDiv1(ts.IntScalar(2), false)
histCumSum := hist.MustCumsum(0, dtype, false)
histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum
halfStep.MustDrop()
lut := histStep.MustDiv(step, true) // deleted histStep
step.MustDrop()
hist.MustDrop()
// nonzero_hist = hist[hist != 0]
nonzeroHistoIdx := histo.MustNonzero(false).MustFlatten(0, -1, true)
nonzeroHisto := histo.MustIndexSelect(0, nonzeroHistoIdx, false)
nonzeroHistoIdx.MustDrop()
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
lut1Dim := lut1.MustSize()
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
histoLen := nonzeroHisto.MustSize()[0]
step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivide1(ts.FloatScalar(255.0), true)
// lut2 composes of 256 elements with value in range [0, 255]
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
/*
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
// imgChan is individual channel of image with [h, w]
h := imgChan.MustSize()[0]
var xs []ts.Tensor = make([]ts.Tensor, h)
imgC := imgChan.MustTotype(gotch.Int64, false)
// Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values.
for i := 0; i < int(h); i++ {
// NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix!
idx := imgC.MustSelect(0, int64(i), false)
x := lut2.MustIndexSelect(0, idx, false)
xs[i] = *x
idx.MustDrop()
}
imgC.MustDrop()
lut2.MustDrop()
out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true)
// delete intermediate tensors
for _, x := range xs {
x.MustDrop()
}
*/
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
return out
}
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
// Compute the histogram of the image channel.
// histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
// histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true)
// For the purposes of computing the step, filter out the nonzeros.
// nonzero_histo = torch.reshape(histo[histo != 0], [-1])
nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true)
// step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
sum := nonzeroHisto.MustSum(gotch.Int64, false)
sub := nonzeroHisto.MustSelect(0, -1, true)
step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true)
sub.MustDrop()
stepVal := step.Float64Values()[0]
if stepVal == 0 {
histo.MustDrop()
step.MustDrop()
out := imgChan.MustShallowClone()
return out
}
halfStep := step.MustDiv1(ts.IntScalar(2), false)
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), step, rounding_mode='floor')
halfStep := step.MustFloorDivide1(ts.FloatScalar(2.0), false)
lut := histo.Must_Cumsum(0, true).MustAdd(halfStep, true).MustFloorDivide(step, true)
step.MustDrop()
halfStep.MustDrop()
// Build lut
// Compute the cumulative sum, shifting by step // 2
// and then normalization by step.
// lut = (torch.cumsum(histo, 0) + (step // 2)) // step
lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256
// Shift lut, prepending with 0.
// lut = torch.cat([torch.zeros(1), lut[:-1]])
l := lut.MustSize()[0] // 256
lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255
zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice())
lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0)
// Clip the counts to be in range.
// torch.clamp(lut, 0, 255)
lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true)
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
lutLen := lut.MustSize()[0]
lut = lut.MustConstantPadNd([]int64{1, 0}, true).MustNarrow(0, 0, lutLen, true).MustClamp(ts.FloatScalar(0), ts.FloatScalar(255.0), true)
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
// can't index using 2d index. Have to flatten and then reshape
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
// result = result.reshape_as(im)
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
out := lut.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true).MustTotype(gotch.Uint8, true)
flattenImg.MustDrop()
return out
}