From c1221e959e0b1166b77cd3414023ed2a44ced6f9 Mon Sep 17 00:00:00 2001 From: sugarme Date: Sun, 14 Jun 2020 22:46:36 +1000 Subject: [PATCH] WIP(vision/image.go) --- .gitignore | 1 + example/mnist-linear/io.go | 89 +++++++++++++++++++ example/mnist-linear/main.go | 86 ++++++++++++++++++ example/mnist-linear/mnist.go | 143 ++++++++++++++++++++++++++++++ go.mod | 5 ++ go.sum | 52 +++++++++++ libtch/c-generated-sample.go | 52 +++++++++++ nn/module.go | 58 ++++++++++++ tensor/tensor-generated-sample.go | 61 +++++++++++++ tensor/util.go | 9 ++ vision/image.go | 128 ++++++++++++++++++++++++++ 11 files changed, 684 insertions(+) create mode 100644 example/mnist-linear/io.go create mode 100644 example/mnist-linear/main.go create mode 100644 example/mnist-linear/mnist.go create mode 100644 go.sum create mode 100644 nn/module.go create mode 100644 vision/image.go diff --git a/.gitignore b/.gitignore index 931dedc..f5e2b25 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ target/ _build/ data/ +example/testdata/ tmp/ gen/.merlin **/*.rs.bk diff --git a/example/mnist-linear/io.go b/example/mnist-linear/io.go new file mode 100644 index 0000000..f9c8a2d --- /dev/null +++ b/example/mnist-linear/io.go @@ -0,0 +1,89 @@ +package main + +import ( + "encoding/binary" + "io" + "os" +) + +const numLabels = 10 +const pixelRange = 255 + +const ( + imageMagic = 0x00000803 + labelMagic = 0x00000801 + // Width of the input tensor / picture + Width = 28 + // Height of the input tensor / picture + Height = 28 +) + +func readLabelFile(r io.Reader, e error) (labels []Label, err error) { + if e != nil { + return nil, e + } + + var ( + magic int32 + n int32 + ) + if err = binary.Read(r, binary.BigEndian, &magic); err != nil { + return nil, err + } + if magic != labelMagic { + return nil, os.ErrInvalid + } + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + return nil, err + } + labels = make([]Label, n) + for i := 0; i < int(n); i++ { + var l Label + if err := binary.Read(r, binary.BigEndian, &l); err != nil { + return nil, err + } + labels[i] = l + } + return labels, nil +} + +func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) { + if e != nil { + return nil, e + } + + var ( + magic int32 + n int32 + nrow int32 + ncol int32 + ) + if err = binary.Read(r, binary.BigEndian, &magic); err != nil { + return nil, err + } + if magic != imageMagic { + return nil, err /*os.ErrInvalid*/ + } + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + return nil, err + } + if err = binary.Read(r, binary.BigEndian, &nrow); err != nil { + return nil, err + } + if err = binary.Read(r, binary.BigEndian, &ncol); err != nil { + return nil, err + } + imgs = make([]RawImage, n) + m := int(nrow * ncol) + for i := 0; i < int(n); i++ { + imgs[i] = make(RawImage, m) + m_, err := io.ReadFull(r, imgs[i]) + if err != nil { + return nil, err + } + if m_ != int(m) { + return nil, os.ErrInvalid + } + } + return imgs, nil +} diff --git a/example/mnist-linear/main.go b/example/mnist-linear/main.go new file mode 100644 index 0000000..09e19d2 --- /dev/null +++ b/example/mnist-linear/main.go @@ -0,0 +1,86 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +const ( + N = 60000 // number of rows in train data (training size) + inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix) + outDim = 10 // output features (probabilities for digits 0-9) + batchSize = 50 + batches = N / batchSize + epochs = 100 +) + +var ( + trainX ts.Tensor + trainY ts.Tensor + testX ts.Tensor + testY ts.Tensor + + err error +) + +func init() { + // load the train set + // trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels) + // trainY is target tensor with shape{6000, 10} + // (represent probabilities for digit 0-9) + // E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1] + trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double) + handleError(err) + // load our test set + testX, testY, err = Load("test", "../testdata/mnist", gotch.Double) + handleError(err) + +} + +func main() { + + for epoch := 0; epoch < epochs; epoch++ { + for i := 0; i < batches; i++ { + // NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point. + start := i * batchSize + end := start + batchSize + + dims, err := trainX.Size() + handleError(err) + + if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) { + break + } + + index := ts.NewNarrow(int64(start), int64(end)) + + batchX := trainX.Idx(index) + batchX.Print() + + fmt.Printf("Processed epoch %v - sample %v\n", epoch, i) + + panic("Stop") + + // batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end)) + // handleError(err) + // batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end)) + // handleError(err) + // xi := batchX.Data().([]float64) + // yi := batchY.Data().([]float64) + // + // xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi)) + // yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi)) + + } + + } +} + +func handleError(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/example/mnist-linear/mnist.go b/example/mnist-linear/mnist.go new file mode 100644 index 0000000..9788e33 --- /dev/null +++ b/example/mnist-linear/mnist.go @@ -0,0 +1,143 @@ +package main + +import ( + "fmt" + "log" + "os" + "path/filepath" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +// Image holds the pixel intensities of an image. +// 255 is foreground (black), 0 is background (white). +type RawImage []byte + +// Label is a digit label in 0 to 9 +type Label uint8 + +// Load loads the mnist data into two tensors +// +// typ can be "train", "test" +// +// loc represents where the mnist files are held +func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) { + const ( + trainLabel = "train-labels-idx1-ubyte" + trainData = "train-images-idx3-ubyte" + testLabel = "t10k-labels-idx1-ubyte" + testData = "t10k-images-idx3-ubyte" + ) + + var labelFile, dataFile string + switch typ { + case "train", "dev": + labelFile = filepath.Join(loc, trainLabel) + dataFile = filepath.Join(loc, trainData) + case "test": + labelFile = filepath.Join(loc, testLabel) + dataFile = filepath.Join(loc, testData) + } + + var labelData []Label + var imageData []RawImage + + if labelData, err = readLabelFile(os.Open(labelFile)); err != nil { + return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err) + } + + if imageData, err = readImageFile(os.Open(dataFile)); err != nil { + return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err) + } + + inputs = prepareX(imageData, as) + targets = prepareY(labelData, as) + return +} + +func pixelWeight(px byte) float64 { + retVal := float64(px)/pixelRange*0.9 + 0.1 + if retVal == 1.0 { + return 0.999 + } + return retVal +} + +func reversePixelWeight(px float64) byte { + return byte((pixelRange*px - pixelRange) / 0.9) +} + +func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) { + rows := len(M) + cols := len(M[0]) + + var backing interface{} + switch dt { + case gotch.Double: + b := make([]float64, rows*cols, rows*cols) + b = b[:0] + for i := 0; i < rows; i++ { + for j := 0; j < len(M[i]); j++ { + b = append(b, pixelWeight(M[i][j])) + } + } + backing = b + case gotch.Float: + b := make([]float32, rows*cols, rows*cols) + b = b[:0] + for i := 0; i < rows; i++ { + for j := 0; j < len(M[i]); j++ { + b = append(b, float32(pixelWeight(M[i][j]))) + } + } + backing = b + } + + retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)}) + if err != nil { + log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err) + } + return +} + +func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) { + rows := len(N) + cols := 10 + + var backing interface{} + switch dt { + case gotch.Double: + b := make([]float64, rows*cols, rows*cols) + b = b[:0] + for i := 0; i < rows; i++ { + for j := 0; j < 10; j++ { + if j == int(N[i]) { + b = append(b, 0.9) + } else { + b = append(b, 0.1) + } + } + } + backing = b + case gotch.Float: + b := make([]float32, rows*cols, rows*cols) + b = b[:0] + for i := 0; i < rows; i++ { + for j := 0; j < 10; j++ { + if j == int(N[i]) { + b = append(b, 0.9) + } else { + b = append(b, 0.1) + } + } + } + backing = b + + } + retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)}) + if err != nil { + log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err) + } + return +} diff --git a/go.mod b/go.mod index 3f67ae4..98318d7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/sugarme/gotch go 1.14 + +require ( + github.com/pkg/errors v0.9.1 + gorgonia.org/tensor v0.9.7 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d8abd6a --- /dev/null +++ b/go.sum @@ -0,0 +1,52 @@ +github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= +github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k= +github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0= +github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0= +github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo= +github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= +github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE= +github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= +github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= +github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A= +github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= +github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY= +github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8= +golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= +golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= +golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= +gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw= +gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM= +gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= +gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gorgonia.org/tensor v0.9.7 h1:RncmNWe66zWDGMpDYFRXmReFkkMK7KOstELU/joamao= +gorgonia.org/tensor v0.9.7/go.mod h1:yYvRwsd34UdhG98GhzsB4YUVt3cQAQ4amoD/nuyhX+c= +gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg= +gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA= +gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A= +gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/libtch/c-generated-sample.go b/libtch/c-generated-sample.go index 880eb4f..305c26f 100644 --- a/libtch/c-generated-sample.go +++ b/libtch/c-generated-sample.go @@ -139,3 +139,55 @@ func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) { func AtgRandnLike(ptr *Ctensor, self Ctensor) { C.atg_rand_like(ptr, self) } + +// void atg_log_softmax(tensor *, tensor self, int64_t dim, int dtype); +func AtgLogSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) { + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + cdtype := *(*C.int)(unsafe.Pointer(&dtype)) + + C.atg_log_softmax(ptr, self, cdim, cdtype) +} + +// void atg_nll_loss(tensor *, tensor self, tensor target, tensor weight, int64_t reduction, int64_t ignore_index); +func AtgNllLoss(ptr *Ctensor, self Ctensor, target Ctensor, weight Ctensor, reduction int64, ignoreIndex int64) { + creduction := *(*C.int64_t)(unsafe.Pointer(&reduction)) + cignoreIndex := *(*C.int64_t)(unsafe.Pointer(&ignoreIndex)) + + C.atg_nll_loss(ptr, self, target, weight, creduction, cignoreIndex) +} + +// void atg_argmax(tensor *, tensor self, int64_t dim, int keepdim); +func AtgArgmax(ptr *Ctensor, self Ctensor, dim int64, keepDim int) { + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim)) + + C.atg_argmax(ptr, self, cdim, ckeepDim) +} + +// void atg_mean(tensor *, tensor self, int dtype); +func AtgMean(ptr *Ctensor, self Ctensor, dtype int32) { + cdtype := *(*C.int)(unsafe.Pointer(&dtype)) + + C.atg_mean(ptr, self, cdtype) +} + +// void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len); +func AtgPermute(ptr *Ctensor, self Ctensor, dims []int64, dimLen int) { + // just get pointer of the first element of the shape + cdimsPtr := (*C.int64_t)(unsafe.Pointer(&dims[0])) + cdimLen := *(*C.int)(unsafe.Pointer(&dimLen)) + + C.atg_permute(ptr, self, cdimsPtr, cdimLen) +} + +// void atg_squeeze1(tensor *, tensor self, int64_t dim); +func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) { + cdim := *(*C.int64_t)(unsafe.Pointer(&dim)) + + C.atg_squeeze1(ptr, self, cdim) +} + +// void atg_squeeze_(tensor *, tensor self); +func AtgSqueeze1_(ptr *Ctensor, self Ctensor) { + C.atg_squeeze_(ptr, self) +} diff --git a/nn/module.go b/nn/module.go new file mode 100644 index 0000000..e0e34cb --- /dev/null +++ b/nn/module.go @@ -0,0 +1,58 @@ +package nn + +import ( + ts "github.com/sugarme/gotch/tensor" +) + +// Module interface is a container with only one method `Forward` +// +// The following is `module` concept from Pytorch documenation: +// Base class for all neural network modules. Your models should also subclass this class. +// Modules can also contain other Modules, allowing to nest them in a tree structure. +// You can assign the submodules as regular attributes. Submodules assigned in this way will +// be registered, and will have their parameters converted too when you call .cuda(), etc. +type Module interface { + // ModuleT + Forward(xs ts.Tensor) ts.Tensor +} + +// ModuleT is a `Module` with an additional train parameter +// The train parameter is commonly used to have different behavior +// between training and evaluation. E.g. When using dropout or batch-normalization. +type ModuleT interface { + ForwardT(xs ts.Tensor, train bool) ts.Tensor +} + +// DefaultModuleT implements default method `BatchAccuracyForLogits`. +// NOTE: when creating a struct that implement `ModuleT`, it should +// include `DefaultModule` so that the 'default' methods `BatchAccuracyForLogits` +// is automatically implemented. +// Concept taken from Rust language trait **Default Implementation** +// Ref: https://doc.rust-lang.org/1.22.1/book/second-edition/ch10-02-traits.html +// +// Example: +// +// type FooModule struct{ +// DefaultModuleT +// OtherField string +// } +type DefaultModuleT struct{} + +func (dmt *DefaultModuleT) BatchAccuracyForLogits(xs, ys ts.Tensor, batchSize int) float64 { + + var ( + sumAccuracy float64 = 0.0 + sampleCount float64 = 0.0 + ) + + // TODO: implement Iter2... + + return sumAccuracy / sampleCount +} + +// TODO: should we include tensor in `Module` and `ModuleT` interfaces??? +// I.e.: +// type Module interface{ +// t.Tensor +// Forward(xs ts.Tensor) ts.Tensor +// } diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index fbdf524..a5a71c5 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -28,6 +28,16 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) { return Tensor{ctensor: *ptr}, nil } +func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) { + var err error + retVal, err = ts.To(device) + if err != nil { + log.Fatal(err) + } + + return retVal +} + func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) @@ -372,3 +382,54 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) { return retVal, nil } + +func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgPermute(ptr, ts.ctensor, dims) + + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSqueeze1(ptr, ts.ctensor, dim) + + if err = TorchErr(); err != nil { + return retVal, err + } + + retVal = Tensor{ctensor: *ptr} + + return retVal, nil +} + +func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) { + var err error + retVal, err = ts.Squeeze1(dim) + if err != nil { + log.Fatal(err) + } + return retVal +} + +func (ts Tensor) Squeeze_() { + ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) + defer C.free(unsafe.Pointer(ptr)) + + lib.AtgSqueeze_(ptr, ts.ctensor) + + if err = TorchErr(); err != nil { + log.Fatal(err) + } + return nil +} diff --git a/tensor/util.go b/tensor/util.go index 6737754..74a5be5 100644 --- a/tensor/util.go +++ b/tensor/util.go @@ -461,3 +461,12 @@ func (f *Func) Invoke() interface{} { // How do we match them to output order of signature function return f.val.Call(f.meta.InArgs) } + +// Must is a helper to unwrap function it wraps. If having error, +// it will cause panic. +func Must(ts Tensor, err error) (retVal Tensor) { + if err != nil { + panic(err) + } + return ts +} diff --git a/vision/image.go b/vision/image.go new file mode 100644 index 0000000..8531510 --- /dev/null +++ b/vision/image.go @@ -0,0 +1,128 @@ +package vision + +// Utility functions to manipulate images. + +import ( + "fmt" + "log" + "path/filepath" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" +) + +// (height, width, channel) -> (channel, height, width) +func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) { + var err error + retVal, err = tensor.Permute([]int64{2, 0, 1}) + if err != nil { + log.Fatalf("hwcToCHW error: %v\n", err) + } + return retVal +} + +func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) { + var err error + retVal, err = tensor.Permute([]int64{1, 2, 0}) + if err != nil { + log.Fatalf("hwcToCHW error: %v\n", err) + } + return retVal +} + +// Load loads an image from a file. +// +// On success returns a tensor of shape [channel, height, width]. +func Load(path string) (retVal ts.Tensor, err error) { + var tensor ts.Tensor + tensor, err = ts.LoadHwc(path) + if err != nil { + return retVal, err + } + + retVal = hwcToCHW(tensor) + return retVal, nil +} + +// Save saves an image to a file. +// +// This expects as input a tensor of shape [channel, height, width]. +// The image format is based on the filename suffix, supported suffixes +// are jpg, png, tga, and bmp. +// The tensor input should be of kind UInt8 with values ranging from +// 0 to 255. +func Save(tensor ts.Tensor, path string) (err error) { + t, err := tensor.Totype(gotch.Uint8) + if err != nil { + err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err) + return err + } + + shape, err := t.Size() + if err != nil { + err = fmt.Errorf("Save - Tensor.Size() error: %v\n", err) + return err + } + + switch { + case len(shape) == 4 && shape[0] == 1: + return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path) + case len(shape) == 3: + return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path) + default: + err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape)) + return err + } +} + +// Resize resizes an image. +// +// This expects as input a tensor of shape [channel, height, width] and returns +// a tensor of shape [channel, out_h, out_w]. +func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) { + tmpTs, err := ts.ResizeHwc(t, outW, outH) + if err != nil { + return retVal, err + } + retVal = hwcToCHW(tmpTs) + + return retVal, nil +} + +// TODO: implement +func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) { + // TODO: implement + + return +} + +// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop. +// +// This expects as input a tensor of shape [channel, height, width] and returns +func ResizePreserveAspectRatio(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) { + return resizePreserveAspectRatioHWC(chwToHWC(t), outW, outH) +} + +// LoadAndResize loads and resizes an image, preserve the aspect ratio by taking a center crop. +func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err error) { + tensor, err := ts.LoadHwc(path) + if err != nil { + return retVal, err + } + + return resizePreserveAspectRatioHWC(tensor, outW, outH) +} + +// TODO: should we need this func??? +func visitDirs(dir string, files []string) (err error) { + + return nil +} + +// LoadDir loads all the images in a directory. +func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) { + // var files []string + // TODO: implement it + + return +}