WIP: aug equalize

This commit is contained in:
sugarme 2021-06-24 07:22:30 +10:00
parent b6f5a89f7b
commit d6bfc5cd39
6 changed files with 97 additions and 28 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 278 KiB

BIN
example/augmentation/bb.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 23 KiB

View File

@ -37,7 +37,7 @@ func roundTrip() {
}
func tOne() {
img, err := vision.Load("./bb.png")
img, err := vision.Load("bb.png")
if err != nil {
panic(err)
}
@ -46,7 +46,7 @@ func tOne() {
device := gotch.CPU
imgTs := img.MustTo(device, true)
t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10)))
// t, err := aug.Compose(aug.WithRandRotate(0, 360))
@ -58,7 +58,7 @@ func tOne() {
// t, err := aug.Compose(aug.WithRandomGrayscale(1.0))
// t, err := aug.Compose(aug.WithRandomVFlip(1.0))
// t, err := aug.Compose(aug.WithRandomHFlip(1.0))
// t, err := aug.Compose(aug.WithRandomEqualize(1.0))
t, err := aug.Compose(aug.WithRandomEqualize(1.0))
// t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5})))
// t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320}))
// t, err := aug.Compose(aug.WithRandomAutocontrast())
@ -67,7 +67,7 @@ func tOne() {
// t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15})))
out := t.Transform(imgTs)
fname := fmt.Sprintf("./bb-transformed.png")
fname := fmt.Sprintf("./bb-transformed.jpg")
err = vision.Save(out, fname)
if err != nil {
panic(err)

View File

@ -1357,7 +1357,6 @@ func equalize(img *ts.Tensor) *ts.Tensor {
}
out := ts.MustStack(images, 0)
for _, x := range images {
x.MustDrop()
}
@ -1367,12 +1366,12 @@ func equalize(img *ts.Tensor) *ts.Tensor {
func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
dim := img.MustSize()
var scaledChans []ts.Tensor
var scaledChans []ts.Tensor = make([]ts.Tensor, int(dim[0]))
for i := 0; i < int(dim[0]); i++ {
cTs := img.MustSelect(0, int64(i), false)
scaledChan := scaleChannel(cTs)
cTs.MustDrop()
scaledChans = append(scaledChans, *scaledChan)
scaledChans[i] = *scaledChan
}
out := ts.MustStack(scaledChans, 0)
@ -1384,7 +1383,7 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
return out
}
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor {
// # TODO: we should expect bincount to always be faster than histc, but this
// # isn't always the case. Once
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
@ -1411,6 +1410,10 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
// return img_chan
if stepVal == 0 {
out = imgChan.MustShallowClone()
hist.MustDrop()
step.MustDrop()
return out
}
@ -1418,36 +1421,102 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
dtype := gotch.Float
halfStep := step.MustDiv1(ts.IntScalar(2), false)
histCumSum := hist.MustCumsum(0, dtype, false)
histStep := histCumSum.MustAdd(halfStep, false)
histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum
halfStep.MustDrop()
lut := histStep.MustDiv(step, true) // deleted histStep
step.MustDrop()
hist.MustDrop()
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
lut1Dim := lut1.MustSize()
// lut2 composes of 256 elements with value in range [0, 255]
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
// channel[h, w]
h := imgChan.MustSize()[0]
// w := imgChan.MustSize()[1]
var xs []ts.Tensor
for i := 0; i < int(h); i++ {
idx := imgChan.MustSelect(0, int64(i), false).MustTotype(gotch.Int64, true)
x := lut2.MustIndexSelect(0, idx, false).MustTotype(gotch.Uint8, true)
xs = append(xs, *x)
idx.MustDrop()
}
out = ts.MustStack(xs, 0)
// delete intermediate tensors
for _, x := range xs {
x.MustDrop()
/*
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
// imgChan is individual channel of image with [h, w]
h := imgChan.MustSize()[0]
var xs []ts.Tensor = make([]ts.Tensor, h)
imgC := imgChan.MustTotype(gotch.Int64, false)
// Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values.
for i := 0; i < int(h); i++ {
// NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix!
idx := imgC.MustSelect(0, int64(i), false)
x := lut2.MustIndexSelect(0, idx, false)
xs[i] = *x
idx.MustDrop()
}
imgC.MustDrop()
lut2.MustDrop()
out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true)
// delete intermediate tensors
for _, x := range xs {
x.MustDrop()
}
*/
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
return out
}
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
// Compute the histogram of the image channel.
// histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
// histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true)
// For the purposes of computing the step, filter out the nonzeros.
// nonzero_histo = torch.reshape(histo[histo != 0], [-1])
nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true)
// step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
sum := nonzeroHisto.MustSum(gotch.Int64, false)
sub := nonzeroHisto.MustSelect(0, -1, true)
step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true)
sub.MustDrop()
stepVal := step.Float64Values()[0]
if stepVal == 0 {
histo.MustDrop()
out := imgChan.MustShallowClone()
return out
}
lut2.MustDrop()
hist.MustDrop()
step.MustDrop()
halfStep := step.MustDiv1(ts.IntScalar(2), false)
// Build lut
// Compute the cumulative sum, shifting by step // 2
// and then normalization by step.
// lut = (torch.cumsum(histo, 0) + (step // 2)) // step
lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256
// Shift lut, prepending with 0.
// lut = torch.cat([torch.zeros(1), lut[:-1]])
l := lut.MustSize()[0] // 256
lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255
zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice())
lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0)
// Clip the counts to be in range.
// torch.clamp(lut, 0, 255)
lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true)
// can't index using 2d index. Have to flatten and then reshape
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
// result = result.reshape_as(im)
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
return out
}

BIN
~/Downloads/bb.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 265 KiB