WIP(vision/image.go)
This commit is contained in:
parent
3b74f1fd16
commit
c1221e959e
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -15,6 +15,7 @@
|
|||
target/
|
||||
_build/
|
||||
data/
|
||||
example/testdata/
|
||||
tmp/
|
||||
gen/.merlin
|
||||
**/*.rs.bk
|
||||
|
|
89
example/mnist-linear/io.go
Normal file
89
example/mnist-linear/io.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
const numLabels = 10
|
||||
const pixelRange = 255
|
||||
|
||||
const (
|
||||
imageMagic = 0x00000803
|
||||
labelMagic = 0x00000801
|
||||
// Width of the input tensor / picture
|
||||
Width = 28
|
||||
// Height of the input tensor / picture
|
||||
Height = 28
|
||||
)
|
||||
|
||||
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
)
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if magic != labelMagic {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels = make([]Label, n)
|
||||
for i := 0; i < int(n); i++ {
|
||||
var l Label
|
||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels[i] = l
|
||||
}
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
nrow int32
|
||||
ncol int32
|
||||
)
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if magic != imageMagic {
|
||||
return nil, err /*os.ErrInvalid*/
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imgs = make([]RawImage, n)
|
||||
m := int(nrow * ncol)
|
||||
for i := 0; i < int(n); i++ {
|
||||
imgs[i] = make(RawImage, m)
|
||||
m_, err := io.ReadFull(r, imgs[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m_ != int(m) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
}
|
||||
return imgs, nil
|
||||
}
|
86
example/mnist-linear/main.go
Normal file
86
example/mnist-linear/main.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
const (
|
||||
N = 60000 // number of rows in train data (training size)
|
||||
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
|
||||
outDim = 10 // output features (probabilities for digits 0-9)
|
||||
batchSize = 50
|
||||
batches = N / batchSize
|
||||
epochs = 100
|
||||
)
|
||||
|
||||
var (
|
||||
trainX ts.Tensor
|
||||
trainY ts.Tensor
|
||||
testX ts.Tensor
|
||||
testY ts.Tensor
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
func init() {
|
||||
// load the train set
|
||||
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
|
||||
// trainY is target tensor with shape{6000, 10}
|
||||
// (represent probabilities for digit 0-9)
|
||||
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
|
||||
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
|
||||
handleError(err)
|
||||
// load our test set
|
||||
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
|
||||
handleError(err)
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
for i := 0; i < batches; i++ {
|
||||
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
|
||||
start := i * batchSize
|
||||
end := start + batchSize
|
||||
|
||||
dims, err := trainX.Size()
|
||||
handleError(err)
|
||||
|
||||
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
|
||||
break
|
||||
}
|
||||
|
||||
index := ts.NewNarrow(int64(start), int64(end))
|
||||
|
||||
batchX := trainX.Idx(index)
|
||||
batchX.Print()
|
||||
|
||||
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
|
||||
|
||||
panic("Stop")
|
||||
|
||||
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
|
||||
// handleError(err)
|
||||
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
|
||||
// handleError(err)
|
||||
// xi := batchX.Data().([]float64)
|
||||
// yi := batchY.Data().([]float64)
|
||||
//
|
||||
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
|
||||
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleError(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
143
example/mnist-linear/mnist.go
Normal file
143
example/mnist-linear/mnist.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Image holds the pixel intensities of an image.
|
||||
// 255 is foreground (black), 0 is background (white).
|
||||
type RawImage []byte
|
||||
|
||||
// Label is a digit label in 0 to 9
|
||||
type Label uint8
|
||||
|
||||
// Load loads the mnist data into two tensors
|
||||
//
|
||||
// typ can be "train", "test"
|
||||
//
|
||||
// loc represents where the mnist files are held
|
||||
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
|
||||
const (
|
||||
trainLabel = "train-labels-idx1-ubyte"
|
||||
trainData = "train-images-idx3-ubyte"
|
||||
testLabel = "t10k-labels-idx1-ubyte"
|
||||
testData = "t10k-images-idx3-ubyte"
|
||||
)
|
||||
|
||||
var labelFile, dataFile string
|
||||
switch typ {
|
||||
case "train", "dev":
|
||||
labelFile = filepath.Join(loc, trainLabel)
|
||||
dataFile = filepath.Join(loc, trainData)
|
||||
case "test":
|
||||
labelFile = filepath.Join(loc, testLabel)
|
||||
dataFile = filepath.Join(loc, testData)
|
||||
}
|
||||
|
||||
var labelData []Label
|
||||
var imageData []RawImage
|
||||
|
||||
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
|
||||
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
|
||||
}
|
||||
|
||||
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
|
||||
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
|
||||
}
|
||||
|
||||
inputs = prepareX(imageData, as)
|
||||
targets = prepareY(labelData, as)
|
||||
return
|
||||
}
|
||||
|
||||
func pixelWeight(px byte) float64 {
|
||||
retVal := float64(px)/pixelRange*0.9 + 0.1
|
||||
if retVal == 1.0 {
|
||||
return 0.999
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func reversePixelWeight(px float64) byte {
|
||||
return byte((pixelRange*px - pixelRange) / 0.9)
|
||||
}
|
||||
|
||||
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
|
||||
rows := len(M)
|
||||
cols := len(M[0])
|
||||
|
||||
var backing interface{}
|
||||
switch dt {
|
||||
case gotch.Double:
|
||||
b := make([]float64, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < len(M[i]); j++ {
|
||||
b = append(b, pixelWeight(M[i][j]))
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
case gotch.Float:
|
||||
b := make([]float32, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < len(M[i]); j++ {
|
||||
b = append(b, float32(pixelWeight(M[i][j])))
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
}
|
||||
|
||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||
if err != nil {
|
||||
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
|
||||
rows := len(N)
|
||||
cols := 10
|
||||
|
||||
var backing interface{}
|
||||
switch dt {
|
||||
case gotch.Double:
|
||||
b := make([]float64, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
if j == int(N[i]) {
|
||||
b = append(b, 0.9)
|
||||
} else {
|
||||
b = append(b, 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
case gotch.Float:
|
||||
b := make([]float32, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
if j == int(N[i]) {
|
||||
b = append(b, 0.9)
|
||||
} else {
|
||||
b = append(b, 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
|
||||
}
|
||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||
if err != nil {
|
||||
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
5
go.mod
5
go.mod
|
@ -1,3 +1,8 @@
|
|||
module github.com/sugarme/gotch
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/pkg/errors v0.9.1
|
||||
gorgonia.org/tensor v0.9.7
|
||||
)
|
||||
|
|
52
go.sum
Normal file
52
go.sum
Normal file
|
@ -0,0 +1,52 @@
|
|||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||
github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo=
|
||||
github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
|
||||
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A=
|
||||
github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
|
||||
gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw=
|
||||
gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM=
|
||||
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
||||
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gorgonia.org/tensor v0.9.7 h1:RncmNWe66zWDGMpDYFRXmReFkkMK7KOstELU/joamao=
|
||||
gorgonia.org/tensor v0.9.7/go.mod h1:yYvRwsd34UdhG98GhzsB4YUVt3cQAQ4amoD/nuyhX+c=
|
||||
gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg=
|
||||
gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA=
|
||||
gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
|
@ -139,3 +139,55 @@ func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
|
|||
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_rand_like(ptr, self)
|
||||
}
|
||||
|
||||
// void atg_log_softmax(tensor *, tensor self, int64_t dim, int dtype);
|
||||
func AtgLogSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||
|
||||
C.atg_log_softmax(ptr, self, cdim, cdtype)
|
||||
}
|
||||
|
||||
// void atg_nll_loss(tensor *, tensor self, tensor target, tensor weight, int64_t reduction, int64_t ignore_index);
|
||||
func AtgNllLoss(ptr *Ctensor, self Ctensor, target Ctensor, weight Ctensor, reduction int64, ignoreIndex int64) {
|
||||
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
|
||||
cignoreIndex := *(*C.int64_t)(unsafe.Pointer(&ignoreIndex))
|
||||
|
||||
C.atg_nll_loss(ptr, self, target, weight, creduction, cignoreIndex)
|
||||
}
|
||||
|
||||
// void atg_argmax(tensor *, tensor self, int64_t dim, int keepdim);
|
||||
func AtgArgmax(ptr *Ctensor, self Ctensor, dim int64, keepDim int) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim))
|
||||
|
||||
C.atg_argmax(ptr, self, cdim, ckeepDim)
|
||||
}
|
||||
|
||||
// void atg_mean(tensor *, tensor self, int dtype);
|
||||
func AtgMean(ptr *Ctensor, self Ctensor, dtype int32) {
|
||||
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||
|
||||
C.atg_mean(ptr, self, cdtype)
|
||||
}
|
||||
|
||||
// void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len);
|
||||
func AtgPermute(ptr *Ctensor, self Ctensor, dims []int64, dimLen int) {
|
||||
// just get pointer of the first element of the shape
|
||||
cdimsPtr := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
||||
cdimLen := *(*C.int)(unsafe.Pointer(&dimLen))
|
||||
|
||||
C.atg_permute(ptr, self, cdimsPtr, cdimLen)
|
||||
}
|
||||
|
||||
// void atg_squeeze1(tensor *, tensor self, int64_t dim);
|
||||
func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
|
||||
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||
|
||||
C.atg_squeeze1(ptr, self, cdim)
|
||||
}
|
||||
|
||||
// void atg_squeeze_(tensor *, tensor self);
|
||||
func AtgSqueeze1_(ptr *Ctensor, self Ctensor) {
|
||||
C.atg_squeeze_(ptr, self)
|
||||
}
|
||||
|
|
58
nn/module.go
Normal file
58
nn/module.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package nn
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Module interface is a container with only one method `Forward`
|
||||
//
|
||||
// The following is `module` concept from Pytorch documenation:
|
||||
// Base class for all neural network modules. Your models should also subclass this class.
|
||||
// Modules can also contain other Modules, allowing to nest them in a tree structure.
|
||||
// You can assign the submodules as regular attributes. Submodules assigned in this way will
|
||||
// be registered, and will have their parameters converted too when you call .cuda(), etc.
|
||||
type Module interface {
|
||||
// ModuleT
|
||||
Forward(xs ts.Tensor) ts.Tensor
|
||||
}
|
||||
|
||||
// ModuleT is a `Module` with an additional train parameter
|
||||
// The train parameter is commonly used to have different behavior
|
||||
// between training and evaluation. E.g. When using dropout or batch-normalization.
|
||||
type ModuleT interface {
|
||||
ForwardT(xs ts.Tensor, train bool) ts.Tensor
|
||||
}
|
||||
|
||||
// DefaultModuleT implements default method `BatchAccuracyForLogits`.
|
||||
// NOTE: when creating a struct that implement `ModuleT`, it should
|
||||
// include `DefaultModule` so that the 'default' methods `BatchAccuracyForLogits`
|
||||
// is automatically implemented.
|
||||
// Concept taken from Rust language trait **Default Implementation**
|
||||
// Ref: https://doc.rust-lang.org/1.22.1/book/second-edition/ch10-02-traits.html
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type FooModule struct{
|
||||
// DefaultModuleT
|
||||
// OtherField string
|
||||
// }
|
||||
type DefaultModuleT struct{}
|
||||
|
||||
func (dmt *DefaultModuleT) BatchAccuracyForLogits(xs, ys ts.Tensor, batchSize int) float64 {
|
||||
|
||||
var (
|
||||
sumAccuracy float64 = 0.0
|
||||
sampleCount float64 = 0.0
|
||||
)
|
||||
|
||||
// TODO: implement Iter2...
|
||||
|
||||
return sumAccuracy / sampleCount
|
||||
}
|
||||
|
||||
// TODO: should we include tensor in `Module` and `ModuleT` interfaces???
|
||||
// I.e.:
|
||||
// type Module interface{
|
||||
// t.Tensor
|
||||
// Forward(xs ts.Tensor) ts.Tensor
|
||||
// }
|
|
@ -28,6 +28,16 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
|||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
|
||||
var err error
|
||||
retVal, err = ts.To(device)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
@ -372,3 +382,54 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) {
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgPermute(ptr, ts.ctensor, dims)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSqueeze1(ptr, ts.ctensor, dim)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
|
||||
var err error
|
||||
retVal, err = ts.Squeeze1(dim)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Squeeze_() {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgSqueeze_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -461,3 +461,12 @@ func (f *Func) Invoke() interface{} {
|
|||
// How do we match them to output order of signature function
|
||||
return f.val.Call(f.meta.InArgs)
|
||||
}
|
||||
|
||||
// Must is a helper to unwrap function it wraps. If having error,
|
||||
// it will cause panic.
|
||||
func Must(ts Tensor, err error) (retVal Tensor) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ts
|
||||
}
|
||||
|
|
128
vision/image.go
Normal file
128
vision/image.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
package vision
|
||||
|
||||
// Utility functions to manipulate images.
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// (height, width, channel) -> (channel, height, width)
|
||||
func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||
var err error
|
||||
retVal, err = tensor.Permute([]int64{2, 0, 1})
|
||||
if err != nil {
|
||||
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||
var err error
|
||||
retVal, err = tensor.Permute([]int64{1, 2, 0})
|
||||
if err != nil {
|
||||
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Load loads an image from a file.
|
||||
//
|
||||
// On success returns a tensor of shape [channel, height, width].
|
||||
func Load(path string) (retVal ts.Tensor, err error) {
|
||||
var tensor ts.Tensor
|
||||
tensor, err = ts.LoadHwc(path)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = hwcToCHW(tensor)
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// Save saves an image to a file.
|
||||
//
|
||||
// This expects as input a tensor of shape [channel, height, width].
|
||||
// The image format is based on the filename suffix, supported suffixes
|
||||
// are jpg, png, tga, and bmp.
|
||||
// The tensor input should be of kind UInt8 with values ranging from
|
||||
// 0 to 255.
|
||||
func Save(tensor ts.Tensor, path string) (err error) {
|
||||
t, err := tensor.Totype(gotch.Uint8)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
shape, err := t.Size()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Save - Tensor.Size() error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
switch {
|
||||
case len(shape) == 4 && shape[0] == 1:
|
||||
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path)
|
||||
case len(shape) == 3:
|
||||
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path)
|
||||
default:
|
||||
err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Resize resizes an image.
|
||||
//
|
||||
// This expects as input a tensor of shape [channel, height, width] and returns
|
||||
// a tensor of shape [channel, out_h, out_w].
|
||||
func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
tmpTs, err := ts.ResizeHwc(t, outW, outH)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
retVal = hwcToCHW(tmpTs)
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
// TODO: implement
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop.
|
||||
//
|
||||
// This expects as input a tensor of shape [channel, height, width] and returns
|
||||
func ResizePreserveAspectRatio(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
return resizePreserveAspectRatioHWC(chwToHWC(t), outW, outH)
|
||||
}
|
||||
|
||||
// LoadAndResize loads and resizes an image, preserve the aspect ratio by taking a center crop.
|
||||
func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
tensor, err := ts.LoadHwc(path)
|
||||
if err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return resizePreserveAspectRatioHWC(tensor, outW, outH)
|
||||
}
|
||||
|
||||
// TODO: should we need this func???
|
||||
func visitDirs(dir string, files []string) (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadDir loads all the images in a directory.
|
||||
func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||
// var files []string
|
||||
// TODO: implement it
|
||||
|
||||
return
|
||||
}
|
Loading…
Reference in New Issue
Block a user