WIP(vision/image.go)
This commit is contained in:
parent
3b74f1fd16
commit
c1221e959e
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -15,6 +15,7 @@
|
||||||
target/
|
target/
|
||||||
_build/
|
_build/
|
||||||
data/
|
data/
|
||||||
|
example/testdata/
|
||||||
tmp/
|
tmp/
|
||||||
gen/.merlin
|
gen/.merlin
|
||||||
**/*.rs.bk
|
**/*.rs.bk
|
||||||
|
|
89
example/mnist-linear/io.go
Normal file
89
example/mnist-linear/io.go
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const numLabels = 10
|
||||||
|
const pixelRange = 255
|
||||||
|
|
||||||
|
const (
|
||||||
|
imageMagic = 0x00000803
|
||||||
|
labelMagic = 0x00000801
|
||||||
|
// Width of the input tensor / picture
|
||||||
|
Width = 28
|
||||||
|
// Height of the input tensor / picture
|
||||||
|
Height = 28
|
||||||
|
)
|
||||||
|
|
||||||
|
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
magic int32
|
||||||
|
n int32
|
||||||
|
)
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if magic != labelMagic {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labels = make([]Label, n)
|
||||||
|
for i := 0; i < int(n); i++ {
|
||||||
|
var l Label
|
||||||
|
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
labels[i] = l
|
||||||
|
}
|
||||||
|
return labels, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
magic int32
|
||||||
|
n int32
|
||||||
|
nrow int32
|
||||||
|
ncol int32
|
||||||
|
)
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if magic != imageMagic {
|
||||||
|
return nil, err /*os.ErrInvalid*/
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
imgs = make([]RawImage, n)
|
||||||
|
m := int(nrow * ncol)
|
||||||
|
for i := 0; i < int(n); i++ {
|
||||||
|
imgs[i] = make(RawImage, m)
|
||||||
|
m_, err := io.ReadFull(r, imgs[i])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if m_ != int(m) {
|
||||||
|
return nil, os.ErrInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return imgs, nil
|
||||||
|
}
|
86
example/mnist-linear/main.go
Normal file
86
example/mnist-linear/main.go
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
N = 60000 // number of rows in train data (training size)
|
||||||
|
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
|
||||||
|
outDim = 10 // output features (probabilities for digits 0-9)
|
||||||
|
batchSize = 50
|
||||||
|
batches = N / batchSize
|
||||||
|
epochs = 100
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
trainX ts.Tensor
|
||||||
|
trainY ts.Tensor
|
||||||
|
testX ts.Tensor
|
||||||
|
testY ts.Tensor
|
||||||
|
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// load the train set
|
||||||
|
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
|
||||||
|
// trainY is target tensor with shape{6000, 10}
|
||||||
|
// (represent probabilities for digit 0-9)
|
||||||
|
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
|
||||||
|
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
|
||||||
|
handleError(err)
|
||||||
|
// load our test set
|
||||||
|
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
|
||||||
|
handleError(err)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
|
||||||
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
|
for i := 0; i < batches; i++ {
|
||||||
|
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
|
||||||
|
start := i * batchSize
|
||||||
|
end := start + batchSize
|
||||||
|
|
||||||
|
dims, err := trainX.Size()
|
||||||
|
handleError(err)
|
||||||
|
|
||||||
|
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
index := ts.NewNarrow(int64(start), int64(end))
|
||||||
|
|
||||||
|
batchX := trainX.Idx(index)
|
||||||
|
batchX.Print()
|
||||||
|
|
||||||
|
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
|
||||||
|
|
||||||
|
panic("Stop")
|
||||||
|
|
||||||
|
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
|
||||||
|
// handleError(err)
|
||||||
|
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
|
||||||
|
// handleError(err)
|
||||||
|
// xi := batchX.Data().([]float64)
|
||||||
|
// yi := batchY.Data().([]float64)
|
||||||
|
//
|
||||||
|
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
|
||||||
|
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleError(err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
143
example/mnist-linear/mnist.go
Normal file
143
example/mnist-linear/mnist.go
Normal file
|
@ -0,0 +1,143 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Image holds the pixel intensities of an image.
|
||||||
|
// 255 is foreground (black), 0 is background (white).
|
||||||
|
type RawImage []byte
|
||||||
|
|
||||||
|
// Label is a digit label in 0 to 9
|
||||||
|
type Label uint8
|
||||||
|
|
||||||
|
// Load loads the mnist data into two tensors
|
||||||
|
//
|
||||||
|
// typ can be "train", "test"
|
||||||
|
//
|
||||||
|
// loc represents where the mnist files are held
|
||||||
|
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
|
||||||
|
const (
|
||||||
|
trainLabel = "train-labels-idx1-ubyte"
|
||||||
|
trainData = "train-images-idx3-ubyte"
|
||||||
|
testLabel = "t10k-labels-idx1-ubyte"
|
||||||
|
testData = "t10k-images-idx3-ubyte"
|
||||||
|
)
|
||||||
|
|
||||||
|
var labelFile, dataFile string
|
||||||
|
switch typ {
|
||||||
|
case "train", "dev":
|
||||||
|
labelFile = filepath.Join(loc, trainLabel)
|
||||||
|
dataFile = filepath.Join(loc, trainData)
|
||||||
|
case "test":
|
||||||
|
labelFile = filepath.Join(loc, testLabel)
|
||||||
|
dataFile = filepath.Join(loc, testData)
|
||||||
|
}
|
||||||
|
|
||||||
|
var labelData []Label
|
||||||
|
var imageData []RawImage
|
||||||
|
|
||||||
|
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
|
||||||
|
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
|
||||||
|
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
inputs = prepareX(imageData, as)
|
||||||
|
targets = prepareY(labelData, as)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func pixelWeight(px byte) float64 {
|
||||||
|
retVal := float64(px)/pixelRange*0.9 + 0.1
|
||||||
|
if retVal == 1.0 {
|
||||||
|
return 0.999
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func reversePixelWeight(px float64) byte {
|
||||||
|
return byte((pixelRange*px - pixelRange) / 0.9)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
|
||||||
|
rows := len(M)
|
||||||
|
cols := len(M[0])
|
||||||
|
|
||||||
|
var backing interface{}
|
||||||
|
switch dt {
|
||||||
|
case gotch.Double:
|
||||||
|
b := make([]float64, rows*cols, rows*cols)
|
||||||
|
b = b[:0]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < len(M[i]); j++ {
|
||||||
|
b = append(b, pixelWeight(M[i][j]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
backing = b
|
||||||
|
case gotch.Float:
|
||||||
|
b := make([]float32, rows*cols, rows*cols)
|
||||||
|
b = b[:0]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < len(M[i]); j++ {
|
||||||
|
b = append(b, float32(pixelWeight(M[i][j])))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
backing = b
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
|
||||||
|
rows := len(N)
|
||||||
|
cols := 10
|
||||||
|
|
||||||
|
var backing interface{}
|
||||||
|
switch dt {
|
||||||
|
case gotch.Double:
|
||||||
|
b := make([]float64, rows*cols, rows*cols)
|
||||||
|
b = b[:0]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
if j == int(N[i]) {
|
||||||
|
b = append(b, 0.9)
|
||||||
|
} else {
|
||||||
|
b = append(b, 0.1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
backing = b
|
||||||
|
case gotch.Float:
|
||||||
|
b := make([]float32, rows*cols, rows*cols)
|
||||||
|
b = b[:0]
|
||||||
|
for i := 0; i < rows; i++ {
|
||||||
|
for j := 0; j < 10; j++ {
|
||||||
|
if j == int(N[i]) {
|
||||||
|
b = append(b, 0.9)
|
||||||
|
} else {
|
||||||
|
b = append(b, 0.1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
backing = b
|
||||||
|
|
||||||
|
}
|
||||||
|
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
5
go.mod
5
go.mod
|
@ -1,3 +1,8 @@
|
||||||
module github.com/sugarme/gotch
|
module github.com/sugarme/gotch
|
||||||
|
|
||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/pkg/errors v0.9.1
|
||||||
|
gorgonia.org/tensor v0.9.7
|
||||||
|
)
|
||||||
|
|
52
go.sum
Normal file
52
go.sum
Normal file
|
@ -0,0 +1,52 @@
|
||||||
|
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||||
|
github.com/chewxy/hm v1.0.0 h1:zy/TSv3LV2nD3dwUEQL2VhXeoXbb9QkpmdRAVUFiA6k=
|
||||||
|
github.com/chewxy/hm v1.0.0/go.mod h1:qg9YI4q6Fkj/whwHR1D+bOGeF7SniIP40VweVepLjg0=
|
||||||
|
github.com/chewxy/math32 v1.0.0/go.mod h1:Miac6hA1ohdDUTagnvJy/q+aNnEk16qWUdb8ZVhvCN0=
|
||||||
|
github.com/chewxy/math32 v1.0.4 h1:dfqy3+BbCmet2zCkaDaIQv9fpMxnmYYlAEV2Iqe3DZo=
|
||||||
|
github.com/chewxy/math32 v1.0.4/go.mod h1:dOB2rcuFrCn6UHrze36WSLVPKtzPMRAQvBvUwkSsLqs=
|
||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
|
||||||
|
github.com/gogo/protobuf v1.3.0 h1:G8O7TerXerS4F6sx9OV7/nRfJdnXgHZu/S/7F2SN+UE=
|
||||||
|
github.com/gogo/protobuf v1.3.0/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o=
|
||||||
|
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
|
||||||
|
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||||
|
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||||
|
github.com/google/flatbuffers v1.11.0 h1:O7CEyB8Cb3/DmtxODGtLHcEvpr81Jm5qLg/hsHnxA2A=
|
||||||
|
github.com/google/flatbuffers v1.11.0/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||||
|
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||||
|
github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00=
|
||||||
|
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||||
|
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||||
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||||
|
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
|
||||||
|
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||||
|
github.com/xtgo/set v1.0.0 h1:6BCNBRv3ORNDQ7fyoJXRv+tstJz3m1JVFQErfeZz2pY=
|
||||||
|
github.com/xtgo/set v1.0.0/go.mod h1:d3NHzGzSa0NmB2NhFyECA+QdRp29oEn2xbT+TpeFoM8=
|
||||||
|
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||||
|
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo=
|
||||||
|
gonum.org/v1/gonum v0.7.0 h1:Hdks0L0hgznZLG9nzXb8vZ0rRvqNvAcgAp84y7Mwkgw=
|
||||||
|
gonum.org/v1/gonum v0.7.0/go.mod h1:L02bwd0sqlsvRv41G7wGWFCsVNZFv/k1xzGIxeANHGM=
|
||||||
|
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
||||||
|
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
||||||
|
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||||
|
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||||
|
gorgonia.org/tensor v0.9.7 h1:RncmNWe66zWDGMpDYFRXmReFkkMK7KOstELU/joamao=
|
||||||
|
gorgonia.org/tensor v0.9.7/go.mod h1:yYvRwsd34UdhG98GhzsB4YUVt3cQAQ4amoD/nuyhX+c=
|
||||||
|
gorgonia.org/vecf32 v0.9.0 h1:PClazic1r+JVJ1dEzRXgeiVl4g1/Hf/w+wUSqnco1Xg=
|
||||||
|
gorgonia.org/vecf32 v0.9.0/go.mod h1:NCc+5D2oxddRL11hd+pCB1PEyXWOyiQxfZ/1wwhOXCA=
|
||||||
|
gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||||
|
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||||
|
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
|
@ -139,3 +139,55 @@ func AtgFill_(ptr *Ctensor, self Ctensor, value Cscalar) {
|
||||||
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
|
func AtgRandnLike(ptr *Ctensor, self Ctensor) {
|
||||||
C.atg_rand_like(ptr, self)
|
C.atg_rand_like(ptr, self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_log_softmax(tensor *, tensor self, int64_t dim, int dtype);
|
||||||
|
func AtgLogSoftmax(ptr *Ctensor, self Ctensor, dim int64, dtype int32) {
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||||
|
|
||||||
|
C.atg_log_softmax(ptr, self, cdim, cdtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_nll_loss(tensor *, tensor self, tensor target, tensor weight, int64_t reduction, int64_t ignore_index);
|
||||||
|
func AtgNllLoss(ptr *Ctensor, self Ctensor, target Ctensor, weight Ctensor, reduction int64, ignoreIndex int64) {
|
||||||
|
creduction := *(*C.int64_t)(unsafe.Pointer(&reduction))
|
||||||
|
cignoreIndex := *(*C.int64_t)(unsafe.Pointer(&ignoreIndex))
|
||||||
|
|
||||||
|
C.atg_nll_loss(ptr, self, target, weight, creduction, cignoreIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_argmax(tensor *, tensor self, int64_t dim, int keepdim);
|
||||||
|
func AtgArgmax(ptr *Ctensor, self Ctensor, dim int64, keepDim int) {
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
ckeepDim := *(*C.int)(unsafe.Pointer(&keepDim))
|
||||||
|
|
||||||
|
C.atg_argmax(ptr, self, cdim, ckeepDim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_mean(tensor *, tensor self, int dtype);
|
||||||
|
func AtgMean(ptr *Ctensor, self Ctensor, dtype int32) {
|
||||||
|
cdtype := *(*C.int)(unsafe.Pointer(&dtype))
|
||||||
|
|
||||||
|
C.atg_mean(ptr, self, cdtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_permute(tensor *, tensor self, int64_t *dims_data, int dims_len);
|
||||||
|
func AtgPermute(ptr *Ctensor, self Ctensor, dims []int64, dimLen int) {
|
||||||
|
// just get pointer of the first element of the shape
|
||||||
|
cdimsPtr := (*C.int64_t)(unsafe.Pointer(&dims[0]))
|
||||||
|
cdimLen := *(*C.int)(unsafe.Pointer(&dimLen))
|
||||||
|
|
||||||
|
C.atg_permute(ptr, self, cdimsPtr, cdimLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_squeeze1(tensor *, tensor self, int64_t dim);
|
||||||
|
func AtgSqueeze1(ptr *Ctensor, self Ctensor, dim int64) {
|
||||||
|
cdim := *(*C.int64_t)(unsafe.Pointer(&dim))
|
||||||
|
|
||||||
|
C.atg_squeeze1(ptr, self, cdim)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_squeeze_(tensor *, tensor self);
|
||||||
|
func AtgSqueeze1_(ptr *Ctensor, self Ctensor) {
|
||||||
|
C.atg_squeeze_(ptr, self)
|
||||||
|
}
|
||||||
|
|
58
nn/module.go
Normal file
58
nn/module.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package nn
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Module interface is a container with only one method `Forward`
|
||||||
|
//
|
||||||
|
// The following is `module` concept from Pytorch documenation:
|
||||||
|
// Base class for all neural network modules. Your models should also subclass this class.
|
||||||
|
// Modules can also contain other Modules, allowing to nest them in a tree structure.
|
||||||
|
// You can assign the submodules as regular attributes. Submodules assigned in this way will
|
||||||
|
// be registered, and will have their parameters converted too when you call .cuda(), etc.
|
||||||
|
type Module interface {
|
||||||
|
// ModuleT
|
||||||
|
Forward(xs ts.Tensor) ts.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModuleT is a `Module` with an additional train parameter
|
||||||
|
// The train parameter is commonly used to have different behavior
|
||||||
|
// between training and evaluation. E.g. When using dropout or batch-normalization.
|
||||||
|
type ModuleT interface {
|
||||||
|
ForwardT(xs ts.Tensor, train bool) ts.Tensor
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModuleT implements default method `BatchAccuracyForLogits`.
|
||||||
|
// NOTE: when creating a struct that implement `ModuleT`, it should
|
||||||
|
// include `DefaultModule` so that the 'default' methods `BatchAccuracyForLogits`
|
||||||
|
// is automatically implemented.
|
||||||
|
// Concept taken from Rust language trait **Default Implementation**
|
||||||
|
// Ref: https://doc.rust-lang.org/1.22.1/book/second-edition/ch10-02-traits.html
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// type FooModule struct{
|
||||||
|
// DefaultModuleT
|
||||||
|
// OtherField string
|
||||||
|
// }
|
||||||
|
type DefaultModuleT struct{}
|
||||||
|
|
||||||
|
func (dmt *DefaultModuleT) BatchAccuracyForLogits(xs, ys ts.Tensor, batchSize int) float64 {
|
||||||
|
|
||||||
|
var (
|
||||||
|
sumAccuracy float64 = 0.0
|
||||||
|
sampleCount float64 = 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
// TODO: implement Iter2...
|
||||||
|
|
||||||
|
return sumAccuracy / sampleCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should we include tensor in `Module` and `ModuleT` interfaces???
|
||||||
|
// I.e.:
|
||||||
|
// type Module interface{
|
||||||
|
// t.Tensor
|
||||||
|
// Forward(xs ts.Tensor) ts.Tensor
|
||||||
|
// }
|
|
@ -28,6 +28,16 @@ func (ts Tensor) To(device gt.Device) (retVal Tensor, err error) {
|
||||||
return Tensor{ctensor: *ptr}, nil
|
return Tensor{ctensor: *ptr}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustTo(device gt.Device) (retVal Tensor) {
|
||||||
|
var err error
|
||||||
|
retVal, err = ts.To(device)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
func (ts Tensor) Matmul(other Tensor) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
@ -372,3 +382,54 @@ func (ts Tensor) RandnLike() (retVal Tensor, err error) {
|
||||||
|
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Permute(dims []int64) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgPermute(ptr, ts.ctensor, dims)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Squeeze1(dim int64) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgSqueeze1(ptr, ts.ctensor, dim)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustSqueeze1(dim int64) (retVal Tensor) {
|
||||||
|
var err error
|
||||||
|
retVal, err = ts.Squeeze1(dim)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Squeeze_() {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgSqueeze_(ptr, ts.ctensor)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -461,3 +461,12 @@ func (f *Func) Invoke() interface{} {
|
||||||
// How do we match them to output order of signature function
|
// How do we match them to output order of signature function
|
||||||
return f.val.Call(f.meta.InArgs)
|
return f.val.Call(f.meta.InArgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Must is a helper to unwrap function it wraps. If having error,
|
||||||
|
// it will cause panic.
|
||||||
|
func Must(ts Tensor, err error) (retVal Tensor) {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ts
|
||||||
|
}
|
||||||
|
|
128
vision/image.go
Normal file
128
vision/image.go
Normal file
|
@ -0,0 +1,128 @@
|
||||||
|
package vision
|
||||||
|
|
||||||
|
// Utility functions to manipulate images.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// (height, width, channel) -> (channel, height, width)
|
||||||
|
func hwcToCHW(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
var err error
|
||||||
|
retVal, err = tensor.Permute([]int64{2, 0, 1})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func chwToHWC(tensor ts.Tensor) (retVal ts.Tensor) {
|
||||||
|
var err error
|
||||||
|
retVal, err = tensor.Permute([]int64{1, 2, 0})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("hwcToCHW error: %v\n", err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load loads an image from a file.
|
||||||
|
//
|
||||||
|
// On success returns a tensor of shape [channel, height, width].
|
||||||
|
func Load(path string) (retVal ts.Tensor, err error) {
|
||||||
|
var tensor ts.Tensor
|
||||||
|
tensor, err = ts.LoadHwc(path)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = hwcToCHW(tensor)
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save saves an image to a file.
|
||||||
|
//
|
||||||
|
// This expects as input a tensor of shape [channel, height, width].
|
||||||
|
// The image format is based on the filename suffix, supported suffixes
|
||||||
|
// are jpg, png, tga, and bmp.
|
||||||
|
// The tensor input should be of kind UInt8 with values ranging from
|
||||||
|
// 0 to 255.
|
||||||
|
func Save(tensor ts.Tensor, path string) (err error) {
|
||||||
|
t, err := tensor.Totype(gotch.Uint8)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Save - Tensor.Totype() error: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
shape, err := t.Size()
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Save - Tensor.Size() error: %v\n", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case len(shape) == 4 && shape[0] == 1:
|
||||||
|
return ts.SaveHwc(chwToHWC(t.MustSqueeze1(int64(0)).MustTo(gotch.CPU)), path)
|
||||||
|
case len(shape) == 3:
|
||||||
|
return ts.SaveHwc(chwToHWC(t.MustTo(gotch.CPU)), path)
|
||||||
|
default:
|
||||||
|
err = fmt.Errorf("Unexpected size (%v) for image tensor.\n", len(shape))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Resize resizes an image.
|
||||||
|
//
|
||||||
|
// This expects as input a tensor of shape [channel, height, width] and returns
|
||||||
|
// a tensor of shape [channel, out_h, out_w].
|
||||||
|
func Resize(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||||
|
tmpTs, err := ts.ResizeHwc(t, outW, outH)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
retVal = hwcToCHW(tmpTs)
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement
|
||||||
|
func resizePreserveAspectRatioHWC(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||||
|
// TODO: implement
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResizePreserveAspectRatio resizes an image, preserve the aspect ratio by taking a center crop.
|
||||||
|
//
|
||||||
|
// This expects as input a tensor of shape [channel, height, width] and returns
|
||||||
|
func ResizePreserveAspectRatio(t ts.Tensor, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||||
|
return resizePreserveAspectRatioHWC(chwToHWC(t), outW, outH)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadAndResize loads and resizes an image, preserve the aspect ratio by taking a center crop.
|
||||||
|
func LoadAndResize(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||||
|
tensor, err := ts.LoadHwc(path)
|
||||||
|
if err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return resizePreserveAspectRatioHWC(tensor, outW, outH)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: should we need this func???
|
||||||
|
func visitDirs(dir string, files []string) (err error) {
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadDir loads all the images in a directory.
|
||||||
|
func LoadDir(path string, outW int64, outH int64) (retVal ts.Tensor, err error) {
|
||||||
|
// var files []string
|
||||||
|
// TODO: implement it
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user