WIP(example/mnist): auto-grad still not working
This commit is contained in:
parent
d3ad29cb53
commit
cf61333fab
|
@ -2,7 +2,6 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
|
@ -27,21 +26,32 @@ func runLinear() {
|
|||
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||
|
||||
device := (gotch.CPU).CInt()
|
||||
dtype := (gotch.Double).CInt()
|
||||
dtype := (gotch.Float).CInt()
|
||||
|
||||
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||
|
||||
fmt.Println(ws.MustSize())
|
||||
fmt.Println(bs.MustSize())
|
||||
|
||||
for epoch := 0; epoch < epochs; epoch++ {
|
||||
}
|
||||
}
|
||||
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||
|
||||
ws.ZeroGrad()
|
||||
bs.ZeroGrad()
|
||||
loss.Backward()
|
||||
|
||||
wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||
bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||
|
||||
wsClone := ws.MustShallowClone()
|
||||
bsClone := bs.MustShallowClone()
|
||||
|
||||
// wsClone.MustAdd_(wsGrad)
|
||||
// bsClone.MustAdd_(bsGrad)
|
||||
|
||||
testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad))
|
||||
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||
|
||||
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||
|
||||
func handleError(err error) {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -57,11 +57,26 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
|||
C.atg_mul(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_mul_(tensor *, tensor self, tensor other);
|
||||
func AtgMul_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_mul_(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_mul1(tensor *, tensor self, scalar other);
|
||||
func AtgMul1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_mul1(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_add(tensor *, tensor self, tensor other);
|
||||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_add(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_add_(tensor *, tensor self, tensor other);
|
||||
func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||
C.atg_add_(ptr, self, other)
|
||||
}
|
||||
|
||||
// void atg_totype(tensor *, tensor self, int scalar_type);
|
||||
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
||||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||
|
@ -200,3 +215,21 @@ func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
|
|||
|
||||
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
||||
}
|
||||
|
||||
// void atg_mm(tensor *, tensor self, tensor mat2);
|
||||
func AtgMm(ptr *Ctensor, self Ctensor, mat2 Ctensor) {
|
||||
C.atg_mm(ptr, self, mat2)
|
||||
}
|
||||
|
||||
// void atg_view(tensor *, tensor self, int64_t *size_data, int size_len);
|
||||
func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
|
||||
sizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
||||
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
||||
|
||||
C.atg_view(ptr, self, sizeDataPtr, csizeLen)
|
||||
}
|
||||
|
||||
// void atg_div1(tensor *, tensor self, scalar other);
|
||||
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||
C.atg_div1(ptr, self, other)
|
||||
}
|
||||
|
|
|
@ -101,25 +101,23 @@ func (ts Tensor) MustDetach_() (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Zero_() (retVal Tensor, err error) {
|
||||
func (ts Tensor) Zero_() (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgZero_(ptr, ts.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
return err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustZero_() (retVal Tensor) {
|
||||
retVal, err := ts.Zero_()
|
||||
func (ts Tensor) MustZero_() {
|
||||
err := ts.Zero_()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
||||
|
@ -170,6 +168,46 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgMul1(ptr, ts.ctensor, other.cscalar)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
return Tensor{ctensor: *ptr}, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Mul1(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mul_(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgMul_(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMul_(other Tensor) {
|
||||
err := ts.Mul_(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
@ -191,6 +229,26 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
|||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Add_(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
|
||||
|
||||
if err = TorchErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustAdd_(other Tensor) {
|
||||
err := ts.Add_(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (ts Tensor) AddG(other Tensor) (err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
@ -468,3 +526,174 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
|
|||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgMm(ptr, ts.ctensor, mat2.ctensor)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Mm(mat2)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.LogSoftmax(dim, dtype)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
|
||||
weight := NewTensor()
|
||||
|
||||
reduction := int64(1) // Mean of loss
|
||||
ignoreIndex := int64(-100)
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.NllLoss(target)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
var ckeepDim int = 0
|
||||
if keepDim {
|
||||
ckeepDim = 1
|
||||
}
|
||||
|
||||
lib.AtgArgmax(ptr, ts.ctensor, dim, ckeepDim)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
|
||||
retVal, err := ts.Argmax(dim, keepDim)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgMean(ptr, ts.ctensor, dtype)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustMean(dtype int32) (retVal Tensor) {
|
||||
retVal, err := ts.Mean(dtype)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData))
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
|
||||
retVal, err := ts.View(sizeData)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
|
||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||
defer C.free(unsafe.Pointer(ptr))
|
||||
|
||||
lib.AtgDiv1(ptr, ts.ctensor, other.cscalar)
|
||||
if err = TorchErr(); err != nil {
|
||||
return retVal, err
|
||||
}
|
||||
|
||||
retVal = Tensor{ctensor: *ptr}
|
||||
|
||||
return retVal, nil
|
||||
}
|
||||
|
||||
func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
||||
retVal, err := ts.Div1(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
|
|
@ -311,6 +311,15 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
|||
|
||||
}
|
||||
|
||||
func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) {
|
||||
retVal, err := ts.Eq1(other)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Float64Value returns a float value on tensors holding a single element.
|
||||
// An error is returned otherwise.
|
||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||
|
@ -330,6 +339,14 @@ func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
|
|||
return retVal, err
|
||||
}
|
||||
|
||||
func (ts Tensor) MustFloat64Value(idx []int64) (retVal float64) {
|
||||
retVal, err := ts.Float64Value(idx)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return retVal
|
||||
}
|
||||
|
||||
// Int64Value returns an int value on tensors holding a single element. An error is
|
||||
// returned otherwise.
|
||||
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
|
||||
|
|
205
vision/mnist.go
205
vision/mnist.go
|
@ -6,123 +6,126 @@ package vision
|
|||
// http://yann.lecun.com/exdb/mnist/
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/sugarme/gotch"
|
||||
ts "github.com/sugarme/gotch/tensor"
|
||||
)
|
||||
|
||||
// Image holds the pixel intensities of an image.
|
||||
// 255 is foreground (black), 0 is background (white).
|
||||
type RawImage []byte
|
||||
|
||||
const numLabels = 10
|
||||
const pixelRange = 255
|
||||
|
||||
const (
|
||||
imageMagic = 0x00000803
|
||||
labelMagic = 0x00000801
|
||||
// Width of the input tensor / picture
|
||||
Width = 28
|
||||
// Height of the input tensor / picture
|
||||
Height = 28
|
||||
)
|
||||
|
||||
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
|
||||
if e != nil {
|
||||
log.Fatalf("readLabels errors: %v\n", e)
|
||||
// readInt32 read 4 bytes and convert to MSB first (big endian) interger.
|
||||
func readInt32(f *os.File) (retVal int, err error) {
|
||||
buf := make([]byte, 4)
|
||||
n, err := f.Read(buf)
|
||||
switch {
|
||||
case err != nil:
|
||||
return 0, err
|
||||
case n != 4:
|
||||
err = fmt.Errorf("Invalid format: %v", f.Name())
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
err error
|
||||
)
|
||||
|
||||
// Check magic number
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
if magic != labelMagic {
|
||||
log.Fatal(os.ErrInvalid)
|
||||
// flip to big endian
|
||||
var v int = 0
|
||||
for _, i := range buf {
|
||||
v = v*256 + int(i)
|
||||
}
|
||||
|
||||
// Now decode number
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// label is a digit number range 0 - 9
|
||||
labels := make([]uint8, n)
|
||||
for i := 0; i < int(n); i++ {
|
||||
var l uint8
|
||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
||||
}
|
||||
labels[i] = l
|
||||
}
|
||||
|
||||
retVal, err = ts.OfSlice(labels)
|
||||
// checkMagicNumber checks the magic number located at the first 4 bytes of
|
||||
// mnist files.
|
||||
func checkMagicNumber(f *os.File, wantNumber int) (err error) {
|
||||
gotNumber, err := readInt32(f)
|
||||
if err != nil {
|
||||
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if gotNumber != wantNumber {
|
||||
err = fmt.Errorf("incorrect magic number: got %v want %v\n", gotNumber, wantNumber)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func readLabels(filename string) (retVal ts.Tensor) {
|
||||
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
log.Fatalf("readLabels errors: %v\n", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err = checkMagicNumber(f, 2049); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
samples, err := readInt32(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
var data []uint8 = make([]uint8, samples)
|
||||
len, err := f.Read(data)
|
||||
if err != nil || len != samples {
|
||||
err = fmt.Errorf("invalid format %v", f.Name())
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
labelsTs, err := ts.OfSlice(data)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create label tensor err.")
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
retVal = labelsTs.MustTotype(gotch.Int64)
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
||||
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
|
||||
if e != nil {
|
||||
log.Fatalf("readLabels errors: %v\n", e)
|
||||
}
|
||||
|
||||
var (
|
||||
magic int32
|
||||
n int32
|
||||
nrow int32
|
||||
ncol int32
|
||||
err error
|
||||
)
|
||||
|
||||
// Check magic number
|
||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
|
||||
if magic != imageMagic {
|
||||
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
|
||||
}
|
||||
|
||||
// Now, decode image
|
||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
||||
}
|
||||
|
||||
imgs := make([]RawImage, n)
|
||||
m := int(nrow * ncol)
|
||||
for i := 0; i < int(n); i++ {
|
||||
imgs[i] = make(RawImage, m)
|
||||
m_, err := io.ReadFull(r, imgs[i])
|
||||
if err != nil {
|
||||
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
|
||||
}
|
||||
if m_ != int(m) {
|
||||
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
|
||||
func readImages(filename string) (retVal ts.Tensor) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
|
||||
log.Fatalf("readImages errors: %v\n", err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err = checkMagicNumber(f, 2051); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
samples, err := readInt32(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := readInt32(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
cols, err := readInt32(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
dataLen := samples * rows * cols
|
||||
var data []uint8 = make([]uint8, dataLen)
|
||||
len, err := f.Read(data)
|
||||
if err != nil || len != dataLen {
|
||||
err = fmt.Errorf("invalid format %v", f.Name())
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
imagesTs, err := ts.OfSlice(data)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("create images tensor err.")
|
||||
log.Fatal(err)
|
||||
}
|
||||
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0))
|
||||
|
||||
return retVal
|
||||
}
|
||||
|
@ -141,10 +144,10 @@ func LoadMNISTDir(dir string) (retVal Dataset) {
|
|||
testLabelsFile := filepath.Join(dir, testLabels)
|
||||
testImagesFile := filepath.Join(dir, testImages)
|
||||
|
||||
trainImagesTs := readImages(os.Open(trainImagesFile))
|
||||
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
|
||||
testImagesTs := readImages(os.Open(testImagesFile))
|
||||
testLabelsTs := readLabels(os.Open(testLabelsFile))
|
||||
trainImagesTs := readImages(trainImagesFile)
|
||||
trainLabelsTs := readLabels(trainLabelsFile)
|
||||
testImagesTs := readImages(testImagesFile)
|
||||
testLabelsTs := readLabels(testLabelsFile)
|
||||
|
||||
return Dataset{
|
||||
TrainImages: trainImagesTs,
|
||||
|
|
Loading…
Reference in New Issue
Block a user