diff --git a/example/augmentation/bb-transformed.png b/example/augmentation/bb-transformed.png index 48a397e..3c3a1a5 100644 Binary files a/example/augmentation/bb-transformed.png and b/example/augmentation/bb-transformed.png differ diff --git a/example/augmentation/main.go b/example/augmentation/main.go index ff81f91..7046fc2 100644 --- a/example/augmentation/main.go +++ b/example/augmentation/main.go @@ -46,7 +46,8 @@ func tOne() { device := gotch.CPU imgTs := img.MustTo(device, true) - t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0))) + t, err := aug.Compose(aug.WithRandomAutocontrast(1.0)) + // t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0))) // t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10))) // t, err := aug.Compose(aug.WithRandRotate(0, 360)) // t, err := aug.Compose(aug.WithResize(320, 320)) // NOTE. WithResize just works on CPU. diff --git a/vision/aug/function.go b/vision/aug/function.go index 95605df..34f4f69 100644 --- a/vision/aug/function.go +++ b/vision/aug/function.go @@ -1526,7 +1526,7 @@ func Byte2FloatImage(x *ts.Tensor) *ts.Tensor { // It's panic if input is not float dtype tensor. func Float2ByteImage(x *ts.Tensor) *ts.Tensor { dtype := x.DType() - if dtype != gotch.Float { + if dtype != gotch.Float && dtype != gotch.Double { err := fmt.Errorf("Input tensor is not float dtype (%v)", dtype) panic(err) } diff --git a/vision/aug/rotate.go b/vision/aug/rotate.go index 09d447f..948a608 100644 --- a/vision/aug/rotate.go +++ b/vision/aug/rotate.go @@ -99,12 +99,18 @@ func newRandRotate(min, max float64) *RandRotateModule { // Forward implements ts.Module for RandRotateModule func (rr *RandRotateModule) Forward(x *ts.Tensor) *ts.Tensor { - out, err := RandomRotate(x, rr.minAngle, rr.maxAngle) + fx := Byte2FloatImage(x) + + out, err := RandomRotate(fx, rr.minAngle, rr.maxAngle) if err != nil { log.Fatal(err) } - return out + bx := Float2ByteImage(out) + fx.MustDrop() + out.MustDrop() + + return bx } func WithRandRotate(minAngle, maxAngle float64) Option {