WIP(example/mnist): auto-grad still not working

This commit is contained in:
sugarme 2020-06-16 01:59:41 +10:00
parent d3ad29cb53
commit cf61333fab
5 changed files with 411 additions and 119 deletions

View File

@ -2,7 +2,6 @@ package main
import (
"fmt"
"log"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
@ -27,21 +26,32 @@ func runLinear() {
fmt.Printf("Test label size: %v\n", ds.TestLabels.MustSize())
device := (gotch.CPU).CInt()
dtype := (gotch.Double).CInt()
dtype := (gotch.Float).CInt()
ws := ts.MustZeros([]int64{ImageDim, Label}, dtype, device).MustSetRequiresGrad(true)
bs := ts.MustZeros([]int64{Label}, dtype, device).MustSetRequiresGrad(true)
fmt.Println(ws.MustSize())
fmt.Println(bs.MustSize())
for epoch := 0; epoch < epochs; epoch++ {
}
}
logits := ds.TrainImages.MustMm(ws).MustAdd(bs)
loss := logits.MustLogSoftmax(-1, dtype).MustNllLoss(ds.TrainLabels)
ws.ZeroGrad()
bs.ZeroGrad()
loss.Backward()
wsGrad := ws.MustGrad().MustMul1(ts.FloatScalar(-1.0))
bsGrad := bs.MustGrad().MustMul1(ts.FloatScalar(-1.0))
wsClone := ws.MustShallowClone()
bsClone := bs.MustShallowClone()
// wsClone.MustAdd_(wsGrad)
// bsClone.MustAdd_(bsGrad)
testLogits := ds.TestImages.MustMm(wsClone.MustAdd(wsGrad)).MustAdd(bsClone.MustAdd(bsGrad))
testAccuracy := testLogits.MustArgmax(-1, false).MustEq1(ds.TestLabels).MustTotype(gotch.Float).MustMean(gotch.Float.CInt()).MustView([]int64{-1}).MustFloat64Value([]int64{0})
fmt.Printf("Epoch: %v - Train loss: %v - Test accuracy: %v\n", epoch, loss.MustView([]int64{-1}).MustFloat64Value([]int64{0}), testAccuracy*100)
func handleError(err error) {
if err != nil {
log.Fatal(err)
}
}

View File

@ -57,11 +57,26 @@ func AtgMul(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_mul(ptr, self, other)
}
// void atg_mul_(tensor *, tensor self, tensor other);
func AtgMul_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_mul_(ptr, self, other)
}
// void atg_mul1(tensor *, tensor self, scalar other);
func AtgMul1(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_mul1(ptr, self, other)
}
// void atg_add(tensor *, tensor self, tensor other);
func AtgAdd(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_add(ptr, self, other)
}
// void atg_add_(tensor *, tensor self, tensor other);
func AtgAdd_(ptr *Ctensor, self Ctensor, other Ctensor) {
C.atg_add_(ptr, self, other)
}
// void atg_totype(tensor *, tensor self, int scalar_type);
func AtgTotype(ptr *Ctensor, self Ctensor, scalar_type int32) {
cscalar_type := *(*C.int)(unsafe.Pointer(&scalar_type))
@ -200,3 +215,21 @@ func AtgStack(ptr *Ctensor, tensorsData []Ctensor, tensorsLen int, dim int64) {
C.atg_stack(ptr, tensorsDataPtr, ctensorsLen, cdim)
}
// void atg_mm(tensor *, tensor self, tensor mat2);
func AtgMm(ptr *Ctensor, self Ctensor, mat2 Ctensor) {
C.atg_mm(ptr, self, mat2)
}
// void atg_view(tensor *, tensor self, int64_t *size_data, int size_len);
func AtgView(ptr *Ctensor, self Ctensor, sizeData []int64, sizeLen int) {
sizeDataPtr := (*C.int64_t)(unsafe.Pointer(&sizeData[0]))
csizeLen := *(*C.int)(unsafe.Pointer(&sizeLen))
C.atg_view(ptr, self, sizeDataPtr, csizeLen)
}
// void atg_div1(tensor *, tensor self, scalar other);
func AtgDiv1(ptr *Ctensor, self Ctensor, other Cscalar) {
C.atg_div1(ptr, self, other)
}

View File

@ -101,25 +101,23 @@ func (ts Tensor) MustDetach_() (retVal Tensor) {
return retVal
}
func (ts Tensor) Zero_() (retVal Tensor, err error) {
func (ts Tensor) Zero_() (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgZero_(ptr, ts.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
return err
}
return Tensor{ctensor: *ptr}, nil
return nil
}
func (ts Tensor) MustZero_() (retVal Tensor) {
retVal, err := ts.Zero_()
func (ts Tensor) MustZero_() {
err := ts.Zero_()
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) SetRequiresGrad(rb bool) (retVal Tensor, err error) {
@ -170,6 +168,46 @@ func (ts Tensor) MustMul(other Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Mul1(other Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMul1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
return Tensor{ctensor: *ptr}, nil
}
func (ts Tensor) MustMul1(other Scalar) (retVal Tensor) {
retVal, err := ts.Mul1(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Mul_(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMul_(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustMul_(other Tensor) {
err := ts.Mul_(other)
if err != nil {
log.Fatal(err)
}
}
func (ts Tensor) Add(other Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -191,6 +229,26 @@ func (ts Tensor) MustAdd(other Tensor) (retVal Tensor) {
return retVal
}
func (ts Tensor) Add_(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgAdd_(ptr, ts.ctensor, other.ctensor)
if err = TorchErr(); err != nil {
return err
}
return nil
}
func (ts Tensor) MustAdd_(other Tensor) {
err := ts.Add_(other)
if err != nil {
log.Fatal(err)
}
}
func (ts Tensor) AddG(other Tensor) (err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
@ -468,3 +526,174 @@ func Stack(tensors []Tensor, dim int64) (retVal Tensor, err error) {
return retVal, nil
}
func (ts Tensor) Mm(mat2 Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMm(ptr, ts.ctensor, mat2.ctensor)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustMm(mat2 Tensor) (retVal Tensor) {
retVal, err := ts.Mm(mat2)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) LogSoftmax(dim int64, dtype int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgLogSoftmax(ptr, ts.ctensor, dim, dtype)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustLogSoftmax(dim int64, dtype int32) (retVal Tensor) {
retVal, err := ts.LogSoftmax(dim, dtype)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) NllLoss(target Tensor) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
weight := NewTensor()
reduction := int64(1) // Mean of loss
ignoreIndex := int64(-100)
defer C.free(unsafe.Pointer(ptr))
lib.AtgNllLoss(ptr, ts.ctensor, target.ctensor, weight.ctensor, reduction, ignoreIndex)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustNllLoss(target Tensor) (retVal Tensor) {
retVal, err := ts.NllLoss(target)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Argmax(dim int64, keepDim bool) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
var ckeepDim int = 0
if keepDim {
ckeepDim = 1
}
lib.AtgArgmax(ptr, ts.ctensor, dim, ckeepDim)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustArgmax(dim int64, keepDim bool) (retVal Tensor) {
retVal, err := ts.Argmax(dim, keepDim)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Mean(dtype int32) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgMean(ptr, ts.ctensor, dtype)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustMean(dtype int32) (retVal Tensor) {
retVal, err := ts.Mean(dtype)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) View(sizeData []int64) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgView(ptr, ts.ctensor, sizeData, len(sizeData))
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustView(sizeData []int64) (retVal Tensor) {
retVal, err := ts.View(sizeData)
if err != nil {
log.Fatal(err)
}
return retVal
}
func (ts Tensor) Div1(other Scalar) (retVal Tensor, err error) {
ptr := (*lib.Ctensor)(unsafe.Pointer(C.malloc(0)))
defer C.free(unsafe.Pointer(ptr))
lib.AtgDiv1(ptr, ts.ctensor, other.cscalar)
if err = TorchErr(); err != nil {
return retVal, err
}
retVal = Tensor{ctensor: *ptr}
return retVal, nil
}
func (ts Tensor) MustDiv1(other Scalar) (retVal Tensor) {
retVal, err := ts.Div1(other)
if err != nil {
log.Fatal(err)
}
return retVal
}

View File

@ -311,6 +311,15 @@ func (ts Tensor) Eq1(other Tensor) (retVal Tensor, err error) {
}
func (ts Tensor) MustEq1(other Tensor) (retVal Tensor) {
retVal, err := ts.Eq1(other)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Float64Value returns a float value on tensors holding a single element.
// An error is returned otherwise.
// double at_double_value_at_indexes(tensor, int64_t *indexes, int indexes_len);
@ -330,6 +339,14 @@ func (ts Tensor) Float64Value(idx []int64) (retVal float64, err error) {
return retVal, err
}
func (ts Tensor) MustFloat64Value(idx []int64) (retVal float64) {
retVal, err := ts.Float64Value(idx)
if err != nil {
log.Fatal(err)
}
return retVal
}
// Int64Value returns an int value on tensors holding a single element. An error is
// returned otherwise.
func (ts Tensor) Int64Value(idx []int64) (retVal int64, err error) {

View File

@ -6,123 +6,126 @@ package vision
// http://yann.lecun.com/exdb/mnist/
import (
"encoding/binary"
"io"
"fmt"
"log"
"os"
"path/filepath"
"github.com/sugarme/gotch"
ts "github.com/sugarme/gotch/tensor"
)
// Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
const numLabels = 10
const pixelRange = 255
const (
imageMagic = 0x00000803
labelMagic = 0x00000801
// Width of the input tensor / picture
Width = 28
// Height of the input tensor / picture
Height = 28
)
func readLabels(r io.Reader, e error) (retVal ts.Tensor) {
if e != nil {
log.Fatalf("readLabels errors: %v\n", e)
// readInt32 read 4 bytes and convert to MSB first (big endian) interger.
func readInt32(f *os.File) (retVal int, err error) {
buf := make([]byte, 4)
n, err := f.Read(buf)
switch {
case err != nil:
return 0, err
case n != 4:
err = fmt.Errorf("Invalid format: %v", f.Name())
return 0, err
}
var (
magic int32
n int32
err error
)
// Check magic number
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
if magic != labelMagic {
log.Fatal(os.ErrInvalid)
// flip to big endian
var v int = 0
for _, i := range buf {
v = v*256 + int(i)
}
// Now decode number
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
return v, nil
}
// label is a digit number range 0 - 9
labels := make([]uint8, n)
for i := 0; i < int(n); i++ {
var l uint8
if err := binary.Read(r, binary.BigEndian, &l); err != nil {
log.Fatalf("readLabels - binary.Read error: %v\n", err)
}
labels[i] = l
}
retVal, err = ts.OfSlice(labels)
// checkMagicNumber checks the magic number located at the first 4 bytes of
// mnist files.
func checkMagicNumber(f *os.File, wantNumber int) (err error) {
gotNumber, err := readInt32(f)
if err != nil {
log.Fatalf("readLabels - ts.OfSlice error: %v\n", err)
return err
}
if gotNumber != wantNumber {
err = fmt.Errorf("incorrect magic number: got %v want %v\n", gotNumber, wantNumber)
return err
}
return nil
}
func readLabels(filename string) (retVal ts.Tensor) {
f, err := os.Open(filename)
if err != nil {
log.Fatalf("readLabels errors: %v\n", err)
}
defer f.Close()
if err = checkMagicNumber(f, 2049); err != nil {
log.Fatal(err)
}
samples, err := readInt32(f)
if err != nil {
log.Fatal(err)
}
var data []uint8 = make([]uint8, samples)
len, err := f.Read(data)
if err != nil || len != samples {
err = fmt.Errorf("invalid format %v", f.Name())
log.Fatal(err)
}
labelsTs, err := ts.OfSlice(data)
if err != nil {
err = fmt.Errorf("create label tensor err.")
log.Fatal(err)
}
retVal = labelsTs.MustTotype(gotch.Int64)
return retVal
}
func readImages(r io.Reader, e error) (retVal ts.Tensor) {
if e != nil {
log.Fatalf("readLabels errors: %v\n", e)
}
var (
magic int32
n int32
nrow int32
ncol int32
err error
)
// Check magic number
if err = binary.Read(r, binary.BigEndian, &magic); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if magic != imageMagic {
log.Fatalf("readImages - incorrect imageMagic: %v\n", err) // err is os.ErrInvalid
}
// Now, decode image
if err = binary.Read(r, binary.BigEndian, &n); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if err = binary.Read(r, binary.BigEndian, &nrow); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
if err = binary.Read(r, binary.BigEndian, &ncol); err != nil {
log.Fatalf("readImages - binary.Read error: %v\n", err)
}
imgs := make([]RawImage, n)
m := int(nrow * ncol)
for i := 0; i < int(n); i++ {
imgs[i] = make(RawImage, m)
m_, err := io.ReadFull(r, imgs[i])
if err != nil {
log.Fatalf("readImages - io.ReadFull error: %v\n", err)
}
if m_ != int(m) {
log.Fatalf("readImages - image matrix size mismatched error: %v\n", os.ErrInvalid)
}
}
retVal, err = ts.NewTensorFromData(imgs, []int64{int64(n), int64(nrow * ncol)})
func readImages(filename string) (retVal ts.Tensor) {
f, err := os.Open(filename)
if err != nil {
log.Fatalf("readImages - ts.NewTensorFromData error: %v\n", err)
log.Fatalf("readImages errors: %v\n", err)
}
defer f.Close()
if err = checkMagicNumber(f, 2051); err != nil {
log.Fatal(err)
}
samples, err := readInt32(f)
if err != nil {
log.Fatal(err)
}
rows, err := readInt32(f)
if err != nil {
log.Fatal(err)
}
cols, err := readInt32(f)
if err != nil {
log.Fatal(err)
}
dataLen := samples * rows * cols
var data []uint8 = make([]uint8, dataLen)
len, err := f.Read(data)
if err != nil || len != dataLen {
err = fmt.Errorf("invalid format %v", f.Name())
log.Fatal(err)
}
imagesTs, err := ts.OfSlice(data)
if err != nil {
err = fmt.Errorf("create images tensor err.")
log.Fatal(err)
}
retVal = imagesTs.MustView([]int64{int64(samples), int64(rows * cols)}).MustTotype(gotch.Float).MustDiv1(ts.FloatScalar(255.0))
return retVal
}
@ -141,10 +144,10 @@ func LoadMNISTDir(dir string) (retVal Dataset) {
testLabelsFile := filepath.Join(dir, testLabels)
testImagesFile := filepath.Join(dir, testImages)
trainImagesTs := readImages(os.Open(trainImagesFile))
trainLabelsTs := readLabels(os.Open(trainLabelsFile))
testImagesTs := readImages(os.Open(testImagesFile))
testLabelsTs := readLabels(os.Open(testLabelsFile))
trainImagesTs := readImages(trainImagesFile)
trainLabelsTs := readLabels(trainLabelsFile)
testImagesTs := readImages(testImagesFile)
testLabelsTs := readLabels(testLabelsFile)
return Dataset{
TrainImages: trainImagesTs,