temporarily fixed aug/equalize
This commit is contained in:
parent
d6bfc5cd39
commit
b49afeb322
Binary file not shown.
Before Width: | Height: | Size: 23 KiB After Width: | Height: | Size: 61 KiB |
Binary file not shown.
Before Width: | Height: | Size: 23 KiB |
|
@ -42,8 +42,8 @@ func tOne() {
|
|||
panic(err)
|
||||
}
|
||||
|
||||
// device := gotch.CudaIfAvailable()
|
||||
device := gotch.CPU
|
||||
device := gotch.CudaIfAvailable()
|
||||
// device := gotch.CPU
|
||||
imgTs := img.MustTo(device, true)
|
||||
|
||||
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
||||
|
|
|
@ -1383,140 +1383,51 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
|||
return out
|
||||
}
|
||||
|
||||
func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor {
|
||||
// # TODO: we should expect bincount to always be faster than histc, but this
|
||||
// # isn't always the case. Once
|
||||
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
|
||||
// # block and only use bincount.
|
||||
// if img_chan.is_cuda:
|
||||
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
||||
// else:
|
||||
// hist = torch.bincount(img_chan.view(-1), minlength=256)
|
||||
|
||||
// hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
|
||||
hist := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
||||
|
||||
// nonzero_hist = hist[hist != 0]
|
||||
nonZeroHist := hist.MustNonzero(false) // [n, 1]
|
||||
|
||||
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
|
||||
nonZeroHistDim := nonZeroHist.MustSize()
|
||||
nonZeroHistSum := nonZeroHist.MustNarrow(0, 0, nonZeroHistDim[0]-1, true).MustSum(gotch.Int64, true)
|
||||
step := nonZeroHistSum.MustDiv1(ts.IntScalar(255), true)
|
||||
stepVal := step.Int64Values()[0]
|
||||
|
||||
var out *ts.Tensor
|
||||
// if step == 0:
|
||||
// return img_chan
|
||||
if stepVal == 0 {
|
||||
out = imgChan.MustShallowClone()
|
||||
|
||||
hist.MustDrop()
|
||||
step.MustDrop()
|
||||
|
||||
return out
|
||||
// NOTE. Use `Bincount` so that result similar to Pytorch. If use `Histc`, results are different!!!
|
||||
device := imgChan.MustDevice()
|
||||
var histo *ts.Tensor
|
||||
if device == gotch.CPU {
|
||||
histo = imgChan.MustFlatten(0, -1, false).MustBincount(ts.NewTensor(), 256, true)
|
||||
} else {
|
||||
histo = imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
||||
}
|
||||
|
||||
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),step, rounding_mode='floor')
|
||||
dtype := gotch.Float
|
||||
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
||||
histCumSum := hist.MustCumsum(0, dtype, false)
|
||||
histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum
|
||||
halfStep.MustDrop()
|
||||
lut := histStep.MustDiv(step, true) // deleted histStep
|
||||
step.MustDrop()
|
||||
hist.MustDrop()
|
||||
// nonzero_hist = hist[hist != 0]
|
||||
nonzeroHistoIdx := histo.MustNonzero(false).MustFlatten(0, -1, true)
|
||||
nonzeroHisto := histo.MustIndexSelect(0, nonzeroHistoIdx, false)
|
||||
nonzeroHistoIdx.MustDrop()
|
||||
|
||||
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
||||
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
|
||||
lut1Dim := lut1.MustSize()
|
||||
// step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor')
|
||||
histoLen := nonzeroHisto.MustSize()[0]
|
||||
step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivide1(ts.FloatScalar(255.0), true)
|
||||
|
||||
// lut2 composes of 256 elements with value in range [0, 255]
|
||||
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
|
||||
|
||||
/*
|
||||
|
||||
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
||||
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
|
||||
// imgChan is individual channel of image with [h, w]
|
||||
h := imgChan.MustSize()[0]
|
||||
var xs []ts.Tensor = make([]ts.Tensor, h)
|
||||
imgC := imgChan.MustTotype(gotch.Int64, false)
|
||||
// Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values.
|
||||
for i := 0; i < int(h); i++ {
|
||||
// NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix!
|
||||
idx := imgC.MustSelect(0, int64(i), false)
|
||||
x := lut2.MustIndexSelect(0, idx, false)
|
||||
xs[i] = *x
|
||||
idx.MustDrop()
|
||||
}
|
||||
|
||||
imgC.MustDrop()
|
||||
lut2.MustDrop()
|
||||
|
||||
out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true)
|
||||
// delete intermediate tensors
|
||||
for _, x := range xs {
|
||||
x.MustDrop()
|
||||
}
|
||||
|
||||
|
||||
*/
|
||||
|
||||
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
||||
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
||||
out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
||||
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||
// Compute the histogram of the image channel.
|
||||
// histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
|
||||
histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
||||
// histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true)
|
||||
|
||||
// For the purposes of computing the step, filter out the nonzeros.
|
||||
// nonzero_histo = torch.reshape(histo[histo != 0], [-1])
|
||||
nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true)
|
||||
|
||||
// step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
||||
sum := nonzeroHisto.MustSum(gotch.Int64, false)
|
||||
sub := nonzeroHisto.MustSelect(0, -1, true)
|
||||
step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true)
|
||||
sub.MustDrop()
|
||||
stepVal := step.Float64Values()[0]
|
||||
if stepVal == 0 {
|
||||
histo.MustDrop()
|
||||
step.MustDrop()
|
||||
out := imgChan.MustShallowClone()
|
||||
return out
|
||||
}
|
||||
|
||||
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
||||
// lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), step, rounding_mode='floor')
|
||||
halfStep := step.MustFloorDivide1(ts.FloatScalar(2.0), false)
|
||||
lut := histo.Must_Cumsum(0, true).MustAdd(halfStep, true).MustFloorDivide(step, true)
|
||||
step.MustDrop()
|
||||
halfStep.MustDrop()
|
||||
|
||||
// Build lut
|
||||
// Compute the cumulative sum, shifting by step // 2
|
||||
// and then normalization by step.
|
||||
// lut = (torch.cumsum(histo, 0) + (step // 2)) // step
|
||||
lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256
|
||||
|
||||
// Shift lut, prepending with 0.
|
||||
// lut = torch.cat([torch.zeros(1), lut[:-1]])
|
||||
l := lut.MustSize()[0] // 256
|
||||
lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255
|
||||
zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice())
|
||||
lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0)
|
||||
|
||||
// Clip the counts to be in range.
|
||||
// torch.clamp(lut, 0, 255)
|
||||
lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true)
|
||||
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
||||
lutLen := lut.MustSize()[0]
|
||||
lut = lut.MustConstantPadNd([]int64{1, 0}, true).MustNarrow(0, 0, lutLen, true).MustClamp(ts.FloatScalar(0), ts.FloatScalar(255.0), true)
|
||||
|
||||
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
||||
// can't index using 2d index. Have to flatten and then reshape
|
||||
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
|
||||
// result = result.reshape_as(im)
|
||||
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
||||
out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
||||
out := lut.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true).MustTotype(gotch.Uint8, true)
|
||||
flattenImg.MustDrop()
|
||||
|
||||
return out
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user