160 lines
3.4 KiB
Go
160 lines
3.4 KiB
Go
package vision
|
|
|
|
// The MNIST hand-written digit dataset.
|
|
//
|
|
// The files can be obtained from the following link:
|
|
// http://yann.lecun.com/exdb/mnist/
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
|
|
"github.com/sugarme/gotch"
|
|
ts "github.com/sugarme/gotch/tensor"
|
|
)
|
|
|
|
// readInt32 read 4 bytes and convert to MSB first (big endian) interger.
|
|
func readInt32(f *os.File) (retVal int, err error) {
|
|
buf := make([]byte, 4)
|
|
n, err := f.Read(buf)
|
|
switch {
|
|
case err != nil:
|
|
return 0, err
|
|
case n != 4:
|
|
err = fmt.Errorf("Invalid format: %v", f.Name())
|
|
return 0, err
|
|
}
|
|
|
|
// flip to big endian
|
|
var v int = 0
|
|
for _, i := range buf {
|
|
v = v*256 + int(i)
|
|
}
|
|
|
|
return v, nil
|
|
}
|
|
|
|
// checkMagicNumber checks the magic number located at the first 4 bytes of
|
|
// mnist files.
|
|
func checkMagicNumber(f *os.File, wantNumber int) (err error) {
|
|
gotNumber, err := readInt32(f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if gotNumber != wantNumber {
|
|
err = fmt.Errorf("incorrect magic number: got %v want %v\n", gotNumber, wantNumber)
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func readLabels(filename string) (retVal ts.Tensor) {
|
|
|
|
f, err := os.Open(filename)
|
|
if err != nil {
|
|
log.Fatalf("readLabels errors: %v\n", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err = checkMagicNumber(f, 2049); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
samples, err := readInt32(f)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
var data []uint8 = make([]uint8, samples)
|
|
len, err := f.Read(data)
|
|
if err != nil || len != samples {
|
|
err = fmt.Errorf("invalid format %v", f.Name())
|
|
log.Fatal(err)
|
|
}
|
|
|
|
labelsTs, err := ts.OfSlice(data)
|
|
if err != nil {
|
|
err = fmt.Errorf("create label tensor err.")
|
|
log.Fatal(err)
|
|
}
|
|
|
|
retVal = labelsTs.MustTotype(gotch.Int64, true)
|
|
|
|
return retVal
|
|
}
|
|
|
|
func readImages(filename string) (retVal ts.Tensor) {
|
|
f, err := os.Open(filename)
|
|
if err != nil {
|
|
log.Fatalf("readImages errors: %v\n", err)
|
|
}
|
|
defer f.Close()
|
|
|
|
if err = checkMagicNumber(f, 2051); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
samples, err := readInt32(f)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
rows, err := readInt32(f)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
cols, err := readInt32(f)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
dataLen := samples * rows * cols
|
|
var data []uint8 = make([]uint8, dataLen)
|
|
len, err := f.Read(data)
|
|
if err != nil || len != dataLen {
|
|
err = fmt.Errorf("invalid format %v", f.Name())
|
|
log.Fatal(err)
|
|
}
|
|
|
|
imagesTs, err := ts.OfSlice(data)
|
|
if err != nil {
|
|
err = fmt.Errorf("create images tensor err.")
|
|
log.Fatal(err)
|
|
}
|
|
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}, true).MustTotype(gotch.Float, true).MustDiv1(ts.FloatScalar(255.0), true)
|
|
|
|
return retVal
|
|
}
|
|
|
|
// LoadMNISTDir loads all MNIST data from a given directory to Dataset
|
|
func LoadMNISTDir(dir string) (retVal Dataset) {
|
|
const (
|
|
trainLabels = "train-labels-idx1-ubyte"
|
|
trainImages = "train-images-idx3-ubyte"
|
|
testLabels = "t10k-labels-idx1-ubyte"
|
|
testImages = "t10k-images-idx3-ubyte"
|
|
)
|
|
|
|
trainLabelsFile := filepath.Join(dir, trainLabels)
|
|
trainImagesFile := filepath.Join(dir, trainImages)
|
|
testLabelsFile := filepath.Join(dir, testLabels)
|
|
testImagesFile := filepath.Join(dir, testImages)
|
|
|
|
trainImagesTs := readImages(trainImagesFile)
|
|
trainLabelsTs := readLabels(trainLabelsFile)
|
|
testImagesTs := readImages(testImagesFile)
|
|
testLabelsTs := readLabels(testLabelsFile)
|
|
|
|
return Dataset{
|
|
TrainImages: trainImagesTs,
|
|
TrainLabels: trainLabelsTs,
|
|
TestImages: testImagesTs,
|
|
TestLabels: testLabelsTs,
|
|
Labels: 10,
|
|
}
|
|
}
|