feat(vision/mnist): completed. WIP(example/mnist)
This commit is contained in:
parent
9efc686748
commit
d3ad29cb53
10
dtype.go
10
dtype.go
|
@ -2,6 +2,7 @@ package gotch
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
// "log"
|
||||
"reflect"
|
||||
)
|
||||
|
@ -101,6 +102,15 @@ func DType2CInt(dt DType) (retVal CInt, err error) {
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (dt DType) CInt() (retVal CInt) {
|
||||
retVal, err := DType2CInt(dt)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||
var found = false
|
||||
for key, val := range dtypeCInt {
|
||||
|
|
|
@ -1,89 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
const numLabels = 10
|
||||
const pixelRange = 255
|
||||
|
||||
const (
|
||||
imageMagic = 0x00000803
|
||||
labelMagic = 0x00000801
|
||||
// Width of the input tensor / picture
|
||||
Width = 28
|
||||
// Height of the input tensor / picture
|
||||
Height = 28
|
||||
)
|
||||
|
||||
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
)
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if magic != labelMagic {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels = make([]Label, n)
|
||||
for i := 0; i < int(n); i++ {
|
||||
var l Label
|
||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
labels[i] = l
|
||||
}
|
||||
return labels, nil
|
||||
}
|
||||
|
||||
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
|
||||
if e != nil {
|
||||
return nil, e
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
nrow int32
|
||||
ncol int32
|
||||
)
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if magic != imageMagic {
|
||||
return nil, err /*os.ErrInvalid*/
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
imgs = make([]RawImage, n)
|
||||
m := int(nrow * ncol)
|
||||
for i := 0; i < int(n); i++ {
|
||||
imgs[i] = make(RawImage, m)
|
||||
m_, err := io.ReadFull(r, imgs[i])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if m_ != int(m) {
|
||||
return nil, os.ErrInvalid
|
||||
}
|
||||
}
|
||||
return imgs, nil
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
const (
|
||||
N = 60000 // number of rows in train data (training size)
|
||||
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
|
||||
outDim = 10 // output features (probabilities for digits 0-9)
|
||||
batchSize = 50
|
||||
batches = N / batchSize
|
||||
epochs = 100
|
||||
)
|
||||
|
||||
var (
|
||||
trainX ts.Tensor
|
||||
trainY ts.Tensor
|
||||
testX ts.Tensor
|
||||
testY ts.Tensor
|
||||
|
||||
err error
|
||||
)
|
||||
|
||||
func init() {
|
||||
// load the train set
|
||||
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
|
||||
// trainY is target tensor with shape{6000, 10}
|
||||
// (represent probabilities for digit 0-9)
|
||||
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
|
||||
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
|
||||
handleError(err)
|
||||
// load our test set
|
||||
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
|
||||
handleError(err)
|
||||
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
for i := 0; i < batches; i++ {
|
||||
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
|
||||
start := i * batchSize
|
||||
end := start + batchSize
|
||||
|
||||
dims, err := trainX.Size()
|
||||
handleError(err)
|
||||
|
||||
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
|
||||
break
|
||||
}
|
||||
|
||||
index := ts.NewNarrow(int64(start), int64(end))
|
||||
|
||||
batchX := trainX.Idx(index)
|
||||
batchX.Print()
|
||||
|
||||
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
|
||||
|
||||
panic("Stop")
|
||||
|
||||
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
|
||||
// handleError(err)
|
||||
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
|
||||
// handleError(err)
|
||||
// xi := batchX.Data().([]float64)
|
||||
// yi := batchY.Data().([]float64)
|
||||
//
|
||||
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
|
||||
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func handleError(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
|
@ -1,143 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Image holds the pixel intensities of an image.
|
||||
// 255 is foreground (black), 0 is background (white).
|
||||
type RawImage []byte
|
||||
|
||||
// Label is a digit label in 0 to 9
|
||||
type Label uint8
|
||||
|
||||
// Load loads the mnist data into two tensors
|
||||
//
|
||||
// typ can be "train", "test"
|
||||
//
|
||||
// loc represents where the mnist files are held
|
||||
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
|
||||
const (
|
||||
trainLabel = "train-labels-idx1-ubyte"
|
||||
trainData = "train-images-idx3-ubyte"
|
||||
testLabel = "t10k-labels-idx1-ubyte"
|
||||
testData = "t10k-images-idx3-ubyte"
|
||||
)
|
||||
|
||||
var labelFile, dataFile string
|
||||
switch typ {
|
||||
case "train", "dev":
|
||||
labelFile = filepath.Join(loc, trainLabel)
|
||||
dataFile = filepath.Join(loc, trainData)
|
||||
case "test":
|
||||
labelFile = filepath.Join(loc, testLabel)
|
||||
dataFile = filepath.Join(loc, testData)
|
||||
}
|
||||
|
||||
var labelData []Label
|
||||
var imageData []RawImage
|
||||
|
||||
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
|
||||
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
|
||||
}
|
||||
|
||||
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
|
||||
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
|
||||
}
|
||||
|
||||
inputs = prepareX(imageData, as)
|
||||
targets = prepareY(labelData, as)
|
||||
return
|
||||
}
|
||||
|
||||
func pixelWeight(px byte) float64 {
|
||||
retVal := float64(px)/pixelRange*0.9 + 0.1
|
||||
if retVal == 1.0 {
|
||||
return 0.999
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func reversePixelWeight(px float64) byte {
|
||||
return byte((pixelRange*px - pixelRange) / 0.9)
|
||||
}
|
||||
|
||||
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
|
||||
rows := len(M)
|
||||
cols := len(M[0])
|
||||
|
||||
var backing interface{}
|
||||
switch dt {
|
||||
case gotch.Double:
|
||||
b := make([]float64, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < len(M[i]); j++ {
|
||||
b = append(b, pixelWeight(M[i][j]))
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
case gotch.Float:
|
||||
b := make([]float32, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < len(M[i]); j++ {
|
||||
b = append(b, float32(pixelWeight(M[i][j])))
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
}
|
||||
|
||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||
if err != nil {
|
||||
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
|
||||
rows := len(N)
|
||||
cols := 10
|
||||
|
||||
var backing interface{}
|
||||
switch dt {
|
||||
case gotch.Double:
|
||||
b := make([]float64, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
if j == int(N[i]) {
|
||||
b = append(b, 0.9)
|
||||
} else {
|
||||
b = append(b, 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
case gotch.Float:
|
||||
b := make([]float32, rows*cols, rows*cols)
|
||||
b = b[:0]
|
||||
for i := 0; i < rows; i++ {
|
||||
for j := 0; j < 10; j++ {
|
||||
if j == int(N[i]) {
|
||||
b = append(b, 0.9)
|
||||
} else {
|
||||
b = append(b, 0.1)
|
||||
}
|
||||
}
|
||||
}
|
||||
backing = b
|
||||
|
||||
}
|
||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
||||
if err != nil {
|
||||
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
|
||||
}
|
||||
return
|
||||
}
|
47
example/mnist/linear.go
Normal file
47
example/mnist/linear.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
"github.com/sugarme/gotch/vision"
|
||||
)
|
||||
|
||||
const (
|
||||
ImageDim int64 = 784
|
||||
Label int64 = 10
|
||||
MnistDir string = "../../data/mnist"
|
||||
|
||||
epochs = 200
|
||||
)
|
||||
|
||||
func runLinear() {
|
||||
var ds vision.Dataset
|
||||
ds = vision.LoadMNISTDir(MnistDir)
|
||||
|
||||
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||
|
||||
device := (gotch.CPU).CInt()
|
||||
dtype := (gotch.Double).CInt()
|
||||
|
||||
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
fmt.Println(ws.MustSize())
|
||||
fmt.Println(bs.MustSize())
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
}
|
||||
}
|
||||
|
||||
func handleError(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
7
example/mnist/main.go
Normal file
7
example/mnist/main.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
package main
|
||||
|
||||
import ()
|
||||
|
||||
func main() {
|
||||
runLinear()
|
||||
}
|
|
@ -304,7 +304,7 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
|
@ -318,7 +318,15 @@ func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal T
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
||||
retVal, err := Zeros(size, optionsKind, optionsDevice)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
|
@ -332,6 +340,14 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te
|
|||
return retVal, nil
|
||||
}
|
||||
|
||||
func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
||||
retVal, err := Ones(size, optionsKind, optionsDevice)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// NOTE: `_` denotes "in-place".
|
||||
func (ts Tensor) Uniform_(from float64, to float64) {
|
||||
var err error
|
||||
|
|
17
vision/dataset.go
Normal file
17
vision/dataset.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package vision
|
||||
|
||||
// A simple dataset structure shared by various computer vision datasets.
|
||||
|
||||
import (
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
type Dataset struct {
|
||||
TrainImages ts.Tensor
|
||||
TrainLabels ts.Tensor
|
||||
TestImages ts.Tensor
|
||||
TestLabels ts.Tensor
|
||||
Labels int64
|
||||
}
|
||||
|
||||
// TODO: implement methods
|
156
vision/mnist.go
Normal file
156
vision/mnist.go
Normal file
|
@ -0,0 +1,156 @@
|
|||
package vision
|
||||
|
||||
// The MNIST hand-written digit dataset.
|
||||
//
|
||||
// The files can be obtained from the following link:
|
||||
// http://yann.lecun.com/exdb/mnist/
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Image holds the pixel intensities of an image.
|
||||
// 255 is foreground (black), 0 is background (white).
|
||||
type RawImage []byte
|
||||
|
||||
const numLabels = 10
|
||||
const pixelRange = 255
|
||||
|
||||
const (
|
||||
imageMagic = 0x00000803
|
||||
labelMagic = 0x00000801
|
||||
// Width of the input tensor / picture
|
||||
Width = 28
|
||||
// Height of the input tensor / picture
|
||||
Height = 28
|
||||
)
|
||||
|
||||
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
|
||||
if e != nil {
|
||||
log.Fatalf("readLabels errors: %v\n", e)
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
err error
|
||||
)
|
||||
|
||||
// Check magic number
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
if magic != labelMagic {
|
||||
log.Fatal(os.ErrInvalid)
|
||||
}
|
||||
|
||||
// Now decode number
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
|
||||
// label is a digit number range 0 - 9
|
||||
labels := make([]uint8, n)
|
||||
for i := 0; i < int(n); i++ {
|
||||
var l uint8
|
||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
labels[i] = l
|
||||
}
|
||||
|
||||
retVal, err = ts.OfSlice(labels)
|
||||
if err != nil {
|
||||
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
|
||||
if e != nil {
|
||||
log.Fatalf("readLabels errors: %v\n", e)
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
nrow int32
|
||||
ncol int32
|
||||
err error
|
||||
)
|
||||
|
||||
// Check magic number
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
|
||||
if magic != imageMagic {
|
||||
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
|
||||
}
|
||||
|
||||
// Now, decode image
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
|
||||
imgs := make([]RawImage, n)
|
||||
m := int(nrow * ncol)
|
||||
for i := 0; i < int(n); i++ {
|
||||
imgs[i] = make(RawImage, m)
|
||||
m_, err := io.ReadFull(r, imgs[i])
|
||||
if err != nil {
|
||||
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
|
||||
}
|
||||
if m_ != int(m) {
|
||||
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
|
||||
if err != nil {
|
||||
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// LoadMNISTDir loads all MNIST data from a given directory to Dataset
|
||||
func LoadMNISTDir(dir string) (retVal Dataset) {
|
||||
const (
|
||||
trainLabels = "train-labels-idx1-ubyte"
|
||||
trainImages = "train-images-idx3-ubyte"
|
||||
testLabels = "t10k-labels-idx1-ubyte"
|
||||
testImages = "t10k-images-idx3-ubyte"
|
||||
)
|
||||
|
||||
trainLabelsFile := filepath.Join(dir, trainLabels)
|
||||
trainImagesFile := filepath.Join(dir, trainImages)
|
||||
testLabelsFile := filepath.Join(dir, testLabels)
|
||||
testImagesFile := filepath.Join(dir, testImages)
|
||||
|
||||
trainImagesTs := readImages(os.Open(trainImagesFile))
|
||||
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
|
||||
testImagesTs := readImages(os.Open(testImagesFile))
|
||||
testLabelsTs := readLabels(os.Open(testLabelsFile))
|
||||
|
||||
return Dataset{
|
||||
TrainImages: trainImagesTs,
|
||||
TrainLabels: trainLabelsTs,
|
||||
TestImages: testImagesTs,
|
||||
TestLabels: testLabelsTs,
|
||||
Labels: 10,
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user