WIP: aug equalize
This commit is contained in:
parent
b6f5a89f7b
commit
d6bfc5cd39
BIN
example/augmentation/bb-transformed.jpg
Normal file
BIN
example/augmentation/bb-transformed.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
Binary file not shown.
Before Width: | Height: | Size: 278 KiB |
BIN
example/augmentation/bb.jpg
Normal file
BIN
example/augmentation/bb.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 23 KiB |
|
@ -37,7 +37,7 @@ func roundTrip() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func tOne() {
|
func tOne() {
|
||||||
img, err := vision.Load("./bb.png")
|
img, err := vision.Load("bb.png")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -46,7 +46,7 @@ func tOne() {
|
||||||
device := gotch.CPU
|
device := gotch.CPU
|
||||||
imgTs := img.MustTo(device, true)
|
imgTs := img.MustTo(device, true)
|
||||||
|
|
||||||
t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
// t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
||||||
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
|
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
|
||||||
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10)))
|
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10)))
|
||||||
// t, err := aug.Compose(aug.WithRandRotate(0, 360))
|
// t, err := aug.Compose(aug.WithRandRotate(0, 360))
|
||||||
|
@ -58,7 +58,7 @@ func tOne() {
|
||||||
// t, err := aug.Compose(aug.WithRandomGrayscale(1.0))
|
// t, err := aug.Compose(aug.WithRandomGrayscale(1.0))
|
||||||
// t, err := aug.Compose(aug.WithRandomVFlip(1.0))
|
// t, err := aug.Compose(aug.WithRandomVFlip(1.0))
|
||||||
// t, err := aug.Compose(aug.WithRandomHFlip(1.0))
|
// t, err := aug.Compose(aug.WithRandomHFlip(1.0))
|
||||||
// t, err := aug.Compose(aug.WithRandomEqualize(1.0))
|
t, err := aug.Compose(aug.WithRandomEqualize(1.0))
|
||||||
// t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5})))
|
// t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5})))
|
||||||
// t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320}))
|
// t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320}))
|
||||||
// t, err := aug.Compose(aug.WithRandomAutocontrast())
|
// t, err := aug.Compose(aug.WithRandomAutocontrast())
|
||||||
|
@ -67,7 +67,7 @@ func tOne() {
|
||||||
// t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15})))
|
// t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15})))
|
||||||
|
|
||||||
out := t.Transform(imgTs)
|
out := t.Transform(imgTs)
|
||||||
fname := fmt.Sprintf("./bb-transformed.png")
|
fname := fmt.Sprintf("./bb-transformed.jpg")
|
||||||
err = vision.Save(out, fname)
|
err = vision.Save(out, fname)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -1357,7 +1357,6 @@ func equalize(img *ts.Tensor) *ts.Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
out := ts.MustStack(images, 0)
|
out := ts.MustStack(images, 0)
|
||||||
|
|
||||||
for _, x := range images {
|
for _, x := range images {
|
||||||
x.MustDrop()
|
x.MustDrop()
|
||||||
}
|
}
|
||||||
|
@ -1367,12 +1366,12 @@ func equalize(img *ts.Tensor) *ts.Tensor {
|
||||||
|
|
||||||
func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
||||||
dim := img.MustSize()
|
dim := img.MustSize()
|
||||||
var scaledChans []ts.Tensor
|
var scaledChans []ts.Tensor = make([]ts.Tensor, int(dim[0]))
|
||||||
for i := 0; i < int(dim[0]); i++ {
|
for i := 0; i < int(dim[0]); i++ {
|
||||||
cTs := img.MustSelect(0, int64(i), false)
|
cTs := img.MustSelect(0, int64(i), false)
|
||||||
scaledChan := scaleChannel(cTs)
|
scaledChan := scaleChannel(cTs)
|
||||||
cTs.MustDrop()
|
cTs.MustDrop()
|
||||||
scaledChans = append(scaledChans, *scaledChan)
|
scaledChans[i] = *scaledChan
|
||||||
}
|
}
|
||||||
|
|
||||||
out := ts.MustStack(scaledChans, 0)
|
out := ts.MustStack(scaledChans, 0)
|
||||||
|
@ -1384,7 +1383,7 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor {
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor {
|
||||||
// # TODO: we should expect bincount to always be faster than histc, but this
|
// # TODO: we should expect bincount to always be faster than histc, but this
|
||||||
// # isn't always the case. Once
|
// # isn't always the case. Once
|
||||||
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
|
// # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
|
||||||
|
@ -1411,6 +1410,10 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||||
// return img_chan
|
// return img_chan
|
||||||
if stepVal == 0 {
|
if stepVal == 0 {
|
||||||
out = imgChan.MustShallowClone()
|
out = imgChan.MustShallowClone()
|
||||||
|
|
||||||
|
hist.MustDrop()
|
||||||
|
step.MustDrop()
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1418,36 +1421,102 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||||
dtype := gotch.Float
|
dtype := gotch.Float
|
||||||
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
||||||
histCumSum := hist.MustCumsum(0, dtype, false)
|
histCumSum := hist.MustCumsum(0, dtype, false)
|
||||||
histStep := histCumSum.MustAdd(halfStep, false)
|
histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum
|
||||||
halfStep.MustDrop()
|
halfStep.MustDrop()
|
||||||
lut := histStep.MustDiv(step, true) // deleted histStep
|
lut := histStep.MustDiv(step, true) // deleted histStep
|
||||||
|
step.MustDrop()
|
||||||
|
hist.MustDrop()
|
||||||
|
|
||||||
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
// lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
|
||||||
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
|
lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut
|
||||||
lut1Dim := lut1.MustSize()
|
lut1Dim := lut1.MustSize()
|
||||||
|
|
||||||
|
// lut2 composes of 256 elements with value in range [0, 255]
|
||||||
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
|
lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1
|
||||||
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
|
||||||
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
|
|
||||||
// channel[h, w]
|
|
||||||
h := imgChan.MustSize()[0]
|
|
||||||
// w := imgChan.MustSize()[1]
|
|
||||||
var xs []ts.Tensor
|
|
||||||
for i := 0; i < int(h); i++ {
|
|
||||||
idx := imgChan.MustSelect(0, int64(i), false).MustTotype(gotch.Int64, true)
|
|
||||||
x := lut2.MustIndexSelect(0, idx, false).MustTotype(gotch.Uint8, true)
|
|
||||||
xs = append(xs, *x)
|
|
||||||
idx.MustDrop()
|
|
||||||
}
|
|
||||||
out = ts.MustStack(xs, 0)
|
|
||||||
|
|
||||||
// delete intermediate tensors
|
/*
|
||||||
for _, x := range xs {
|
|
||||||
x.MustDrop()
|
// return lut[img_chan.to(torch.int64)].to(torch.uint8)
|
||||||
|
// NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop
|
||||||
|
// imgChan is individual channel of image with [h, w]
|
||||||
|
h := imgChan.MustSize()[0]
|
||||||
|
var xs []ts.Tensor = make([]ts.Tensor, h)
|
||||||
|
imgC := imgChan.MustTotype(gotch.Int64, false)
|
||||||
|
// Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values.
|
||||||
|
for i := 0; i < int(h); i++ {
|
||||||
|
// NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix!
|
||||||
|
idx := imgC.MustSelect(0, int64(i), false)
|
||||||
|
x := lut2.MustIndexSelect(0, idx, false)
|
||||||
|
xs[i] = *x
|
||||||
|
idx.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
imgC.MustDrop()
|
||||||
|
lut2.MustDrop()
|
||||||
|
|
||||||
|
out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true)
|
||||||
|
// delete intermediate tensors
|
||||||
|
for _, x := range xs {
|
||||||
|
x.MustDrop()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
||||||
|
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
||||||
|
out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179
|
||||||
|
func scaleChannel(imgChan *ts.Tensor) *ts.Tensor {
|
||||||
|
// Compute the histogram of the image channel.
|
||||||
|
// histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32)
|
||||||
|
histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true)
|
||||||
|
// histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true)
|
||||||
|
|
||||||
|
// For the purposes of computing the step, filter out the nonzeros.
|
||||||
|
// nonzero_histo = torch.reshape(histo[histo != 0], [-1])
|
||||||
|
nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true)
|
||||||
|
|
||||||
|
// step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255
|
||||||
|
sum := nonzeroHisto.MustSum(gotch.Int64, false)
|
||||||
|
sub := nonzeroHisto.MustSelect(0, -1, true)
|
||||||
|
step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true)
|
||||||
|
sub.MustDrop()
|
||||||
|
stepVal := step.Float64Values()[0]
|
||||||
|
if stepVal == 0 {
|
||||||
|
histo.MustDrop()
|
||||||
|
out := imgChan.MustShallowClone()
|
||||||
|
return out
|
||||||
}
|
}
|
||||||
lut2.MustDrop()
|
|
||||||
hist.MustDrop()
|
halfStep := step.MustDiv1(ts.IntScalar(2), false)
|
||||||
step.MustDrop()
|
|
||||||
|
// Build lut
|
||||||
|
// Compute the cumulative sum, shifting by step // 2
|
||||||
|
// and then normalization by step.
|
||||||
|
// lut = (torch.cumsum(histo, 0) + (step // 2)) // step
|
||||||
|
lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256
|
||||||
|
|
||||||
|
// Shift lut, prepending with 0.
|
||||||
|
// lut = torch.cat([torch.zeros(1), lut[:-1]])
|
||||||
|
l := lut.MustSize()[0] // 256
|
||||||
|
lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255
|
||||||
|
zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice())
|
||||||
|
lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0)
|
||||||
|
|
||||||
|
// Clip the counts to be in range.
|
||||||
|
// torch.clamp(lut, 0, 255)
|
||||||
|
lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true)
|
||||||
|
|
||||||
|
// can't index using 2d index. Have to flatten and then reshape
|
||||||
|
// result = torch.gather(build_lut(histo, step), 0, im.flatten().long())
|
||||||
|
// result = result.reshape_as(im)
|
||||||
|
flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true)
|
||||||
|
out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
BIN
~/Downloads/bb.png
Normal file
BIN
~/Downloads/bb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 265 KiB |
Loading…
Reference in New Issue
Block a user