WIP(example/mnist): auto-grad still not working
This commit is contained in:
parent
d3ad29cb53
commit
cf61333fab
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
|
|
||||||
"github.com/sugarme/gotch"
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
|
@ -27,21 +26,32 @@ func runLinear() {
|
||||||
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
|
||||||
|
|
||||||
device := (gotch.CPU).CInt()
|
device := (gotch.CPU).CInt()
|
||||||
dtype := (gotch.Double).CInt()
|
dtype := (gotch.Float).CInt()
|
||||||
|
|
||||||
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
|
||||||
|
|
||||||
fmt.Println(ws.MustSize())
|
|
||||||
fmt.Println(bs.MustSize())
|
|
||||||
|
|
||||||
for epoch := 0; epoch < epochs; epoch++ {
|
for epoch := 0; epoch < epochs; epoch++ {
|
||||||
}
|
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
|
||||||
}
|
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
|
||||||
|
|
||||||
|
ws.ZeroGrad()
|
||||||
|
bs.ZeroGrad()
|
||||||
|
loss.Backward()
|
||||||
|
|
||||||
|
wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||||
|
bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))
|
||||||
|
|
||||||
|
wsClone := ws.MustShallowClone()
|
||||||
|
bsClone := bs.MustShallowClone()
|
||||||
|
|
||||||
|
// wsClone.MustAdd_(wsGrad)
|
||||||
|
// bsClone.MustAdd_(bsGrad)
|
||||||
|
|
||||||
|
testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad))
|
||||||
|
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
|
||||||
|
|
||||||
|
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
|
||||||
|
|
||||||
func handleError(err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,11 +57,26 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_mul(ptr, self, other)
|
C.atg_mul(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_mul_(tensor *, tensor self, tensor other);
|
||||||
|
func AtgMul_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
|
C.atg_mul_(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_mul1(tensor *, tensor self, scalar other);
|
||||||
|
func AtgMul1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||||
|
C.atg_mul1(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
// void atg_add(tensor *, tensor self, tensor other);
|
// void atg_add(tensor *, tensor self, tensor other);
|
||||||
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
C.atg_add(ptr, self, other)
|
C.atg_add(ptr, self, other)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_add_(tensor *, tensor self, tensor other);
|
||||||
|
func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
|
||||||
|
C.atg_add_(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
||||||
// void atg_totype(tensor *, tensor self, int scalar_type);
|
// void atg_totype(tensor *, tensor self, int scalar_type);
|
||||||
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
|
||||||
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
|
||||||
|
@ -200,3 +215,21 @@ func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
|
||||||
|
|
||||||
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// void atg_mm(tensor *, tensor self, tensor mat2);
|
||||||
|
func AtgMm(ptr *Ctensor, self Ctensor, mat2 Ctensor) {
|
||||||
|
C.atg_mm(ptr, self, mat2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_view(tensor *, tensor self, int64_t *size_data, int size_len);
|
||||||
|
func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
|
||||||
|
sizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
|
||||||
|
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
|
||||||
|
|
||||||
|
C.atg_view(ptr, self, sizeDataPtr, csizeLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// void atg_div1(tensor *, tensor self, scalar other);
|
||||||
|
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
|
||||||
|
C.atg_div1(ptr, self, other)
|
||||||
|
}
|
||||||
|
|
|
@ -101,25 +101,23 @@ func (ts Tensor) MustDetach_() (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) Zero_() (retVal Tensor, err error) {
|
func (ts Tensor) Zero_() (err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
lib.AtgZero_(ptr, ts.ctensor)
|
lib.AtgZero_(ptr, ts.ctensor)
|
||||||
|
|
||||||
if err = TorchErr(); err != nil {
|
if err = TorchErr(); err != nil {
|
||||||
return retVal, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return Tensor{ctensor: *ptr}, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) MustZero_() (retVal Tensor) {
|
func (ts Tensor) MustZero_() {
|
||||||
retVal, err := ts.Zero_()
|
err := ts.Zero_()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return retVal
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
|
||||||
|
@ -170,6 +168,46 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
lib.AtgMul1(ptr, ts.ctensor, other.cscalar)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return Tensor{ctensor: *ptr}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Mul1(other)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Mul_(other Tensor) (err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
lib.AtgMul_(ptr, ts.ctensor, other.ctensor)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustMul_(other Tensor) {
|
||||||
|
err := ts.Mul_(other)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
@ -191,6 +229,26 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Add_(other Tensor) (err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
|
||||||
|
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustAdd_(other Tensor) {
|
||||||
|
err := ts.Add_(other)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (ts Tensor) AddG(other Tensor) (err error) {
|
func (ts Tensor) AddG(other Tensor) (err error) {
|
||||||
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
defer C.free(unsafe.Pointer(ptr))
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
@ -468,3 +526,174 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
|
||||||
|
|
||||||
return retVal, nil
|
return retVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgMm(ptr, ts.ctensor, mat2.ctensor)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Mm(mat2)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
|
||||||
|
retVal, err := ts.LogSoftmax(dim, dtype)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
|
||||||
|
weight := NewTensor()
|
||||||
|
|
||||||
|
reduction := int64(1) // Mean of loss
|
||||||
|
ignoreIndex := int64(-100)
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
|
||||||
|
retVal, err := ts.NllLoss(target)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
var ckeepDim int = 0
|
||||||
|
if keepDim {
|
||||||
|
ckeepDim = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
lib.AtgArgmax(ptr, ts.ctensor, dim, ckeepDim)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Argmax(dim, keepDim)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgMean(ptr, ts.ctensor, dtype)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustMean(dtype int32) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Mean(dtype)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData))
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
|
||||||
|
retVal, err := ts.View(sizeData)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
|
||||||
|
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
|
||||||
|
defer C.free(unsafe.Pointer(ptr))
|
||||||
|
|
||||||
|
lib.AtgDiv1(ptr, ts.ctensor, other.cscalar)
|
||||||
|
if err = TorchErr(); err != nil {
|
||||||
|
return retVal, err
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = Tensor{ctensor: *ptr}
|
||||||
|
|
||||||
|
return retVal, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Div1(other)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
|
@ -311,6 +311,15 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) {
|
||||||
|
retVal, err := ts.Eq1(other)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
// Float64Value returns a float value on tensors holding a single element.
|
// Float64Value returns a float value on tensors holding a single element.
|
||||||
// An error is returned otherwise.
|
// An error is returned otherwise.
|
||||||
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
|
||||||
|
@ -330,6 +339,14 @@ func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
|
||||||
return retVal, err
|
return retVal, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ts Tensor) MustFloat64Value(idx []int64) (retVal float64) {
|
||||||
|
retVal, err := ts.Float64Value(idx)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
return retVal
|
||||||
|
}
|
||||||
|
|
||||||
// Int64Value returns an int value on tensors holding a single element. An error is
|
// Int64Value returns an int value on tensors holding a single element. An error is
|
||||||
// returned otherwise.
|
// returned otherwise.
|
||||||
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
|
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {
|
||||||
|
|
205
vision/mnist.go
205
vision/mnist.go
|
@ -6,123 +6,126 @@ package vision
|
||||||
// http://yann.lecun.com/exdb/mnist/
|
// http://yann.lecun.com/exdb/mnist/
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"fmt"
|
||||||
"io"
|
|
||||||
"log"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/sugarme/gotch"
|
||||||
ts "github.com/sugarme/gotch/tensor"
|
ts "github.com/sugarme/gotch/tensor"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Image holds the pixel intensities of an image.
|
// readInt32 read 4 bytes and convert to MSB first (big endian) interger.
|
||||||
// 255 is foreground (black), 0 is background (white).
|
func readInt32(f *os.File) (retVal int, err error) {
|
||||||
type RawImage []byte
|
buf := make([]byte, 4)
|
||||||
|
n, err := f.Read(buf)
|
||||||
const numLabels = 10
|
switch {
|
||||||
const pixelRange = 255
|
case err != nil:
|
||||||
|
return 0, err
|
||||||
const (
|
case n != 4:
|
||||||
imageMagic = 0x00000803
|
err = fmt.Errorf("Invalid format: %v", f.Name())
|
||||||
labelMagic = 0x00000801
|
return 0, err
|
||||||
// Width of the input tensor / picture
|
|
||||||
Width = 28
|
|
||||||
// Height of the input tensor / picture
|
|
||||||
Height = 28
|
|
||||||
)
|
|
||||||
|
|
||||||
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
|
|
||||||
if e != nil {
|
|
||||||
log.Fatalf("readLabels errors: %v\n", e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
// flip to big endian
|
||||||
magic int32
|
var v int = 0
|
||||||
n int32
|
for _, i := range buf {
|
||||||
err error
|
v = v*256 + int(i)
|
||||||
)
|
|
||||||
|
|
||||||
// Check magic number
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
|
||||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
if magic != labelMagic {
|
|
||||||
log.Fatal(os.ErrInvalid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now decode number
|
return v, nil
|
||||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
}
|
||||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// label is a digit number range 0 - 9
|
// checkMagicNumber checks the magic number located at the first 4 bytes of
|
||||||
labels := make([]uint8, n)
|
// mnist files.
|
||||||
for i := 0; i < int(n); i++ {
|
func checkMagicNumber(f *os.File, wantNumber int) (err error) {
|
||||||
var l uint8
|
gotNumber, err := readInt32(f)
|
||||||
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
|
|
||||||
log.Fatalf("readLabels - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
labels[i] = l
|
|
||||||
}
|
|
||||||
|
|
||||||
retVal, err = ts.OfSlice(labels)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if gotNumber != wantNumber {
|
||||||
|
err = fmt.Errorf("incorrect magic number: got %v want %v\n", gotNumber, wantNumber)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLabels(filename string) (retVal ts.Tensor) {
|
||||||
|
|
||||||
|
f, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("readLabels errors: %v\n", err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err = checkMagicNumber(f, 2049); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
samples, err := readInt32(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var data []uint8 = make([]uint8, samples)
|
||||||
|
len, err := f.Read(data)
|
||||||
|
if err != nil || len != samples {
|
||||||
|
err = fmt.Errorf("invalid format %v", f.Name())
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
labelsTs, err := ts.OfSlice(data)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("create label tensor err.")
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
retVal = labelsTs.MustTotype(gotch.Int64)
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
|
||||||
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
|
func readImages(filename string) (retVal ts.Tensor) {
|
||||||
if e != nil {
|
f, err := os.Open(filename)
|
||||||
log.Fatalf("readLabels errors: %v\n", e)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
magic int32
|
|
||||||
n int32
|
|
||||||
nrow int32
|
|
||||||
ncol int32
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
// Check magic number
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
|
|
||||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if magic != imageMagic {
|
|
||||||
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now, decode image
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
|
|
||||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
|
|
||||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
|
|
||||||
log.Fatalf("readImages - binary.Read error: %v\n", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
imgs := make([]RawImage, n)
|
|
||||||
m := int(nrow * ncol)
|
|
||||||
for i := 0; i < int(n); i++ {
|
|
||||||
imgs[i] = make(RawImage, m)
|
|
||||||
m_, err := io.ReadFull(r, imgs[i])
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
|
|
||||||
}
|
|
||||||
if m_ != int(m) {
|
|
||||||
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
|
log.Fatalf("readImages errors: %v\n", err)
|
||||||
}
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
if err = checkMagicNumber(f, 2051); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
samples, err := readInt32(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := readInt32(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
cols, err := readInt32(f)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataLen := samples * rows * cols
|
||||||
|
var data []uint8 = make([]uint8, dataLen)
|
||||||
|
len, err := f.Read(data)
|
||||||
|
if err != nil || len != dataLen {
|
||||||
|
err = fmt.Errorf("invalid format %v", f.Name())
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
imagesTs, err := ts.OfSlice(data)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("create images tensor err.")
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0))
|
||||||
|
|
||||||
return retVal
|
return retVal
|
||||||
}
|
}
|
||||||
|
@ -141,10 +144,10 @@ func LoadMNISTDir(dir string) (retVal Dataset) {
|
||||||
testLabelsFile := filepath.Join(dir, testLabels)
|
testLabelsFile := filepath.Join(dir, testLabels)
|
||||||
testImagesFile := filepath.Join(dir, testImages)
|
testImagesFile := filepath.Join(dir, testImages)
|
||||||
|
|
||||||
trainImagesTs := readImages(os.Open(trainImagesFile))
|
trainImagesTs := readImages(trainImagesFile)
|
||||||
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
|
trainLabelsTs := readLabels(trainLabelsFile)
|
||||||
testImagesTs := readImages(os.Open(testImagesFile))
|
testImagesTs := readImages(testImagesFile)
|
||||||
testLabelsTs := readLabels(os.Open(testLabelsFile))
|
testLabelsTs := readLabels(testLabelsFile)
|
||||||
|
|
||||||
return Dataset{
|
return Dataset{
|
||||||
TrainImages: trainImagesTs,
|
TrainImages: trainImagesTs,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user