fix(vision/image): fixed resizeReserveAspectRatio

This commit is contained in:
sugarme 2020-07-01 12:37:06 +10:00
parent fcbc4ca870
commit 98624fad6e
4 changed files with 34 additions and 28 deletions

View File

@ -30,14 +30,11 @@ func main() {
imageNet := vision.NewImageNet()
// Load the image file and resize it to the usual imagenet dimension of 224x224.
// image, err := imageNet.LoadImageAndResize224(imageFile)
image, err := imageNet.LoadImage(imageFile)
image, err := imageNet.LoadImageAndResize224(imageFile)
if err != nil {
log.Fatal(err)
}
// image.MustSave("resize224.jpg")
// Load the Python saved module.
model, err := ts.ModuleLoad(modelPath)
if err != nil {

View File

@ -8,7 +8,7 @@ import (
lib "github.com/sugarme/gotch/libtch"
)
// LoadHwc returns a tensor of shape [width, height, channels] on success.
// LoadHwc returns a tensor of shape [height, width, channels] on success.
func LoadHwc(path string) (retVal Tensor, err error) {
ctensor := lib.AtLoadImage(path)
@ -18,19 +18,20 @@ func LoadHwc(path string) (retVal Tensor, err error) {
}
retVal = Tensor{ctensor}
return retVal, nil
}
// SaveHwc save an image from tensor. It expects a tensor of shape [width,
// height, channels]
// SaveHwc save an image from tensor. It expects a tensor of shape [height,
// width, channels]
func SaveHwc(ts Tensor, path string) (err error) {
lib.AtSaveImage(ts.ctensor, path)
return TorchErr()
}
// ResizeHwc expects a tensor of shape [width, height, channels].
// On success returns a tensor of shape [width, height, channels].
// ResizeHwc expects a tensor of shape [height, width, channels].
// On success returns a tensor of shape [height, width, channels].
func ResizeHwc(ts Tensor, outWidth, outHeight int64) (retVal Tensor, err error) {
ctensor := lib.AtResizeImage(ts.ctensor, outWidth, outHeight)

View File

@ -54,7 +54,7 @@ func Load(path string) (retVal ts.Tensor, err error) {
// The tensor input should be of kind UInt8 with values ranging from
// 0 to 255.
func Save(tensor ts.Tensor, path string) (err error) {
t, err := tensor.Totype(gotch.Uint8, true)
t, err := tensor.Totype(gotch.Uint8, false) // false to keep the input tensor
if err != nil {
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
return err
@ -91,16 +91,18 @@ func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
return retVal, nil
}
// TODO: implement
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
tsSize, err := t.Size()
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Size() method call err: %v\n", err)
return retVal, err
}
w := tsSize[0]
// TODO: check it
h := tsSize[1]
if w*outH == h*outW {
w := tsSize[0]
if (w * outH) == (h * outW) {
tmpTs, err := ts.ResizeHwc(t, outW, outH)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.ResizeHwc() method call err: %v\n", err)
@ -108,36 +110,35 @@ func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal t
}
return hwcToCHW(tmpTs), nil
} else {
ratioW := float64(outW) / float64(w)
ratioH := float64(outH) / float64(h)
ratioW := float64(outW) / float64(h)
ratioH := float64(outH) / float64(w)
ratio := math.Max(ratioW, ratioH)
resizeW := int64(ratio) * h
resizeH := int64(ratio) * w
resizeW := int64(ratio * float64(h))
resizeH := int64(ratio * float64(w))
resizeW = maxInt64(resizeW, outW)
resizeH = maxInt64(resizeH, outH)
resizeW = int64(math.Max(float64(resizeW), float64(outW)))
resizeH = int64(math.Max(float64(resizeH), float64(outH)))
tmpTs, err := ts.ResizeHwc(t, resizeW, resizeH)
tensor := hwcToCHW(tmpTs)
var tensorW ts.Tensor
var tensorH ts.Tensor
if resizeW != outW {
if resizeW == outW {
tensorW = tensor
} else {
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW, true)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
}
} else {
tensorW = tensor
}
if resizeH == outH {
if int64(resizeH) == outH {
retVal = tensorW
} else {
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH, true)
tensorH, err = tensor.Narrow(1, (resizeH-outH)/2, outH, true)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
@ -170,7 +171,6 @@ func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err e
func LoadDir(dir string, outW int64, outH int64) (retVal ts.Tensor, err error) {
var filePaths []string // "dir/filename.ext"
var tensors []ts.Tensor
// TODO: implement it
files, err := ioutil.ReadDir(dir)
if err != nil {
err = fmt.Errorf("LoadDir - Read directory error: %v\n", err)
@ -191,3 +191,11 @@ func LoadDir(dir string, outW int64, outH int64) (retVal ts.Tensor, err error) {
return ts.Stack(tensors, 0)
}
func maxInt64(v1, v2 int64) int64 {
if v1 > v2 {
return v1
}
return v2
}

View File

@ -121,7 +121,7 @@ func (in ImageNet) LoadImageAndResize(path string, w, h int64) (retVal ts.Tensor
return retVal, err
}
return in.Normalize(tensor)
return tensor, nil
}
// LoadImageAndResize224 loads an image from a file and resize it to 224x224.