feat(vision/mnist): completed. WIP(example/mnist)

This commit is contained in:
sugarme 2020-06-15 13:19:42 +10:00
parent 9efc686748
commit d3ad29cb53
9 changed files with 255 additions and 320 deletions

View File

@ -2,6 +2,7 @@ package gotch
import (
"fmt"
"log"
// "log"
"reflect"
)
@ -101,6 +102,15 @@ func DType2CInt(dt DType) (retVal CInt, err error) {
return retVal, nil
}
func (dt DType) CInt() (retVal CInt) {
retVal, err := DType2CInt(dt)
if err != nil {
log.Fatal(err)
}
return retVal
}
func CInt2DType(v CInt) (dtype DType, err error) {
var found = false
for key, val := range dtypeCInt {

View File

@ -1,89 +0,0 @@
package main
import (
"encoding/binary"
"io"
"os"
)
const numLabels = 10
const pixelRange = 255
const (
imageMagic = 0x00000803
labelMagic = 0x00000801
// Width of the input tensor / picture
Width = 28
// Height of the input tensor / picture
Height = 28
)
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
if e != nil {
return nil, e
}
var (
magic int32
n int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != labelMagic {
return nil, os.ErrInvalid
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return nil, err
}
labels = make([]Label, n)
for i := 0; i < int(n); i++ {
var l Label
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
return nil, err
}
labels[i] = l
}
return labels, nil
}
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
if e != nil {
return nil, e
}
var (
magic int32
n int32
nrow int32
ncol int32
)
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
return nil, err
}
if magic != imageMagic {
return nil, err /*os.ErrInvalid*/
}
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
return nil, err
}
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
return nil, err
}
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
return nil, err
}
imgs = make([]RawImage, n)
m := int(nrow * ncol)
for i := 0; i < int(n); i++ {
imgs[i] = make(RawImage, m)
m_, err := io.ReadFull(r, imgs[i])
if err != nil {
return nil, err
}
if m_ != int(m) {
return nil, os.ErrInvalid
}
}
return imgs, nil
}

View File

@ -1,86 +0,0 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
const (
N = 60000 // number of rows in train data (training size)
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
outDim = 10 // output features (probabilities for digits 0-9)
batchSize = 50
batches = N / batchSize
epochs = 100
)
var (
trainX ts.Tensor
trainY ts.Tensor
testX ts.Tensor
testY ts.Tensor
err error
)
func init() {
// load the train set
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
// trainY is target tensor with shape{6000, 10}
// (represent probabilities for digit 0-9)
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
handleError(err)
// load our test set
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
handleError(err)
}
func main() {
for epoch := 0; epoch < epochs; epoch++ {
for i := 0; i < batches; i++ {
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
start := i * batchSize
end := start + batchSize
dims, err := trainX.Size()
handleError(err)
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
break
}
index := ts.NewNarrow(int64(start), int64(end))
batchX := trainX.Idx(index)
batchX.Print()
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
panic("Stop")
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
// handleError(err)
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
// handleError(err)
// xi := batchX.Data().([]float64)
// yi := batchY.Data().([]float64)
//
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
}
}
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}

View File

@ -1,143 +0,0 @@
package main
import (
"fmt"
"log"
"os"
"path/filepath"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
// Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
// Label is a digit label in 0 to 9
type Label uint8
// Load loads the mnist data into two tensors
//
// typ can be "train", "test"
//
// loc represents where the mnist files are held
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
const (
trainLabel = "train-labels-idx1-ubyte"
trainData = "train-images-idx3-ubyte"
testLabel = "t10k-labels-idx1-ubyte"
testData = "t10k-images-idx3-ubyte"
)
var labelFile, dataFile string
switch typ {
case "train", "dev":
labelFile = filepath.Join(loc, trainLabel)
dataFile = filepath.Join(loc, trainData)
case "test":
labelFile = filepath.Join(loc, testLabel)
dataFile = filepath.Join(loc, testData)
}
var labelData []Label
var imageData []RawImage
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
}
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
}
inputs = prepareX(imageData, as)
targets = prepareY(labelData, as)
return
}
func pixelWeight(px byte) float64 {
retVal := float64(px)/pixelRange*0.9 + 0.1
if retVal == 1.0 {
return 0.999
}
return retVal
}
func reversePixelWeight(px float64) byte {
return byte((pixelRange*px - pixelRange) / 0.9)
}
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
rows := len(M)
cols := len(M[0])
var backing interface{}
switch dt {
case gotch.Double:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, pixelWeight(M[i][j]))
}
}
backing = b
case gotch.Float:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, float32(pixelWeight(M[i][j])))
}
}
backing = b
}
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
if err != nil {
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
}
return
}
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
rows := len(N)
cols := 10
var backing interface{}
switch dt {
case gotch.Double:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
case gotch.Float:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
}
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
if err != nil {
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
}
return
}

47
example/mnist/linear.go Normal file
View File

@ -0,0 +1,47 @@
package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
"github.com/sugarme/gotch/vision"
)
const (
ImageDim int64 = 784
Label int64 = 10
MnistDir string = "../../data/mnist"
epochs = 200
)
func runLinear() {
var ds vision.Dataset
ds = vision.LoadMNISTDir(MnistDir)
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
device := (gotch.CPU).CInt()
dtype := (gotch.Double).CInt()
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
fmt.Println(ws.MustSize())
fmt.Println(bs.MustSize())
for epoch := 0; epoch < epochs; epoch++ {
}
}
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}

7
example/mnist/main.go Normal file
View File

@ -0,0 +1,7 @@
package main
import ()
func main() {
runLinear()
}

View File

@ -304,7 +304,7 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
return retVal, nil
}
func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -318,7 +318,15 @@ func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal T
return retVal, nil
}
func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
retVal, err := Zeros(size, optionsKind, optionsDevice)
if err != nil {
log.Fatal(err)
}
return retVal
}
func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -332,6 +340,14 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te
return retVal, nil
}
func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
retVal, err := Ones(size, optionsKind, optionsDevice)
if err != nil {
log.Fatal(err)
}
return retVal
}
// NOTE: `_` denotes "in-place".
func (ts Tensor) Uniform_(from float64, to float64) {
var err error

17
vision/dataset.go Normal file
View File

@ -0,0 +1,17 @@
package vision
// A simple dataset structure shared by various computer vision datasets.
import (
ts "github.com/sugarme/gotch/tensor"
)
type Dataset struct {
TrainImages ts.Tensor
TrainLabels ts.Tensor
TestImages ts.Tensor
TestLabels ts.Tensor
Labels int64
}
// TODO: implement methods

156
vision/mnist.go Normal file
View File

@ -0,0 +1,156 @@
package vision
// The MNIST hand-written digit dataset.
//
// The files can be obtained from the following link:
// http://yann.lecun.com/exdb/mnist/
import (
"encoding/binary"
"io"
"log"
"os"
"path/filepath"
ts "github.com/sugarme/gotch/tensor"
)
// Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
const numLabels = 10
const pixelRange = 255
const (
imageMagic = 0x00000803
labelMagic = 0x00000801
// Width of the input tensor / picture
Width = 28
// Height of the input tensor / picture
Height = 28
)
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
if e != nil {
log.Fatalf("readLabels errors: %v\n", e)
}
var (
magic int32
n int32
err error
)
// Check magic number
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
if magic != labelMagic {
log.Fatal(os.ErrInvalid)
}
// Now decode number
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
// label is a digit number range 0 - 9
labels := make([]uint8, n)
for i := 0; i < int(n); i++ {
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
labels[i] = l
}
retVal, err = ts.OfSlice(labels)
if err != nil {
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
}
return retVal
}
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
if e != nil {
log.Fatalf("readLabels errors: %v\n", e)
}
var (
magic int32
n int32
nrow int32
ncol int32
err error
)
// Check magic number
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if magic != imageMagic {
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
}
// Now, decode image
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
imgs := make([]RawImage, n)
m := int(nrow * ncol)
for i := 0; i < int(n); i++ {
imgs[i] = make(RawImage, m)
m_, err := io.ReadFull(r, imgs[i])
if err != nil {
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
}
if m_ != int(m) {
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
}
}
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
if err != nil {
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
}
return retVal
}
// LoadMNISTDir loads all MNIST data from a given directory to Dataset
func LoadMNISTDir(dir string) (retVal Dataset) {
const (
trainLabels = "train-labels-idx1-ubyte"
trainImages = "train-images-idx3-ubyte"
testLabels = "t10k-labels-idx1-ubyte"
testImages = "t10k-images-idx3-ubyte"
)
trainLabelsFile := filepath.Join(dir, trainLabels)
trainImagesFile := filepath.Join(dir, trainImages)
testLabelsFile := filepath.Join(dir, testLabels)
testImagesFile := filepath.Join(dir, testImages)
trainImagesTs := readImages(os.Open(trainImagesFile))
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
testImagesTs := readImages(os.Open(testImagesFile))
testLabelsTs := readLabels(os.Open(testLabelsFile))
return Dataset{
TrainImages: trainImagesTs,
TrainLabels: trainLabelsTs,
TestImages: testImagesTs,
TestLabels: testLabelsTs,
Labels: 10,
}
}