WIP(vision/image.go)

This commit is contained in:
sugarme 2020-06-14 22:46:36 +10:00
parent 3b74f1fd16
commit c1221e959e
11 changed files with 684 additions and 0 deletions

1
.gitignore vendored
View File

@ -15,6 +15,7 @@
target/
_build/
data/
example/testdata/
tmp/
gen/.merlin
**/*.rs.bk

View File

@ -0,0 +1,89 @@
package main
import (
"encoding/binary"
"io"
"os"
)
const numLabels = 10
const pixelRange = 255
const (
imageMagic = 0x00000803
labelMagic = 0x00000801
// Width of the input tensor / picture
Width = 28
// Height of the input tensor / picture
Height = 28
)
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
if e != nil {
return nil, e
}
var (
magic int32
n int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != labelMagic {
return nil, os.ErrInvalid
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return nil, err
}
labels = make([]Label, n)
for i := 0; i < int(n); i++ {
var l Label
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return nil, err
}
labels[i] = l
}
return labels, nil
}
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
if e != nil {
return nil, e
}
var (
magic int32
n int32
nrow int32
ncol int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != imageMagic {
return nil, err /*os.ErrInvalid*/
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return nil, err
}
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
return nil, err
}
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
return nil, err
}
imgs = make([]RawImage, n)
m := int(nrow * ncol)
for i := 0; i < int(n); i++ {
imgs[i] = make(RawImage, m)
m_, err := io.ReadFull(r, imgs[i])
if err != nil {
return nil, err
}
if m_ != int(m) {
return nil, os.ErrInvalid
}
}
return imgs, nil
}

View File

@ -0,0 +1,86 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
const (
N = 60000 // number of rows in train data (training size)
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
outDim = 10 // output features (probabilities for digits 0-9)
batchSize = 50
batches = N / batchSize
epochs = 100
)
var (
trainX ts.Tensor
trainY ts.Tensor
testX ts.Tensor
testY ts.Tensor
err error
)
func init() {
// load the train set
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
// trainY is target tensor with shape{6000, 10}
// (represent probabilities for digit 0-9)
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
handleError(err)
// load our test set
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
handleError(err)
}
func main() {
for epoch := 0; epoch < epochs; epoch++ {
for i := 0; i < batches; i++ {
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
start := i * batchSize
end := start + batchSize
dims, err := trainX.Size()
handleError(err)
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
break
}
index := ts.NewNarrow(int64(start), int64(end))
batchX := trainX.Idx(index)
batchX.Print()
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
panic("Stop")
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
// handleError(err)
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
// handleError(err)
// xi := batchX.Data().([]float64)
// yi := batchY.Data().([]float64)
//
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
}
}
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}

View File

@ -0,0 +1,143 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
// Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
// Label is a digit label in 0 to 9
type Label uint8
// Load loads the mnist data into two tensors
//
// typ can be "train", "test"
//
// loc represents where the mnist files are held
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
const (
trainLabel = "train-labels-idx1-ubyte"
trainData = "train-images-idx3-ubyte"
testLabel = "t10k-labels-idx1-ubyte"
testData = "t10k-images-idx3-ubyte"
)
var labelFile, dataFile string
switch typ {
case "train", "dev":
labelFile = filepath.Join(loc, trainLabel)
dataFile = filepath.Join(loc, trainData)
case "test":
labelFile = filepath.Join(loc, testLabel)
dataFile = filepath.Join(loc, testData)
}
var labelData []Label
var imageData []RawImage
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
}
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
}
inputs = prepareX(imageData, as)
targets = prepareY(labelData, as)
return
}
func pixelWeight(px byte) float64 {
retVal := float64(px)/pixelRange*0.9 + 0.1
if retVal == 1.0 {
return 0.999
}
return retVal
}
func reversePixelWeight(px float64) byte {
return byte((pixelRange*px - pixelRange) / 0.9)
}
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
rows := len(M)
cols := len(M[0])
var backing interface{}
switch dt {
case gotch.Double:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, pixelWeight(M[i][j]))
}
}
backing = b
case gotch.Float:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, float32(pixelWeight(M[i][j])))
}
}
backing = b
}
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
if err != nil {
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
}
return
}
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
rows := len(N)
cols := 10
var backing interface{}
switch dt {
case gotch.Double:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
case gotch.Float:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
}
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
if err != nil {
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
}
return
}

5
go.mod
View File

@ -1,3 +1,8 @@
module github.com/sugarme/gotch
go 1.14
require (
github.com/pkg/errors v0.9.1
gorgonia.org/tensor v0.9.7
)

52
go.sum Normal file
View File

@ -0,0 +1,52 @@
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo=
github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A=
github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw=
gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM=
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gorgonia.org/tensor v0.9.7 h1:RncmNWe66zWDGMpDYFRXmReFkkMK7KOstELU/joamao=
gorgonia.org/tensor v0.9.7/go.mod h1:yYvRwsd34UdhG98GhzsB4YUVt3cQAQ4amoD/nuyhX+c=
gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg=
gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA=
gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@ -139,3 +139,55 @@ func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
C.atg_rand_like(ptr, self)
}
// void atg_log_softmax(tensor *, tensor self, int64_t dim, int dtype);
func AtgLogSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
C.atg_log_softmax(ptr, self, cdim, cdtype)
}
// void atg_nll_loss(tensor *, tensor self, tensor target, tensor weight, int64_t reduction, int64_t ignore_index);
func AtgNllLoss(ptr *Ctensor, self Ctensor, target Ctensor, weight Ctensor, reduction int64, ignoreIndex int64) {
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
cignoreIndex := *(*C.int64_t)(unsafe.Pointer(&ignoreIndex))
C.atg_nll_loss(ptr, self, target, weight, creduction, cignoreIndex)
}
// void atg_argmax(tensor *, tensor self, int64_t dim, int keepdim);
func AtgArgmax(ptr *Ctensor, self Ctensor, dim int64, keepDim int) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim))
C.atg_argmax(ptr, self, cdim, ckeepDim)
}
// void atg_mean(tensor *, tensor self, int dtype);
func AtgMean(ptr *Ctensor, self Ctensor, dtype int32) {
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
C.atg_mean(ptr, self, cdtype)
}
// void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len);
func AtgPermute(ptr *Ctensor, self Ctensor, dims []int64, dimLen int) {
// just get pointer of the first element of the shape
cdimsPtr := (*C.int64_t)(unsafe.Pointer(&dims[0]))
cdimLen := *(*C.int)(unsafe.Pointer(&dimLen))
C.atg_permute(ptr, self, cdimsPtr, cdimLen)
}
// void atg_squeeze1(tensor *, tensor self, int64_t dim);
func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
C.atg_squeeze1(ptr, self, cdim)
}
// void atg_squeeze_(tensor *, tensor self);
func AtgSqueeze1_(ptr *Ctensor, self Ctensor) {
C.atg_squeeze_(ptr, self)
}

58
nn/module.go Normal file
View File

@ -0,0 +1,58 @@
package nn
import (
ts "github.com/sugarme/gotch/tensor"
)
// Module interface is a container with only one method `Forward`
//
// The following is `module` concept from Pytorch documenation:
// Base class for all neural network modules. Your models should also subclass this class.
// Modules can also contain other Modules, allowing to nest them in a tree structure.
// You can assign the submodules as regular attributes. Submodules assigned in this way will
// be registered, and will have their parameters converted too when you call .cuda(), etc.
type Module interface {
// ModuleT
Forward(xs ts.Tensor) ts.Tensor
}
// ModuleT is a `Module` with an additional train parameter
// The train parameter is commonly used to have different behavior
// between training and evaluation. E.g. When using dropout or batch-normalization.
type ModuleT interface {
ForwardT(xs ts.Tensor, train bool) ts.Tensor
}
// DefaultModuleT implements default method `BatchAccuracyForLogits`.
// NOTE: when creating a struct that implement `ModuleT`, it should
// include `DefaultModule` so that the 'default' methods `BatchAccuracyForLogits`
// is automatically implemented.
// Concept taken from Rust language trait **Default Implementation**
// Ref: https://doc.rust-lang.org/1.22.1/book/second-edition/ch10-02-traits.html
//
// Example:
//
// type FooModule struct{
// DefaultModuleT
// OtherField string
// }
type DefaultModuleT struct{}
func (dmt *DefaultModuleT) BatchAccuracyForLogits(xs, ys ts.Tensor, batchSize int) float64 {
var (
sumAccuracy float64 = 0.0
sampleCount float64 = 0.0
)
// TODO: implement Iter2...
return sumAccuracy / sampleCount
}
// TODO: should we include tensor in `Module` and `ModuleT` interfaces???
// I.e.:
// type Module interface{
// t.Tensor
// Forward(xs ts.Tensor) ts.Tensor
// }

View File

@ -28,6 +28,16 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
var err error
retVal, err = ts.To(device)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -372,3 +382,54 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgPermute(ptr, ts.ctensor, dims)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSqueeze1(ptr, ts.ctensor, dim)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
var err error
retVal, err = ts.Squeeze1(dim)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Squeeze_() {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgSqueeze_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
log.Fatal(err)
}
return nil
}

View File

@ -461,3 +461,12 @@ func (f *Func) Invoke() interface{} {
// How do we match them to output order of signature function
return f.val.Call(f.meta.InArgs)
}
// Must is a helper to unwrap function it wraps. If having error,
// it will cause panic.
func Must(ts Tensor, err error) (retVal Tensor) {
if err != nil {
panic(err)
}
return ts
}

128
vision/image.go Normal file
View File

@ -0,0 +1,128 @@
package vision
// Utility functions to manipulate images.
import (
"fmt"
"log"
"path/filepath"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
// (height, width, channel) -> (channel, height, width)
func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
var err error
retVal, err = tensor.Permute([]int64{2, 0, 1})
if err != nil {
log.Fatalf("hwcToCHW error: %v\n", err)
}
return retVal
}
func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) {
var err error
retVal, err = tensor.Permute([]int64{1, 2, 0})
if err != nil {
log.Fatalf("hwcToCHW error: %v\n", err)
}
return retVal
}
// Load loads an image from a file.
//
// On success returns a tensor of shape [channel, height, width].
func Load(path string) (retVal ts.Tensor, err error) {
var tensor ts.Tensor
tensor, err = ts.LoadHwc(path)
if err != nil {
return retVal, err
}
retVal = hwcToCHW(tensor)
return retVal, nil
}
// Save saves an image to a file.
//
// This expects as input a tensor of shape [channel, height, width].
// The image format is based on the filename suffix, supported suffixes
// are jpg, png, tga, and bmp.
// The tensor input should be of kind UInt8 with values ranging from
// 0 to 255.
func Save(tensor ts.Tensor, path string) (err error) {
t, err := tensor.Totype(gotch.Uint8)
if err != nil {
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
return err
}
shape, err := t.Size()
if err != nil {
err = fmt.Errorf("Save - Tensor.Size() error: %v\n", err)
return err
}
switch {
case len(shape) == 4 && shape[0] == 1:
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path)
case len(shape) == 3:
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path)
default:
err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape))
return err
}
}
// Resize resizes an image.
//
// This expects as input a tensor of shape [channel, height, width] and returns
// a tensor of shape [channel, out_h, out_w].
func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
tmpTs, err := ts.ResizeHwc(t, outW, outH)
if err != nil {
return retVal, err
}
retVal = hwcToCHW(tmpTs)
return retVal, nil
}
// TODO: implement
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
// TODO: implement
return
}
// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop.
//
// This expects as input a tensor of shape [channel, height, width] and returns
func ResizePreserveAspectRatio(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
return resizePreserveAspectRatioHWC(chwToHWC(t), outW, outH)
}
// LoadAndResize loads and resizes an image, preserve the aspect ratio by taking a center crop.
func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
tensor, err := ts.LoadHwc(path)
if err != nil {
return retVal, err
}
return resizePreserveAspectRatioHWC(tensor, outW, outH)
}
// TODO: should we need this func???
func visitDirs(dir string, files []string) (err error) {
return nil
}
// LoadDir loads all the images in a directory.
func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
// var files []string
// TODO: implement it
return
}