feat(vision/image): completed
This commit is contained in:
parent
c1221e959e
commit
9efc686748
|
@ -188,6 +188,15 @@ func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
|
|||
}
|
||||
|
||||
// void atg_squeeze_(tensor *, tensor self);
|
||||
func AtgSqueeze1_(ptr *Ctensor, self Ctensor) {
|
||||
func AtgSqueeze_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_squeeze_(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_stack(tensor *, tensor *tensors_data, int tensors_len, int64_t dim);
|
||||
func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
|
||||
tensorsDataPtr := (*Ctensor)(unsafe.Pointer(&tensorsData[0]))
|
||||
ctensorsLen := *(*C.int)(unsafe.Pointer(&tensorsLen))
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
|
||||
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
||||
}
|
||||
|
|
|
@ -387,7 +387,7 @@ func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
|
|||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgPermute(ptr, ts.ctensor, dims)
|
||||
lib.AtgPermute(ptr, ts.ctensor, dims, len(dims))
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
|
@ -423,6 +423,7 @@ func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
|
|||
}
|
||||
|
||||
func (ts Tensor) Squeeze_() {
|
||||
var err error
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
|
@ -431,5 +432,23 @@ func (ts Tensor) Squeeze_() {
|
|||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
var ctensors []lib.Ctensor
|
||||
for _, t := range tensors {
|
||||
ctensors = append(ctensors, t.ctensor)
|
||||
}
|
||||
|
||||
lib.AtgStack(ptr, ctensors, len(tensors), dim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
|
|
@ -4,8 +4,10 @@ package vision
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"path/filepath"
|
||||
"math"
|
||||
// "path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -91,9 +93,58 @@ func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
|||
|
||||
// TODO: implement
|
||||
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
// TODO: implement
|
||||
tsSize, err := t.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Size() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
w := tsSize[0]
|
||||
h := tsSize[1]
|
||||
if w*outH == h*outW {
|
||||
tmpTs, err := ts.ResizeHwc(t, outW, outH)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.ResizeHwc() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
return hwcToCHW(tmpTs), nil
|
||||
} else {
|
||||
|
||||
return
|
||||
ratioW := float64(outW) / float64(w)
|
||||
ratioH := float64(outH) / float64(h)
|
||||
ratio := math.Max(ratioW, ratioH)
|
||||
|
||||
resizeW := int64(ratio) * h
|
||||
resizeH := int64(ratio) * w
|
||||
|
||||
resizeW = int64(math.Max(float64(resizeW), float64(outW)))
|
||||
resizeH = int64(math.Max(float64(resizeH), float64(outH)))
|
||||
tmpTs, err := ts.ResizeHwc(t, resizeW, resizeH)
|
||||
|
||||
tensor := hwcToCHW(tmpTs)
|
||||
|
||||
var tensorW ts.Tensor
|
||||
var tensorH ts.Tensor
|
||||
if resizeW != outW {
|
||||
tensorW, err = tensor.Narrow(2, (resizeW-outW)/2, outW)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
}
|
||||
|
||||
if resizeH == outH {
|
||||
retVal = tensorW
|
||||
} else {
|
||||
tensorH, err = tensor.Narrow(2, (resizeH-outH)/2, outH)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("resizePreserveAspectRatioHWC - ts.Narrow() method call err: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
retVal = tensorH
|
||||
}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop.
|
||||
|
@ -113,16 +164,28 @@ func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err e
|
|||
return resizePreserveAspectRatioHWC(tensor, outW, outH)
|
||||
}
|
||||
|
||||
// TODO: should we need this func???
|
||||
func visitDirs(dir string, files []string) (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDir loads all the images in a directory.
|
||||
func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
// var files []string
|
||||
func LoadDir(dir string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
var filePaths []string // "dir/filename.ext"
|
||||
var tensors []ts.Tensor
|
||||
// TODO: implement it
|
||||
files, err := ioutil.ReadDir(dir)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadDir - Read directory error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
for _, f := range files {
|
||||
filePaths = append(filePaths, fmt.Sprintf("%v%v", dir, f.Name()))
|
||||
}
|
||||
|
||||
return
|
||||
for _, path := range filePaths {
|
||||
tensor, err := LoadAndResize(path, outW, outH)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("LoadDir - LoadAndResize method call error: %v\n", err)
|
||||
return retVal, err
|
||||
}
|
||||
tensors = append(tensors, tensor)
|
||||
}
|
||||
|
||||
return ts.Stack(tensors, 0)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user