diff --git a/example/augmentation/bb-transformed.jpg b/example/augmentation/bb-transformed.jpg index 41dbe3d..4fd978d 100644 Binary files a/example/augmentation/bb-transformed.jpg and b/example/augmentation/bb-transformed.jpg differ diff --git a/example/augmentation/bb.jpg b/example/augmentation/bb.jpg deleted file mode 100644 index e6c7d16..0000000 Binary files a/example/augmentation/bb.jpg and /dev/null differ diff --git a/example/augmentation/main.go b/example/augmentation/main.go index 88d1b38..99b6a08 100644 --- a/example/augmentation/main.go +++ b/example/augmentation/main.go @@ -42,8 +42,8 @@ func tOne() { panic(err) } - // device := gotch.CudaIfAvailable() - device := gotch.CPU + device := gotch.CudaIfAvailable() + // device := gotch.CPU imgTs := img.MustTo(device, true) // t, err := aug.Compose(aug.WithRandomAutocontrast(1.0)) diff --git a/vision/aug/function.go b/vision/aug/function.go index 06f7c5b..c19713b 100644 --- a/vision/aug/function.go +++ b/vision/aug/function.go @@ -1383,140 +1383,51 @@ func equalizeSingleImage(img *ts.Tensor) *ts.Tensor { return out } -func scaleChannelOld(imgChan *ts.Tensor) *ts.Tensor { - // # TODO: we should expect bincount to always be faster than histc, but this - // # isn't always the case. Once - // # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if - // # block and only use bincount. - // if img_chan.is_cuda: +func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { // hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) - // else: - // hist = torch.bincount(img_chan.view(-1), minlength=256) - - // hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) - hist := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true) - - // nonzero_hist = hist[hist != 0] - nonZeroHist := hist.MustNonzero(false) // [n, 1] - - // step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor') - nonZeroHistDim := nonZeroHist.MustSize() - nonZeroHistSum := nonZeroHist.MustNarrow(0, 0, nonZeroHistDim[0]-1, true).MustSum(gotch.Int64, true) - step := nonZeroHistSum.MustDiv1(ts.IntScalar(255), true) - stepVal := step.Int64Values()[0] - - var out *ts.Tensor - // if step == 0: - // return img_chan - if stepVal == 0 { - out = imgChan.MustShallowClone() - - hist.MustDrop() - step.MustDrop() - - return out + // NOTE. Use `Bincount` so that result similar to Pytorch. If use `Histc`, results are different!!! + device := imgChan.MustDevice() + var histo *ts.Tensor + if device == gotch.CPU { + histo = imgChan.MustFlatten(0, -1, false).MustBincount(ts.NewTensor(), 256, true) + } else { + histo = imgChan.MustTotype(gotch.Float, false).MustHistc(256, true) } - // lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'),step, rounding_mode='floor') - dtype := gotch.Float - halfStep := step.MustDiv1(ts.IntScalar(2), false) - histCumSum := hist.MustCumsum(0, dtype, false) - histStep := histCumSum.MustAdd(halfStep, true) // delete histCumSum - halfStep.MustDrop() - lut := histStep.MustDiv(step, true) // deleted histStep - step.MustDrop() - hist.MustDrop() + // nonzero_hist = hist[hist != 0] + nonzeroHistoIdx := histo.MustNonzero(false).MustFlatten(0, -1, true) + nonzeroHisto := histo.MustIndexSelect(0, nonzeroHistoIdx, false) + nonzeroHistoIdx.MustDrop() - // lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) - lut1 := lut.MustConstantPadNd([]int64{1, 0}, true) // deleted lut - lut1Dim := lut1.MustSize() + // step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode='floor') + histoLen := nonzeroHisto.MustSize()[0] + step := nonzeroHisto.MustNarrow(0, 0, histoLen-1, true).MustSum(gotch.Float, true).MustFloorDivide1(ts.FloatScalar(255.0), true) - // lut2 composes of 256 elements with value in range [0, 255] - lut2 := lut1.MustNarrow(0, 0, lut1Dim[0]-1, true).MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) // deleted lut1 - - /* - - // return lut[img_chan.to(torch.int64)].to(torch.uint8) - // NOTE: haven't supported multi-dimentional tensor index yet. So we do a in a loop - // imgChan is individual channel of image with [h, w] - h := imgChan.MustSize()[0] - var xs []ts.Tensor = make([]ts.Tensor, h) - imgC := imgChan.MustTotype(gotch.Int64, false) - // Select row-by-row (width) of imgChan as index values, then indexing lut2 for the values. - for i := 0; i < int(h); i++ { - // NOTE: there a KNOWN mem leak here at `ts.MustSelect`. Don't know why. Need more investigation and fix! - idx := imgC.MustSelect(0, int64(i), false) - x := lut2.MustIndexSelect(0, idx, false) - xs[i] = *x - idx.MustDrop() - } - - imgC.MustDrop() - lut2.MustDrop() - - out = ts.MustStack(xs, 0).MustTotype(gotch.Uint8, true) - // delete intermediate tensors - for _, x := range xs { - x.MustDrop() - } - - - */ - - // Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179 - flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true) - out = lut2.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true) - - return out -} - -// Ref. https://github.com/pytorch/vision/issues/1049#issuecomment-519232179 -func scaleChannel(imgChan *ts.Tensor) *ts.Tensor { - // Compute the histogram of the image channel. - // histo = torch.histc(im, bins=256, min=0, max=255)#.type(torch.int32) - histo := imgChan.MustTotype(gotch.Float, false).MustHistc(256, true) - // histo := imgChan.MustDiv1(ts.FloatScalar(256.0), false).MustHistc(256, true) - - // For the purposes of computing the step, filter out the nonzeros. - // nonzero_histo = torch.reshape(histo[histo != 0], [-1]) - nonzeroHisto := histo.MustNonzero(false).MustReshape([]int64{-1}, true) - - // step = (torch.sum(nonzero_histo) - nonzero_histo[-1]) // 255 - sum := nonzeroHisto.MustSum(gotch.Int64, false) - sub := nonzeroHisto.MustSelect(0, -1, true) - step := sum.MustSub(sub, true).MustDiv1(ts.IntScalar(255), true) - sub.MustDrop() stepVal := step.Float64Values()[0] if stepVal == 0 { histo.MustDrop() + step.MustDrop() out := imgChan.MustShallowClone() return out } - halfStep := step.MustDiv1(ts.IntScalar(2), false) + // lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode='floor'), step, rounding_mode='floor') + halfStep := step.MustFloorDivide1(ts.FloatScalar(2.0), false) + lut := histo.Must_Cumsum(0, true).MustAdd(halfStep, true).MustFloorDivide(step, true) + step.MustDrop() + halfStep.MustDrop() - // Build lut - // Compute the cumulative sum, shifting by step // 2 - // and then normalization by step. - // lut = (torch.cumsum(histo, 0) + (step // 2)) // step - lut := histo.MustCumsum(0, gotch.Int64, true).MustAdd(halfStep, true).MustDiv(step, true) // size = 256 - - // Shift lut, prepending with 0. - // lut = torch.cat([torch.zeros(1), lut[:-1]]) - l := lut.MustSize()[0] // 256 - lut1 := lut.MustNarrow(0, 0, l-1, true) // size = 255 - zeroTs := ts.MustZeros([]int64{1}, gotch.Float, imgChan.MustDevice()) - lut2 := ts.MustCat([]ts.Tensor{*zeroTs, *lut1}, 0) - - // Clip the counts to be in range. - // torch.clamp(lut, 0, 255) - lut3 := lut2.MustClamp(ts.IntScalar(0), ts.IntScalar(255), true) + // lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255) + lutLen := lut.MustSize()[0] + lut = lut.MustConstantPadNd([]int64{1, 0}, true).MustNarrow(0, 0, lutLen, true).MustClamp(ts.FloatScalar(0), ts.FloatScalar(255.0), true) + // return lut[img_chan.to(torch.int64)].to(torch.uint8) // can't index using 2d index. Have to flatten and then reshape // result = torch.gather(build_lut(histo, step), 0, im.flatten().long()) // result = result.reshape_as(im) flattenImg := imgChan.MustFlatten(0, -1, false).MustTotype(gotch.Int64, true) - out := lut3.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true) + out := lut.MustIndexSelect(0, flattenImg, true).MustReshapeAs(imgChan, true).MustTotype(gotch.Uint8, true) + flattenImg.MustDrop() return out }