diff --git a/example/augmentation/bb-transformed.jpg b/example/augmentation/bb-transformed.jpg new file mode 100644 index 0000000..41dbe3d Binary files /dev/null and b/example/augmentation/bb-transformed.jpg differ diff --git a/example/augmentation/bb-transformed.png b/example/augmentation/bb-transformed.png deleted file mode 100644 index 3c3a1a5..0000000 Binary files a/example/augmentation/bb-transformed.png and /dev/null differ diff --git a/example/augmentation/bb.jpg b/example/augmentation/bb.jpg new file mode 100644 index 0000000..e6c7d16 Binary files /dev/null and b/example/augmentation/bb.jpg differ diff --git a/example/augmentation/main.go b/example/augmentation/main.go index 7046fc2..88d1b38 100644 --- a/example/augmentation/main.go +++ b/example/augmentation/main.go @@ -37,7 +37,7 @@ func roundTrip() { } func tOne() { - img, err := vision.Load("./bb.png") + img, err := vision.Load("bb.png") if err != nil { panic(err) } @@ -46,7 +46,7 @@ func tOne() { device := gotch.CPU imgTs := img.MustTo(device, true) - t, err := aug.Compose(aug.WithRandomAutocontrast(1.0)) + // t, err := aug.Compose(aug.WithRandomAutocontrast(1.0)) // t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0))) // t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10))) // t, err := aug.Compose(aug.WithRandRotate(0, 360)) @@ -58,7 +58,7 @@ func tOne() { // t, err := aug.Compose(aug.WithRandomGrayscale(1.0)) // t, err := aug.Compose(aug.WithRandomVFlip(1.0)) // t, err := aug.Compose(aug.WithRandomHFlip(1.0)) - // t, err := aug.Compose(aug.WithRandomEqualize(1.0)) + t, err := aug.Compose(aug.WithRandomEqualize(1.0)) // t, err := aug.Compose(aug.WithRandomCutout(aug.WithCutoutValue([]int64{124, 96, 255}), aug.WithCutoutScale([]float64{0.01, 0.1}), aug.WithCutoutRatio([]float64{0.5, 0.5}))) // t, err := aug.Compose(aug.WithCenterCrop([]int64{320, 320})) // t, err := aug.Compose(aug.WithRandomAutocontrast()) @@ -67,7 +67,7 @@ func tOne() { // t, err := aug.Compose(aug.WithRandomAffine(aug.WithAffineDegree([]int64{0, 15}), aug.WithAffineShear([]float64{0, 15}))) out := t.Transform(imgTs) - fname := fmt.Sprintf("./bb-transformed.png") + fname := fmt.Sprintf("./bb-transformed.jpg") err = vision.Save(out, fname) if err != nil { panic(err) diff --git a/vision/aug/function.go b/vision/aug/function.go index 34f4f69..06f7c5b 100644 --- a/vision/aug/function.go +++ b/vision/aug/function.go @@ -1357,7 +1357,6 @@ func equalize(img *ts.Tensor) *ts.Tensor { } out := ts.MustStack(images, 0) - for _, x := range images { x.MustDrop() } @@ -1367,12 +1366,12 @@ func equalize(img *ts.Tensor) *ts.Tensor { func equalizeSingleImage(img *ts.Tensor) *ts.Tensor { dim := img.MustSize() - var scaledChans []ts.Tensor + var scaledChans []ts.Tensor = make([]ts.Tensor, int(dim[0])) for i := 0; i < int(dim[0]); i++ { cTs := img.MustSelect(0, int64(i), false) scaledChan := scaleChannel(cTs) cTs.MustDrop() - scaledChans = append(scaledChans, *scaledChan) + scaledChans[i] = *scaledChan } out := ts.MustStack(scaledChans, 0) @@ -1384,7 +1383,7 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor { return out } -func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { +func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor { // # TODO: we should expect bincount to always be faster than histc, but this // # isn't always the case. Once // # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if @@ -1411,6 +1410,10 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { // return img_chan if stepVal == 0 { out = imgChan.MustShallowClone() + + hist.MustDrop() + step.MustDrop() + return out } @@ -1418,36 +1421,102 @@ func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { dtype := gotch.Float halfStep := step.MustDiv1(ts.IntScalar(2), false) histCumSum := hist.MustCumsum(0, dtype, false) - histStep := histCumSum.MustAdd(halfStep, false) + histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum halfStep.MustDrop() lut := histStep.MustDiv(step, true) // deleted histStep + step.MustDrop() + hist.MustDrop() // lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut lut1Dim := lut1.MustSize() + // lut2 composes of 256 elements with value in range [0, 255] lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1 - // return lut[img_chan.to(torch.int64)].to(torch.uint8) - // NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop - // channel[h, w] - h := imgChan.MustSize()[0] - // w := imgChan.MustSize()[1] - var xs []ts.Tensor - for i := 0; i < int(h); i++ { - idx := imgChan.MustSelect(0, int64(i), false).MustTotype(gotch.Int64, true) - x := lut2.MustIndexSelect(0, idx, false).MustTotype(gotch.Uint8, true) - xs = append(xs, *x) - idx.MustDrop() - } - out = ts.MustStack(xs, 0) - // delete intermediate tensors - for _, x := range xs { - x.MustDrop() + /* + + // return lut[img_chan.to(torch.int64)].to(torch.uint8) + // NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop + // imgChan is individual channel of image with [h, w] + h := imgChan.MustSize()[0] + var xs []ts.Tensor = make([]ts.Tensor, h) + imgC := imgChan.MustTotype(gotch.Int64, false) + // Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values. + for i := 0; i < int(h); i++ { + // NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix! + idx := imgC.MustSelect(0, int64(i), false) + x := lut2.MustIndexSelect(0, idx, false) + xs[i] = *x + idx.MustDrop() + } + + imgC.MustDrop() + lut2.MustDrop() + + out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true) + // delete intermediate tensors + for _, x := range xs { + x.MustDrop() + } + + + */ + + // Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179 + flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true) + out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true) + + return out +} + +// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179 +func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { + // Compute the histogram of the image channel. + // histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32) + histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true) + // histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true) + + // For the purposes of computing the step, filter out the nonzeros. + // nonzero_histo = torch.reshape(histo[histo != 0], [-1]) + nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true) + + // step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255 + sum := nonzeroHisto.MustSum(gotch.Int64, false) + sub := nonzeroHisto.MustSelect(0, -1, true) + step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true) + sub.MustDrop() + stepVal := step.Float64Values()[0] + if stepVal == 0 { + histo.MustDrop() + out := imgChan.MustShallowClone() + return out } - lut2.MustDrop() - hist.MustDrop() - step.MustDrop() + + halfStep := step.MustDiv1(ts.IntScalar(2), false) + + // Build lut + // Compute the cumulative sum, shifting by step // 2 + // and then normalization by step. + // lut = (torch.cumsum(histo, 0) + (step // 2)) // step + lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256 + + // Shift lut, prepending with 0. + // lut = torch.cat([torch.zeros(1), lut[:-1]]) + l := lut.MustSize()[0] // 256 + lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255 + zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice()) + lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0) + + // Clip the counts to be in range. + // torch.clamp(lut, 0, 255) + lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) + + // can't index using 2d index. Have to flatten and then reshape + // result = torch.gather(build_lut(histo, step), 0, im.flatten().long()) + // result = result.reshape_as(im) + flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true) + out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true) return out } diff --git a/~/Downloads/bb.png b/~/Downloads/bb.png new file mode 100644 index 0000000..6b13541 Binary files /dev/null and b/~/Downloads/bb.png differ