feat(vision/image): completed

This commit is contained in:
sugarme 2020-06-15 10:25:10 +10:00
parent c1221e959e
commit 9efc686748
3 changed files with 106 additions and 15 deletions

View File

@ -188,6 +188,15 @@ func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
}
// void atg_squeeze_(tensor *, tensor self);
func AtgSqueeze1_(ptr *Ctensor, self Ctensor) {
func AtgSqueeze_(ptr *Ctensor, self Ctensor) {
C.atg_squeeze_(ptr, self)
}
// void atg_stack(tensor *, tensor *tensors_data, int tensors_len, int64_t dim);
func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
tensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
}

View File

@ -387,7 +387,7 @@ func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgPermute(ptr, ts.ctensor, dims)
lib.AtgPermute(ptr, ts.ctensor, dims, len(dims))
if err = TorchErr(); err != nil {
return retVal, err
@ -423,6 +423,7 @@ func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
}
func (ts Tensor) Squeeze_() {
var err error
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -431,5 +432,23 @@ func (ts Tensor) Squeeze_() {
if err = TorchErr(); err != nil {
log.Fatal(err)
}
return nil
}
func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
var ctensors []lib.Ctensor
for _, t := range tensors {
ctensors = append(ctensors, t.ctensor)
}
lib.AtgStack(ptr, ctensors, len(tensors), dim)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}

View File

@ -4,8 +4,10 @@ package vision
import (
"fmt"
"io/ioutil"
"log"
"path/filepath"
"math"
// "path/filepath"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
@ -91,9 +93,58 @@ func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
// TODO: implement
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
// TODO: implement
tsSize, err := t.Size()
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Size() method call err: %v\n", err)
return retVal, err
}
w := tsSize[0]
h := tsSize[1]
if w*outH == h*outW {
tmpTs, err := ts.ResizeHwc(t, outW, outH)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.ResizeHwc() method call err: %v\n", err)
return retVal, err
}
return hwcToCHW(tmpTs), nil
} else {
return
ratioW := float64(outW) / float64(w)
ratioH := float64(outH) / float64(h)
ratio := math.Max(ratioW, ratioH)
resizeW := int64(ratio) * h
resizeH := int64(ratio) * w
resizeW = int64(math.Max(float64(resizeW), float64(outW)))
resizeH = int64(math.Max(float64(resizeH), float64(outH)))
tmpTs, err := ts.ResizeHwc(t, resizeW, resizeH)
tensor := hwcToCHW(tmpTs)
var tensorW ts.Tensor
var tensorH ts.Tensor
if resizeW != outW {
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
}
}
if resizeH == outH {
retVal = tensorW
} else {
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH)
if err != nil {
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
return retVal, err
}
retVal = tensorH
}
return retVal, nil
}
}
// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop.
@ -113,16 +164,28 @@ func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err e
return resizePreserveAspectRatioHWC(tensor, outW, outH)
}
// TODO: should we need this func???
func visitDirs(dir string, files []string) (err error) {
return nil
}
// LoadDir loads all the images in a directory.
func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
// var files []string
func LoadDir(dir string, outW int64, outH int64) (retVal ts.Tensor, err error) {
var filePaths []string // "dir/filename.ext"
var tensors []ts.Tensor
// TODO: implement it
files, err := ioutil.ReadDir(dir)
if err != nil {
err = fmt.Errorf("LoadDir - Read directory error: %v\n", err)
return retVal, err
}
for _, f := range files {
filePaths = append(filePaths, fmt.Sprintf("%v%v", dir, f.Name()))
}
return
for _, path := range filePaths {
tensor, err := LoadAndResize(path, outW, outH)
if err != nil {
err = fmt.Errorf("LoadDir - LoadAndResize method call error: %v\n", err)
return retVal, err
}
tensors = append(tensors, tensor)
}
return ts.Stack(tensors, 0)
}