vision/aug: added double dtype

This commit is contained in:
sugarme 2021-06-22 00:38:40 +10:00
parent 121908de21
commit b6f5a89f7b
4 changed files with 11 additions and 4 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 264 KiB

After

Width:  |  Height:  |  Size: 278 KiB

View File

@ -46,7 +46,8 @@ func tOne() {
device := gotch.CPU
imgTs := img.MustTo(device, true)
t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10)))
// t, err := aug.Compose(aug.WithRandRotate(0, 360))
// t, err := aug.Compose(aug.WithResize(320, 320)) // NOTE. WithResize just works on CPU.

View File

@ -1526,7 +1526,7 @@ func Byte2FloatImage(x *ts.Tensor) *ts.Tensor {
// It's panic if input is not float dtype tensor.
func Float2ByteImage(x *ts.Tensor) *ts.Tensor {
dtype := x.DType()
if dtype != gotch.Float {
if dtype != gotch.Float && dtype != gotch.Double {
err := fmt.Errorf("Input tensor is not float dtype (%v)", dtype)
panic(err)
}

View File

@ -99,12 +99,18 @@ func newRandRotate(min, max float64) *RandRotateModule {
// Forward implements ts.Module for RandRotateModule
func (rr *RandRotateModule) Forward(x *ts.Tensor) *ts.Tensor {
out, err := RandomRotate(x, rr.minAngle, rr.maxAngle)
fx := Byte2FloatImage(x)
out, err := RandomRotate(fx, rr.minAngle, rr.maxAngle)
if err != nil {
log.Fatal(err)
}
return out
bx := Float2ByteImage(out)
fx.MustDrop()
out.MustDrop()
return bx
}
func WithRandRotate(minAngle, maxAngle float64) Option {