feat(vision/mnist): completed. WIP(example/mnist)
This commit is contained in:
parent
9efc686748
commit
d3ad29cb53
10
dtype.go
10
dtype.go
|
@ -2,6 +2,7 @@ package gotch
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"log"
|
||||||
// "log"
|
// "log"
|
||||||
"reflect"
|
"reflect"
|
||||||
)
|
)
|
||||||
|
@ -101,6 +102,15 @@ func DType2CInt(dt DType) (retVal CInt, err error) {
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dt DType) CInt() (retVal CInt) {
|
||||||
|
retVal, err := DType2CInt(dt)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
func CInt2DType(v CInt) (dtype DType, err error) {
|
func CInt2DType(v CInt) (dtype DType, err error) {
|
||||||
var found = false
|
var found = false
|
||||||
for key, val := range dtypeCInt {
|
for key, val := range dtypeCInt {
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
const numLabels = 10
|
|
||||||
const pixelRange = 255
|
|
||||||
|
|
||||||
const (
|
|
||||||
imageMagic = 0x00000803
|
|
||||||
labelMagic = 0x00000801
|
|
||||||
// Width of the input tensor / picture
|
|
||||||
Width = 28
|
|
||||||
// Height of the input tensor / picture
|
|
||||||
Height = 28
|
|
||||||
)
|
|
||||||
|
|
||||||
func readLabelFile(r io.Reader, e error) (labels []Label, err error) {
|
|
||||||
if e != nil {
|
|
||||||
return nil, e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
magic int32
|
|
||||||
n int32
|
|
||||||
)
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if magic != labelMagic {
|
|
||||||
return nil, os.ErrInvalid
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
labels = make([]Label, n)
|
|
||||||
for i := 0; i < int(n); i++ {
|
|
||||||
var l Label
|
|
||||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
labels[i] = l
|
|
||||||
}
|
|
||||||
return labels, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func readImageFile(r io.Reader, e error) (imgs []RawImage, err error) {
|
|
||||||
if e != nil {
|
|
||||||
return nil, e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
magic int32
|
|
||||||
n int32
|
|
||||||
nrow int32
|
|
||||||
ncol int32
|
|
||||||
)
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if magic != imageMagic {
|
|
||||||
return nil, err /*os.ErrInvalid*/
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
imgs = make([]RawImage, n)
|
|
||||||
m := int(nrow * ncol)
|
|
||||||
for i := 0; i < int(n); i++ {
|
|
||||||
imgs[i] = make(RawImage, m)
|
|
||||||
m_, err := io.ReadFull(r, imgs[i])
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if m_ != int(m) {
|
|
||||||
return nil, os.ErrInvalid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return imgs, nil
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
N = 60000 // number of rows in train data (training size)
|
|
||||||
inDim = 784 // input features - columns in train data (image data is 28x28pixel matrix)
|
|
||||||
outDim = 10 // output features (probabilities for digits 0-9)
|
|
||||||
batchSize = 50
|
|
||||||
batches = N / batchSize
|
|
||||||
epochs = 100
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
trainX ts.Tensor
|
|
||||||
trainY ts.Tensor
|
|
||||||
testX ts.Tensor
|
|
||||||
testY ts.Tensor
|
|
||||||
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// load the train set
|
|
||||||
// trainX is input tensor with shape{60000, 784} (image size: 28x28 pixels)
|
|
||||||
// trainY is target tensor with shape{6000, 10}
|
|
||||||
// (represent probabilities for digit 0-9)
|
|
||||||
// E.g. [0.1 0.1 0.1 0.1 0.1 0.9 0.1 0.1 0.1 0.1]
|
|
||||||
trainX, trainY, err = Load("train", "../testdata/mnist", gotch.Double)
|
|
||||||
handleError(err)
|
|
||||||
// load our test set
|
|
||||||
testX, testY, err = Load("test", "../testdata/mnist", gotch.Double)
|
|
||||||
handleError(err)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
|
||||||
for i := 0; i < batches; i++ {
|
|
||||||
// NOTE: `m.Reset()` does not delete data. It just moves pointer to starting point.
|
|
||||||
start := i * batchSize
|
|
||||||
end := start + batchSize
|
|
||||||
|
|
||||||
dims, err := trainX.Size()
|
|
||||||
handleError(err)
|
|
||||||
|
|
||||||
if start > ts.FlattenDim(dims) || end > ts.FlattenDim(dims) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
index := ts.NewNarrow(int64(start), int64(end))
|
|
||||||
|
|
||||||
batchX := trainX.Idx(index)
|
|
||||||
batchX.Print()
|
|
||||||
|
|
||||||
fmt.Printf("Processed epoch %v - sample %v\n", epoch, i)
|
|
||||||
|
|
||||||
panic("Stop")
|
|
||||||
|
|
||||||
// batchX, err := trainX.Slice(nn.MakeRangedSlice(start, end))
|
|
||||||
// handleError(err)
|
|
||||||
// batchY, err := trainY.Slice(nn.MakeRangedSlice(start, end))
|
|
||||||
// handleError(err)
|
|
||||||
// xi := batchX.Data().([]float64)
|
|
||||||
// yi := batchY.Data().([]float64)
|
|
||||||
//
|
|
||||||
// xiT := ts.New(ts.WithShape(batchSize, inDim), ts.WithBacking(xi))
|
|
||||||
// yiT := ts.New(ts.WithShape(batchSize, outDim), ts.WithBacking(yi))
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleError(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,143 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Image holds the pixel intensities of an image.
|
|
||||||
// 255 is foreground (black), 0 is background (white).
|
|
||||||
type RawImage []byte
|
|
||||||
|
|
||||||
// Label is a digit label in 0 to 9
|
|
||||||
type Label uint8
|
|
||||||
|
|
||||||
// Load loads the mnist data into two tensors
|
|
||||||
//
|
|
||||||
// typ can be "train", "test"
|
|
||||||
//
|
|
||||||
// loc represents where the mnist files are held
|
|
||||||
func Load(typ, loc string, as gotch.DType) (inputs, targets ts.Tensor, err error) {
|
|
||||||
const (
|
|
||||||
trainLabel = "train-labels-idx1-ubyte"
|
|
||||||
trainData = "train-images-idx3-ubyte"
|
|
||||||
testLabel = "t10k-labels-idx1-ubyte"
|
|
||||||
testData = "t10k-images-idx3-ubyte"
|
|
||||||
)
|
|
||||||
|
|
||||||
var labelFile, dataFile string
|
|
||||||
switch typ {
|
|
||||||
case "train", "dev":
|
|
||||||
labelFile = filepath.Join(loc, trainLabel)
|
|
||||||
dataFile = filepath.Join(loc, trainData)
|
|
||||||
case "test":
|
|
||||||
labelFile = filepath.Join(loc, testLabel)
|
|
||||||
dataFile = filepath.Join(loc, testData)
|
|
||||||
}
|
|
||||||
|
|
||||||
var labelData []Label
|
|
||||||
var imageData []RawImage
|
|
||||||
|
|
||||||
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
|
|
||||||
return inputs, targets, fmt.Errorf("Unable to read Labels: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
|
|
||||||
return inputs, targets, fmt.Errorf("Unable to read image data: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
inputs = prepareX(imageData, as)
|
|
||||||
targets = prepareY(labelData, as)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func pixelWeight(px byte) float64 {
|
|
||||||
retVal := float64(px)/pixelRange*0.9 + 0.1
|
|
||||||
if retVal == 1.0 {
|
|
||||||
return 0.999
|
|
||||||
}
|
|
||||||
return retVal
|
|
||||||
}
|
|
||||||
|
|
||||||
func reversePixelWeight(px float64) byte {
|
|
||||||
return byte((pixelRange*px - pixelRange) / 0.9)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareX(M []RawImage, dt gotch.DType) (retVal ts.Tensor) {
|
|
||||||
rows := len(M)
|
|
||||||
cols := len(M[0])
|
|
||||||
|
|
||||||
var backing interface{}
|
|
||||||
switch dt {
|
|
||||||
case gotch.Double:
|
|
||||||
b := make([]float64, rows*cols, rows*cols)
|
|
||||||
b = b[:0]
|
|
||||||
for i := 0; i < rows; i++ {
|
|
||||||
for j := 0; j < len(M[i]); j++ {
|
|
||||||
b = append(b, pixelWeight(M[i][j]))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
backing = b
|
|
||||||
case gotch.Float:
|
|
||||||
b := make([]float32, rows*cols, rows*cols)
|
|
||||||
b = b[:0]
|
|
||||||
for i := 0; i < rows; i++ {
|
|
||||||
for j := 0; j < len(M[i]); j++ {
|
|
||||||
b = append(b, float32(pixelWeight(M[i][j])))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
backing = b
|
|
||||||
}
|
|
||||||
|
|
||||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Prepare X - NewTensorFromData error: %v\n", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareY(N []Label, dt gotch.DType) (retVal ts.Tensor) {
|
|
||||||
rows := len(N)
|
|
||||||
cols := 10
|
|
||||||
|
|
||||||
var backing interface{}
|
|
||||||
switch dt {
|
|
||||||
case gotch.Double:
|
|
||||||
b := make([]float64, rows*cols, rows*cols)
|
|
||||||
b = b[:0]
|
|
||||||
for i := 0; i < rows; i++ {
|
|
||||||
for j := 0; j < 10; j++ {
|
|
||||||
if j == int(N[i]) {
|
|
||||||
b = append(b, 0.9)
|
|
||||||
} else {
|
|
||||||
b = append(b, 0.1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
backing = b
|
|
||||||
case gotch.Float:
|
|
||||||
b := make([]float32, rows*cols, rows*cols)
|
|
||||||
b = b[:0]
|
|
||||||
for i := 0; i < rows; i++ {
|
|
||||||
for j := 0; j < 10; j++ {
|
|
||||||
if j == int(N[i]) {
|
|
||||||
b = append(b, 0.9)
|
|
||||||
} else {
|
|
||||||
b = append(b, 0.1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
backing = b
|
|
||||||
|
|
||||||
}
|
|
||||||
retVal, err := ts.NewTensorFromData(backing, []int64{int64(rows), int64(cols)})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("Prepare Y - NewTensorFromData error: %v\n", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
47
example/mnist/linear.go
Normal file
47
example/mnist/linear.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
"github.com/sugarme/gotch/vision"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ImageDim int64 = 784
|
||||||
|
Label int64 = 10
|
||||||
|
MnistDir string = "../../data/mnist"
|
||||||
|
|
||||||
|
epochs = 200
|
||||||
|
)
|
||||||
|
|
||||||
|
func runLinear() {
|
||||||
|
var ds vision.Dataset
|
||||||
|
ds = vision.LoadMNISTDir(MnistDir)
|
||||||
|
|
||||||
|
fmt.Printf("Train image size: %v\n", ds.TrainImages.MustSize())
|
||||||
|
fmt.Printf("Train label size: %v\n", ds.TrainLabels.MustSize())
|
||||||
|
fmt.Printf("Test image size: %v\n", ds.TestImages.MustSize())
|
||||||
|
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||||
|
|
||||||
|
device := (gotch.CPU).CInt()
|
||||||
|
dtype := (gotch.Double).CInt()
|
||||||
|
|
||||||
|
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
|
fmt.Println(ws.MustSize())
|
||||||
|
fmt.Println(bs.MustSize())
|
||||||
|
|
||||||
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleError(err error) {
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
7
example/mnist/main.go
Normal file
7
example/mnist/main.go
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import ()
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
runLinear()
|
||||||
|
}
|
|
@ -304,7 +304,7 @@ func (ts Tensor) IndexSelect(dim int64, index Tensor) (retVal Tensor, err error)
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
func Zeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
@ -318,7 +318,15 @@ func (ts Tensor) Zeros(size []int64, optionsKind, optionsDevice int32) (retVal T
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
func MustZeros(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
||||||
|
retVal, err := Zeros(size, optionsKind, optionsDevice)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func Ones(size []int64, optionsKind, optionsDevice int32) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
@ -332,6 +340,14 @@ func (ts Tensor) Ones(size []int64, optionsKind, optionsDevice int32) (retVal Te
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MustOnes(size []int64, optionsKind, optionsDevice int32) (retVal Tensor) {
|
||||||
|
retVal, err := Ones(size, optionsKind, optionsDevice)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
// NOTE: `_` denotes "in-place".
|
// NOTE: `_` denotes "in-place".
|
||||||
func (ts Tensor) Uniform_(from float64, to float64) {
|
func (ts Tensor) Uniform_(from float64, to float64) {
|
||||||
var err error
|
var err error
|
||||||
|
|
17
vision/dataset.go
Normal file
17
vision/dataset.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package vision
|
||||||
|
|
||||||
|
// A simple dataset structure shared by various computer vision datasets.
|
||||||
|
|
||||||
|
import (
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dataset struct {
|
||||||
|
TrainImages ts.Tensor
|
||||||
|
TrainLabels ts.Tensor
|
||||||
|
TestImages ts.Tensor
|
||||||
|
TestLabels ts.Tensor
|
||||||
|
Labels int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: implement methods
|
156
vision/mnist.go
Normal file
156
vision/mnist.go
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
package vision
|
||||||
|
|
||||||
|
// The MNIST hand-written digit dataset.
|
||||||
|
//
|
||||||
|
// The files can be obtained from the following link:
|
||||||
|
// http://yann.lecun.com/exdb/mnist/
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Image holds the pixel intensities of an image.
|
||||||
|
// 255 is foreground (black), 0 is background (white).
|
||||||
|
type RawImage []byte
|
||||||
|
|
||||||
|
const numLabels = 10
|
||||||
|
const pixelRange = 255
|
||||||
|
|
||||||
|
const (
|
||||||
|
imageMagic = 0x00000803
|
||||||
|
labelMagic = 0x00000801
|
||||||
|
// Width of the input tensor / picture
|
||||||
|
Width = 28
|
||||||
|
// Height of the input tensor / picture
|
||||||
|
Height = 28
|
||||||
|
)
|
||||||
|
|
||||||
|
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
|
||||||
|
if e != nil {
|
||||||
|
log.Fatalf("readLabels errors: %v\n", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
magic int32
|
||||||
|
n int32
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check magic number
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||||
|
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
if magic != labelMagic {
|
||||||
|
log.Fatal(os.ErrInvalid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now decode number
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||||
|
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// label is a digit number range 0 - 9
|
||||||
|
labels := make([]uint8, n)
|
||||||
|
for i := 0; i < int(n); i++ {
|
||||||
|
var l uint8
|
||||||
|
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||||
|
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
labels[i] = l
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal, err = ts.OfSlice(labels)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
|
||||||
|
if e != nil {
|
||||||
|
log.Fatalf("readLabels errors: %v\n", e)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
magic int32
|
||||||
|
n int32
|
||||||
|
nrow int32
|
||||||
|
ncol int32
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
// Check magic number
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||||
|
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if magic != imageMagic {
|
||||||
|
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now, decode image
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||||
|
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||||
|
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||||
|
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
imgs := make([]RawImage, n)
|
||||||
|
m := int(nrow * ncol)
|
||||||
|
for i := 0; i < int(n); i++ {
|
||||||
|
imgs[i] = make(RawImage, m)
|
||||||
|
m_, err := io.ReadFull(r, imgs[i])
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
|
||||||
|
}
|
||||||
|
if m_ != int(m) {
|
||||||
|
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadMNISTDir loads all MNIST data from a given directory to Dataset
|
||||||
|
func LoadMNISTDir(dir string) (retVal Dataset) {
|
||||||
|
const (
|
||||||
|
trainLabels = "train-labels-idx1-ubyte"
|
||||||
|
trainImages = "train-images-idx3-ubyte"
|
||||||
|
testLabels = "t10k-labels-idx1-ubyte"
|
||||||
|
testImages = "t10k-images-idx3-ubyte"
|
||||||
|
)
|
||||||
|
|
||||||
|
trainLabelsFile := filepath.Join(dir, trainLabels)
|
||||||
|
trainImagesFile := filepath.Join(dir, trainImages)
|
||||||
|
testLabelsFile := filepath.Join(dir, testLabels)
|
||||||
|
testImagesFile := filepath.Join(dir, testImages)
|
||||||
|
|
||||||
|
trainImagesTs := readImages(os.Open(trainImagesFile))
|
||||||
|
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
|
||||||
|
testImagesTs := readImages(os.Open(testImagesFile))
|
||||||
|
testLabelsTs := readLabels(os.Open(testLabelsFile))
|
||||||
|
|
||||||
|
return Dataset{
|
||||||
|
TrainImages: trainImagesTs,
|
||||||
|
TrainLabels: trainLabelsTs,
|
||||||
|
TestImages: testImagesTs,
|
||||||
|
TestLabels: testLabelsTs,
|
||||||
|
Labels: 10,
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user