vision/aug: added double dtype
This commit is contained in:
parent
121908de21
commit
b6f5a89f7b
Binary file not shown.
Before Width: | Height: | Size: 264 KiB After Width: | Height: | Size: 278 KiB |
|
@ -46,7 +46,8 @@ func tOne() {
|
|||
device := gotch.CPU
|
||||
imgTs := img.MustTo(device, true)
|
||||
|
||||
t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
|
||||
t, err := aug.Compose(aug.WithRandomAutocontrast(1.0))
|
||||
// t, err := aug.Compose(aug.WithRandomSolarize(aug.WithSolarizeThreshold(125), aug.WithSolarizePvalue(1.0)))
|
||||
// t, err := aug.Compose(aug.WithRandomAdjustSharpness(aug.WithSharpnessPvalue(1.0), aug.WithSharpnessFactor(10)))
|
||||
// t, err := aug.Compose(aug.WithRandRotate(0, 360))
|
||||
// t, err := aug.Compose(aug.WithResize(320, 320)) // NOTE. WithResize just works on CPU.
|
||||
|
|
|
@ -1526,7 +1526,7 @@ func Byte2FloatImage(x *ts.Tensor) *ts.Tensor {
|
|||
// It's panic if input is not float dtype tensor.
|
||||
func Float2ByteImage(x *ts.Tensor) *ts.Tensor {
|
||||
dtype := x.DType()
|
||||
if dtype != gotch.Float {
|
||||
if dtype != gotch.Float && dtype != gotch.Double {
|
||||
err := fmt.Errorf("Input tensor is not float dtype (%v)", dtype)
|
||||
panic(err)
|
||||
}
|
||||
|
|
|
@ -99,12 +99,18 @@ func newRandRotate(min, max float64) *RandRotateModule {
|
|||
|
||||
// Forward implements ts.Module for RandRotateModule
|
||||
func (rr *RandRotateModule) Forward(x *ts.Tensor) *ts.Tensor {
|
||||
out, err := RandomRotate(x, rr.minAngle, rr.maxAngle)
|
||||
fx := Byte2FloatImage(x)
|
||||
|
||||
out, err := RandomRotate(fx, rr.minAngle, rr.maxAngle)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return out
|
||||
bx := Float2ByteImage(out)
|
||||
fx.MustDrop()
|
||||
out.MustDrop()
|
||||
|
||||
return bx
|
||||
}
|
||||
|
||||
func WithRandRotate(minAngle, maxAngle float64) Option {
|
||||
|
|
Loading…
Reference in New Issue
Block a user