From d3ad29cb53a87a6b76362620232b45c6588ee7b0 Mon Sep 17 00:00:00 2001 From: sugarme Date: Mon, 15 Jun 2020 13:19:42 +1000 Subject: [PATCH] feat(vision/mnist): completed. WIP(example/mnist) --- dtype.go | 10 ++ example/mnist-linear/io.go | 89 ----------------- example/mnist-linear/main.go | 86 ---------------- example/mnist-linear/mnist.go | 143 --------------------------- example/mnist/linear.go | 47 +++++++++ example/mnist/main.go | 7 ++ tensor/tensor-generated-sample.go | 20 +++- vision/dataset.go | 17 ++++ vision/mnist.go | 156 ++++++++++++++++++++++++++++++ 9 files changed, 255 insertions(+), 320 deletions(-) delete mode 100644 example/mnist-linear/io.go delete mode 100644 example/mnist-linear/main.go delete mode 100644 example/mnist-linear/mnist.go create mode 100644 example/mnist/linear.go create mode 100644 example/mnist/main.go create mode 100644 vision/dataset.go create mode 100644 vision/mnist.go diff --git a/dtype.go b/dtype.go index 0d8f928..f1ba17d 100644 --- a/dtype.go +++ b/dtype.go @@ -2,6 +2,7 @@ package gotch import ( "fmt" + "log" // "log" "reflect" ) @@ -101,6 +102,15 @@ func DType2CInt(dt DType) (retVal CInt, err error) { return retVal, nil } +func (dt DType) CInt() (retVal CInt) { + retVal, err := DType2CInt(dt) + if err != nil { + log.Fatal(err) + } + + return retVal +} + func CInt2DType(v CInt) (dtype DType, err error) { var found = false for key, val := range dtypeCInt { diff --git a/example/mnist-linear/io.go b/example/mnist-linear/io.go deleted file mode 100644 index f9c8a2d..0000000 --- a/example/mnist-linear/io.go +++ /dev/null @@ -1,89 +0,0 @@ -package main - -import ( - "encoding/binary" - "io" - "os" -) - -const numLabels = 10 -const pixelRange = 255 - -const ( - imageMagic = 0x00000803 - labelMagic = 0x00000801 - // Width of the input tensor / picture - Width = 28 - // Height of the input tensor / picture - Height = 28 -) - -func readLabelFile(r io.Reader, e error) (labels []Label, err error) { - if e != nil { - return nil, e - } - - var ( - magic int32 - n int32 - ) - if err = binary.Read(r, binary.BigEndian, &magic); err != nil { - return nil, err - } - if magic != labelMagic { - return nil, os.ErrInvalid - } - if err = binary.Read(r, binary.BigEndian, &n); err != nil { - return nil, err - } - labels = make([]Label, n) - for i := 0; i < int(n); i++ { - var l Label - if err := binary.Read(r, binary.BigEndian, &l); err != nil { - return nil, err - } - labels[i] = l - } - return labels, nil -} - -func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) { - if e != nil { - return nil, e - } - - var ( - magic int32 - n int32 - nrow int32 - ncol int32 - ) - if err = binary.Read(r, binary.BigEndian, &magic); err != nil { - return nil, err - } - if magic != imageMagic { - return nil, err /*os.ErrInvalid*/ - } - if err = binary.Read(r, binary.BigEndian, &n); err != nil { - return nil, err - } - if err = binary.Read(r, binary.BigEndian, &nrow); err != nil { - return nil, err - } - if err = binary.Read(r, binary.BigEndian, &ncol); err != nil { - return nil, err - } - imgs = make([]RawImage, n) - m := int(nrow * ncol) - for i := 0; i < int(n); i++ { - imgs[i] = make(RawImage, m) - m_, err := io.ReadFull(r, imgs[i]) - if err != nil { - return nil, err - } - if m_ != int(m) { - return nil, os.ErrInvalid - } - } - return imgs, nil -} diff --git a/example/mnist-linear/main.go b/example/mnist-linear/main.go deleted file mode 100644 index 09e19d2..0000000 --- a/example/mnist-linear/main.go +++ /dev/null @@ -1,86 +0,0 @@ -package main - -import ( - "fmt" - "log" - - "github.com/sugarme/gotch" - ts "github.com/sugarme/gotch/tensor" -) - -const ( - N = 60000 // number of rows in train data (training size) - inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix) - outDim = 10 // output features (probabilities for digits 0-9) - batchSize = 50 - batches = N / batchSize - epochs = 100 -) - -var ( - trainX ts.Tensor - trainY ts.Tensor - testX ts.Tensor - testY ts.Tensor - - err error -) - -func init() { - // load the train set - // trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels) - // trainY is target tensor with shape{6000, 10} - // (represent probabilities for digit 0-9) - // E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1] - trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double) - handleError(err) - // load our test set - testX, testY, err = Load("test", "../testdata/mnist", gotch.Double) - handleError(err) - -} - -func main() { - - for epoch := 0; epoch < epochs; epoch++ { - for i := 0; i < batches; i++ { - // NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point. - start := i * batchSize - end := start + batchSize - - dims, err := trainX.Size() - handleError(err) - - if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) { - break - } - - index := ts.NewNarrow(int64(start), int64(end)) - - batchX := trainX.Idx(index) - batchX.Print() - - fmt.Printf("Processed epoch %v - sample %v\n", epoch, i) - - panic("Stop") - - // batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end)) - // handleError(err) - // batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end)) - // handleError(err) - // xi := batchX.Data().([]float64) - // yi := batchY.Data().([]float64) - // - // xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi)) - // yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi)) - - } - - } -} - -func handleError(err error) { - if err != nil { - log.Fatal(err) - } -} diff --git a/example/mnist-linear/mnist.go b/example/mnist-linear/mnist.go deleted file mode 100644 index 9788e33..0000000 --- a/example/mnist-linear/mnist.go +++ /dev/null @@ -1,143 +0,0 @@ -package main - -import ( - "fmt" - "log" - "os" - "path/filepath" - - "github.com/sugarme/gotch" - ts "github.com/sugarme/gotch/tensor" -) - -// Image holds the pixel intensities of an image. -// 255 is foreground (black), 0 is background (white). -type RawImage []byte - -// Label is a digit label in 0 to 9 -type Label uint8 - -// Load loads the mnist data into two tensors -// -// typ can be "train", "test" -// -// loc represents where the mnist files are held -func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) { - const ( - trainLabel = "train-labels-idx1-ubyte" - trainData = "train-images-idx3-ubyte" - testLabel = "t10k-labels-idx1-ubyte" - testData = "t10k-images-idx3-ubyte" - ) - - var labelFile, dataFile string - switch typ { - case "train", "dev": - labelFile = filepath.Join(loc, trainLabel) - dataFile = filepath.Join(loc, trainData) - case "test": - labelFile = filepath.Join(loc, testLabel) - dataFile = filepath.Join(loc, testData) - } - - var labelData []Label - var imageData []RawImage - - if labelData, err = readLabelFile(os.Open(labelFile)); err != nil { - return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err) - } - - if imageData, err = readImageFile(os.Open(dataFile)); err != nil { - return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err) - } - - inputs = prepareX(imageData, as) - targets = prepareY(labelData, as) - return -} - -func pixelWeight(px byte) float64 { - retVal := float64(px)/pixelRange*0.9 + 0.1 - if retVal == 1.0 { - return 0.999 - } - return retVal -} - -func reversePixelWeight(px float64) byte { - return byte((pixelRange*px - pixelRange) / 0.9) -} - -func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) { - rows := len(M) - cols := len(M[0]) - - var backing interface{} - switch dt { - case gotch.Double: - b := make([]float64, rows*cols, rows*cols) - b = b[:0] - for i := 0; i < rows; i++ { - for j := 0; j < len(M[i]); j++ { - b = append(b, pixelWeight(M[i][j])) - } - } - backing = b - case gotch.Float: - b := make([]float32, rows*cols, rows*cols) - b = b[:0] - for i := 0; i < rows; i++ { - for j := 0; j < len(M[i]); j++ { - b = append(b, float32(pixelWeight(M[i][j]))) - } - } - backing = b - } - - retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)}) - if err != nil { - log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err) - } - return -} - -func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) { - rows := len(N) - cols := 10 - - var backing interface{} - switch dt { - case gotch.Double: - b := make([]float64, rows*cols, rows*cols) - b = b[:0] - for i := 0; i < rows; i++ { - for j := 0; j < 10; j++ { - if j == int(N[i]) { - b = append(b, 0.9) - } else { - b = append(b, 0.1) - } - } - } - backing = b - case gotch.Float: - b := make([]float32, rows*cols, rows*cols) - b = b[:0] - for i := 0; i < rows; i++ { - for j := 0; j < 10; j++ { - if j == int(N[i]) { - b = append(b, 0.9) - } else { - b = append(b, 0.1) - } - } - } - backing = b - - } - retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)}) - if err != nil { - log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err) - } - return -} diff --git a/example/mnist/linear.go b/example/mnist/linear.go new file mode 100644 index 0000000..1395f58 --- /dev/null +++ b/example/mnist/linear.go @@ -0,0 +1,47 @@ +package main + +import ( + "fmt" + "log" + + "github.com/sugarme/gotch" + ts "github.com/sugarme/gotch/tensor" + "github.com/sugarme/gotch/vision" +) + +const ( + ImageDim int64 = 784 + Label int64 = 10 + MnistDir string = "../../data/mnist" + + epochs = 200 +) + +func runLinear() { + var ds vision.Dataset + ds = vision.LoadMNISTDir(MnistDir) + + fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize()) + fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize()) + fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize()) + fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize()) + + device := (gotch.CPU).CInt() + dtype := (gotch.Double).CInt() + + ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true) + + bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true) + + fmt.Println(ws.MustSize()) + fmt.Println(bs.MustSize()) + + for epoch := 0; epoch < epochs; epoch++ { + } +} + +func handleError(err error) { + if err != nil { + log.Fatal(err) + } +} diff --git a/example/mnist/main.go b/example/mnist/main.go new file mode 100644 index 0000000..15a945a --- /dev/null +++ b/example/mnist/main.go @@ -0,0 +1,7 @@ +package main + +import () + +func main() { + runLinear() +} diff --git a/tensor/tensor-generated-sample.go b/tensor/tensor-generated-sample.go index 8820504..b21e3d2 100644 --- a/tensor/tensor-generated-sample.go +++ b/tensor/tensor-generated-sample.go @@ -304,7 +304,7 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error) return retVal, nil } -func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { +func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) @@ -318,7 +318,15 @@ func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal T return retVal, nil } -func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { +func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) { + retVal, err := Zeros(size, optionsKind, optionsDevice) + if err != nil { + log.Fatal(err) + } + return retVal +} + +func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) { ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0))) defer C.free(unsafe.Pointer(ptr)) @@ -332,6 +340,14 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te return retVal, nil } +func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) { + retVal, err := Ones(size, optionsKind, optionsDevice) + if err != nil { + log.Fatal(err) + } + return retVal +} + // NOTE: `_` denotes "in-place". func (ts Tensor) Uniform_(from float64, to float64) { var err error diff --git a/vision/dataset.go b/vision/dataset.go new file mode 100644 index 0000000..f483870 --- /dev/null +++ b/vision/dataset.go @@ -0,0 +1,17 @@ +package vision + +// A simple dataset structure shared by various computer vision datasets. + +import ( + ts "github.com/sugarme/gotch/tensor" +) + +type Dataset struct { + TrainImages ts.Tensor + TrainLabels ts.Tensor + TestImages ts.Tensor + TestLabels ts.Tensor + Labels int64 +} + +// TODO: implement methods diff --git a/vision/mnist.go b/vision/mnist.go new file mode 100644 index 0000000..c7b9138 --- /dev/null +++ b/vision/mnist.go @@ -0,0 +1,156 @@ +package vision + +// The MNIST hand-written digit dataset. +// +// The files can be obtained from the following link: +// http://yann.lecun.com/exdb/mnist/ + +import ( + "encoding/binary" + "io" + "log" + "os" + "path/filepath" + + ts "github.com/sugarme/gotch/tensor" +) + +// Image holds the pixel intensities of an image. +// 255 is foreground (black), 0 is background (white). +type RawImage []byte + +const numLabels = 10 +const pixelRange = 255 + +const ( + imageMagic = 0x00000803 + labelMagic = 0x00000801 + // Width of the input tensor / picture + Width = 28 + // Height of the input tensor / picture + Height = 28 +) + +func readLabels(r io.Reader, e error) (retVal ts.Tensor) { + if e != nil { + log.Fatalf("readLabels errors: %v\n", e) + } + + var ( + magic int32 + n int32 + err error + ) + + // Check magic number + if err = binary.Read(r, binary.BigEndian, &magic); err != nil { + log.Fatalf("readLabels - binary.Read error: %v\n", err) + } + if magic != labelMagic { + log.Fatal(os.ErrInvalid) + } + + // Now decode number + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + log.Fatalf("readLabels - binary.Read error: %v\n", err) + } + + // label is a digit number range 0 - 9 + labels := make([]uint8, n) + for i := 0; i < int(n); i++ { + var l uint8 + if err := binary.Read(r, binary.BigEndian, &l); err != nil { + log.Fatalf("readLabels - binary.Read error: %v\n", err) + } + labels[i] = l + } + + retVal, err = ts.OfSlice(labels) + if err != nil { + log.Fatalf("readLabels - ts.OfSlice error: %v\n", err) + } + + return retVal +} + +func readImages(r io.Reader, e error) (retVal ts.Tensor) { + if e != nil { + log.Fatalf("readLabels errors: %v\n", e) + } + + var ( + magic int32 + n int32 + nrow int32 + ncol int32 + err error + ) + + // Check magic number + if err = binary.Read(r, binary.BigEndian, &magic); err != nil { + log.Fatalf("readImages - binary.Read error: %v\n", err) + } + + if magic != imageMagic { + log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid + } + + // Now, decode image + if err = binary.Read(r, binary.BigEndian, &n); err != nil { + log.Fatalf("readImages - binary.Read error: %v\n", err) + } + if err = binary.Read(r, binary.BigEndian, &nrow); err != nil { + log.Fatalf("readImages - binary.Read error: %v\n", err) + } + if err = binary.Read(r, binary.BigEndian, &ncol); err != nil { + log.Fatalf("readImages - binary.Read error: %v\n", err) + } + + imgs := make([]RawImage, n) + m := int(nrow * ncol) + for i := 0; i < int(n); i++ { + imgs[i] = make(RawImage, m) + m_, err := io.ReadFull(r, imgs[i]) + if err != nil { + log.Fatalf("readImages - io.ReadFull error: %v\n", err) + } + if m_ != int(m) { + log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid) + } + } + + retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)}) + if err != nil { + log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err) + } + + return retVal +} + +// LoadMNISTDir loads all MNIST data from a given directory to Dataset +func LoadMNISTDir(dir string) (retVal Dataset) { + const ( + trainLabels = "train-labels-idx1-ubyte" + trainImages = "train-images-idx3-ubyte" + testLabels = "t10k-labels-idx1-ubyte" + testImages = "t10k-images-idx3-ubyte" + ) + + trainLabelsFile := filepath.Join(dir, trainLabels) + trainImagesFile := filepath.Join(dir, trainImages) + testLabelsFile := filepath.Join(dir, testLabels) + testImagesFile := filepath.Join(dir, testImages) + + trainImagesTs := readImages(os.Open(trainImagesFile)) + trainLabelsTs := readLabels(os.Open(trainLabelsFile)) + testImagesTs := readImages(os.Open(testImagesFile)) + testLabelsTs := readLabels(os.Open(testLabelsFile)) + + return Dataset{ + TrainImages: trainImagesTs, + TrainLabels: trainLabelsTs, + TestImages: testImagesTs, + TestLabels: testLabelsTs, + Labels: 10, + } +}