From 992ff860d41e6300dc85e2ac58cb3f659302d749 Mon Sep 17 00:00:00 2001 From: sugarme Date: Sun, 15 Nov 2020 10:32:37 +1100 Subject: [PATCH] updated example code on README and added Colab links --- README.md | 101 +++--- example/mnist/GoTch_MNIST_CNN.ipynb | 496 ---------------------------- example/mnist/README.md | 6 +- 3 files changed, 53 insertions(+), 550 deletions(-) delete mode 100644 example/mnist/GoTch_MNIST_CNN.ipynb diff --git a/README.md b/README.md index cd3f44f..4d8d80d 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,6 @@ func basicOps() { fmt.Printf("%8.3f\n", xs) fmt.Printf("%i", xs) - // output /* (1,.,.) = 0.391 0.055 0.638 0.514 0.757 0.446 @@ -93,46 +92,42 @@ func basicOps() { mul := ts1.MustMatmul(ts2, false) defer mul.MustDrop() - fmt.Println("ts1: ") - ts1.Print() - fmt.Println("ts2: ") - ts2.Print() - fmt.Println("mul tensor (ts1 x ts2): ") - mul.Print() - //ts1: - // 0 1 2 - // 3 4 5 - //[ CPULongType{2,3} ] - //ts2: - // 1 1 1 1 - // 1 1 1 1 - // 1 1 1 1 - //[ CPULongType{3,4} ] - //mul tensor (ts1 x ts2): - // 3 3 3 3 - // 12 12 12 12 - //[ CPULongType{2,4} ] + fmt.Printf("ts1:\n%2d", ts1) + fmt.Printf("ts2:\n%2d", ts2) + fmt.Printf("mul tensor (ts1 x ts2):\n%2d", mul) + + /* + ts1: + 0 1 2 + 3 4 5 + + ts2: + 1 1 1 1 + 1 1 1 1 + 1 1 1 1 + + mul tensor (ts1 x ts2): + 3 3 3 3 + 12 12 12 12 + */ // In-place operation ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU) - fmt.Println("Before:") - ts3.Print() + fmt.Printf("Before:\n%v", ts3) ts3.MustAdd1_(ts.FloatScalar(2.0)) - fmt.Printf("After (ts3 + 2.0): \n") - ts3.Print() - ts3.MustDrop() + fmt.Printf("After (ts3 + 2.0):\n%v", ts3) - //Before: - // 1 1 1 - // 1 1 1 - //[ CPUFloatType{2,3} ] - //After (ts3 + 2.0): - // 3 3 3 - // 3 3 3 - //[ CPUFloatType{2,3} ] + /* + Before: + 1 1 1 + 1 1 1 + After (ts3 + 2.0): + 3 3 3 + 3 3 3 + */ } ``` @@ -142,30 +137,32 @@ func basicOps() { ```go import ( + "fmt" + "github.com/sugarme/gotch" "github.com/sugarme/gotch/nn" ts "github.com/sugarme/gotch/tensor" ) type Net struct { - conv1 nn.Conv2D - conv2 nn.Conv2D - fc nn.Linear + conv1 *nn.Conv2D + conv2 *nn.Conv2D + fc *nn.Linear } - func newNet(vs nn.Path) Net { + func newNet(vs *nn.Path) *Net { conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig()) conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig()) fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig()) - return Net{ + return &Net{ conv1, conv2, fc, } } - func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) { + func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor { xs = xs.MustView([]int64{-1, 1, 8, 8}, false) outC1 := xs.Apply(n.conv1) @@ -177,10 +174,8 @@ func basicOps() { outView2 := outMP2.MustView([]int64{-1, 10}, true) defer outView2.MustDrop() - outFC := outView2.Apply(&n.fc) - + outFC := outView2.Apply(n.fc) return outFC.MustRelu(true) - } func main() { @@ -191,29 +186,33 @@ func basicOps() { xs := ts.MustOnes([]int64{8, 8}, gotch.Float, gotch.CPU) logits := net.ForwardT(xs, false) - logits.Print() + fmt.Printf("Logits: %0.3f", logits) } - // 0.0000 0.0000 0.0000 0.2477 0.2437 0.0000 0.0000 0.0000 0.0000 0.0171 - //[ CPUFloatType{1,10} ] - - + //Logits: 0.000 0.000 0.000 0.225 0.321 0.147 0.000 0.207 0.000 0.000 ``` - Real application examples can be found at [example folder](example/README.md) +## Play with GoTch on Google Colab + +1. [Tensor Initiation](tensor/tensor-initiation.ipynb) Open In Colab +2. [Tensor Indexing](tensor/tensor-indexing.ipynb) Open In Colab +3. [MNIST](mnist) Open In Colab +4. [Tokenizer - BPE model](tokenizer/bpe.ipynb) Open In Colab +5. [transformer - BERT Mask Language Model](transformer/bert-mask-lm.ipynb) Open In Colab +6. [YOLO v3 model infering](yolo/yolo.ipynb) Open In Colab + +More coming soon... + ## Getting Started -- [Documentations](docs/README.md) - - See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc) for detail APIs - ## License GoTch is Apache 2.0 licensed. - ## Acknowledgement - This project has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs) diff --git a/example/mnist/GoTch_MNIST_CNN.ipynb b/example/mnist/GoTch_MNIST_CNN.ipynb deleted file mode 100644 index c7521b1..0000000 --- a/example/mnist/GoTch_MNIST_CNN.ipynb +++ /dev/null @@ -1,496 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "accelerator": "GPU", - "colab": { - "name": "GoTch-MNIST-CNN.ipynb", - "provenance": [], - "collapsed_sections": [ - "fsitX3NbLbTB", - "l6aOKEMarRNH", - "mWe6_MtPK8Kh" - ], - "toc_visible": true, - "include_colab_link": true - }, - "kernelspec": { - "display_name": "Go", - "name": "gophernotes" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dkq9izrfMXPV" - }, - "source": [ - "# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding\n", - "\n", - "This notebook using\n", - "\n", - "1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)\n", - "2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)\n", - "3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fsitX3NbLbTB" - }, - "source": [ - "## Install Go kernel - GopherNotes\n", - "\n", - "*NOTE: refresh/reload (browser) after this step.*" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "caT1iMfshw62", - "outputId": "5be1cd7d-c12e-4083-d523-a035a9fc06a6", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "source": [ - "# run this cell first time using python runtime\n", - "!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null\n", - "!apt update > /dev/null \n", - "!apt install golang-go > /dev/null\n", - "%env GOPATH=/root/go\n", - "!go get -u github.com/gopherdata/gophernotes\n", - "!cp ~/go/bin/gophernotes /usr/bin/\n", - "!mkdir /usr/local/share/jupyter/kernels/gophernotes\n", - "!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \\\n", - " /usr/local/share/jupyter/kernels/gophernotes\n", - "# then refresh (browser), it will now use gophernotes. Skip to golang in later cells" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "text": [ - "\n", - "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n", - "\n", - "\n", - "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n", - "\n", - "env: GOPATH=/root/go\n" - ], - "name": "stdout" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "l6aOKEMarRNH" - }, - "source": [ - "## Install Pytorch C++ APIs and Go binding - GoTch\n", - "\n", - "NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library\n", - "\n", - "see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6\n", - "\n", - "Google Colab default settings:\n", - "```bash\n", - "LD_LIBRARY_PATH=/usr/lib64-nvidia\n", - "LIBRARY_PATH=/usr/local/cuda/lib64/stubs\n", - "```\n", - "We copy directly `libtorch/lib` to those paths as a hacky way. " - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "0N_jGbVft2Y7", - "outputId": "e2052bf1-862f-49e7-ebed-23f09f7fb9e4", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "source": [ - "$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip\n", - "$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local\n", - "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/\n", - "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "text": [ - "/tmp/libtorch-cxx11 100%[===================>] 765.61M 22.4MB/s in 29s \n" - ], - "name": "stderr" - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "dVo3An-pGzFE" - }, - "source": [ - "import(\"os\")\n", - "os.Setenv(\"CPATH\", \"usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include\")" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "x_9y6CpZGlPJ", - "outputId": "79b538ff-9320-4097-8e76-43160253a60d", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "source": [ - "$rm -f -- go.mod\n", - "$go mod init github.com/sugarme/playgo\n", - "$go get github.com/sugarme/gotch@v0.3.2" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "stream", - "text": [ - "go: creating new go.mod: module github.com/sugarme/playgo\n", - "go: downloading github.com/sugarme/gotch v0.3.2\n" - ], - "name": "stderr" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mWe6_MtPK8Kh" - }, - "source": [ - "## Download MNIST dataset" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "sYyKlDWchSv3", - "outputId": "ba7b2bb6-f42f-4dbd-c96e-633f4b66e0f1", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "source": [ - "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", - "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", - "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", - "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", - "\n", - "$gunzip train-images-idx3-ubyte.gz\n", - "$gunzip train-labels-idx1-ubyte.gz\n", - "$gunzip t10k-images-idx3-ubyte.gz\n", - "$gunzip t10k-labels-idx1-ubyte.gz" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "stream", - "text": [ - "train-images-idx3-u 100%[===================>] 9.45M 11.2MB/s in 0.8s \n", - "train-labels-idx1-u 100%[===================>] 28.20K --.-KB/s in 0.06s \n", - "t10k-images-idx3-ub 100%[===================>] 1.57M 2.03MB/s in 0.8s \n", - "t10k-labels-idx1-ub 100%[===================>] 4.44K --.-KB/s in 0s \n" - ], - "name": "stderr" - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7MzMZDHiLJQR" - }, - "source": [ - "## Create Convolution Neural Network (CNN)" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "YkLmNrswJHYY" - }, - "source": [ - "import(\n", - " \"fmt\"\n", - " \"time\"\n", - "\n", - " \"github.com/sugarme/gotch\"\n", - " \"github.com/sugarme/gotch/nn\"\n", - " ts \"github.com/sugarme/gotch/tensor\"\n", - " \"github.com/sugarme/gotch/vision\"\n", - ") \n", - "\n", - "const (\n", - " MnistDirCNN string = \"./\"\n", - " epochsCNN = 100\n", - " batchCNN = 256\n", - " batchSize = 256\n", - " LrCNN = 1e-4\n", - ")\n", - "\n", - "type Net struct {\n", - " conv1 *nn.Conv2D\n", - " conv2 *nn.Conv2D\n", - " fc1 *nn.Linear\n", - " fc2 *nn.Linear\n", - "}\n", - "\n", - "func newNet(vs *nn.Path) Net {\n", - " conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())\n", - " conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())\n", - " fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())\n", - " fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())\n", - "\n", - " return Net{conv1,conv2,fc1,fc2}\n", - "}\n", - "\n", - "func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {\n", - " outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)\n", - " defer outView1.MustDrop()\n", - " outC1 := outView1.Apply(n.conv1)\n", - " outMP1 := outC1.MaxPool2DDefault(2, true)\n", - " defer outMP1.MustDrop()\n", - " outC2 := outMP1.Apply(n.conv2)\n", - " outMP2 := outC2.MaxPool2DDefault(2, true)\n", - " outView2 := outMP2.MustView([]int64{-1, 1024}, true)\n", - " defer outView2.MustDrop()\n", - " outFC1 := outView2.Apply(n.fc1)\n", - " outRelu := outFC1.MustRelu(true)\n", - " defer outRelu.MustDrop()\n", - " outDropout := ts.MustDropout(outRelu, 0.5, train)\n", - " defer outDropout.MustDrop()\n", - " return outDropout.Apply(n.fc2)\n", - "}\n", - "\n", - "func trainCNN(){\n", - " var ds *vision.Dataset\n", - " ds = vision.LoadMNISTDir(MnistDirCNN)\n", - " testImages := ds.TestImages\n", - " testLabels := ds.TestLabels\n", - "\n", - " cuda := gotch.CudaBuilder(0)\n", - " vs := nn.NewVarStore(cuda.CudaIfAvailable())\n", - "\n", - " var cnn Net = newNet(vs.Root())\n", - " opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)\n", - " if err != nil {fmt.Print(err)}\n", - "\n", - " var bestAccuracy float64 = 0.0\n", - " startTime := time.Now()\n", - "\n", - " for epoch := 0; epoch < epochsCNN; epoch++ {\n", - " totalSize := ds.TrainImages.MustSize()[0]\n", - " samples := int(totalSize)\n", - " index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)\n", - " imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)\n", - " labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)\n", - "\n", - " batches := samples / batchSize\n", - " batchIndex := 0\n", - " var epocLoss *ts.Tensor\n", - " for i := 0; i < batches; i++ {\n", - " start := batchIndex * batchSize\n", - " size := batchSize\n", - " if samples-start < batchSize {break}\n", - " batchIndex += 1\n", - "\n", - " // Indexing\n", - " narrowIndex := ts.NewNarrow(int64(start), int64(start+size))\n", - " bImages := imagesTs.Idx(narrowIndex)\n", - " bLabels := labelsTs.Idx(narrowIndex)\n", - "\n", - " bImages = bImages.MustTo(vs.Device(), true)\n", - " bLabels = bLabels.MustTo(vs.Device(), true)\n", - "\n", - " logits := cnn.ForwardT(bImages, true)\n", - " loss := logits.CrossEntropyForLogits(bLabels)\n", - "\n", - " opt.BackwardStep(loss)\n", - "\n", - " epocLoss = loss.MustShallowClone()\n", - " epocLoss.Detach_()\n", - "\n", - " bImages.MustDrop()\n", - " bLabels.MustDrop()\n", - " }\n", - "\n", - " testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)\n", - " fmt.Printf(\"Epoch: %v\\t Loss: %.2f \\t Test accuracy: %.2f%%\\n\", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)\n", - " if testAccuracy > bestAccuracy {\n", - " bestAccuracy = testAccuracy\n", - " }\n", - "\n", - " epocLoss.MustDrop()\n", - " imagesTs.MustDrop()\n", - " labelsTs.MustDrop()\n", - " }\n", - "\n", - " fmt.Printf(\"Best test accuracy: %.2f%%\\n\", bestAccuracy*100.0)\n", - " fmt.Printf(\"Taken time:\\t%.2f mins\\n\", time.Since(startTime).Minutes())\n", - "}" - ], - "execution_count": 6, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dIwh8dLLOzJ9" - }, - "source": [ - "## Run train and evaluation" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "0AW2nIHjKA9W", - "outputId": "0123b6df-7edc-4f35-b874-2b4f1ec69e54", - "colab": { - "base_uri": "https://localhost:8080/" - } - }, - "source": [ - "trainCNN()" - ], - "execution_count": 7, - "outputs": [ - { - "output_type": "stream", - "text": [ - "Epoch: 0\t Loss: 0.24 \t Test accuracy: 93.34%\n", - "Epoch: 1\t Loss: 0.19 \t Test accuracy: 96.02%\n", - "Epoch: 2\t Loss: 0.09 \t Test accuracy: 97.10%\n", - "Epoch: 3\t Loss: 0.10 \t Test accuracy: 97.66%\n", - "Epoch: 4\t Loss: 0.03 \t Test accuracy: 98.13%\n", - "Epoch: 5\t Loss: 0.05 \t Test accuracy: 98.43%\n", - "Epoch: 6\t Loss: 0.09 \t Test accuracy: 98.60%\n", - "Epoch: 7\t Loss: 0.05 \t Test accuracy: 98.80%\n", - "Epoch: 8\t Loss: 0.03 \t Test accuracy: 98.80%\n", - "Epoch: 9\t Loss: 0.05 \t Test accuracy: 98.89%\n", - "Epoch: 10\t Loss: 0.02 \t Test accuracy: 98.88%\n", - "Epoch: 11\t Loss: 0.03 \t Test accuracy: 98.98%\n", - "Epoch: 12\t Loss: 0.03 \t Test accuracy: 99.05%\n", - "Epoch: 13\t Loss: 0.04 \t Test accuracy: 99.06%\n", - "Epoch: 14\t Loss: 0.02 \t Test accuracy: 99.07%\n", - "Epoch: 15\t Loss: 0.02 \t Test accuracy: 98.98%\n", - "Epoch: 16\t Loss: 0.02 \t Test accuracy: 99.06%\n", - "Epoch: 17\t Loss: 0.01 \t Test accuracy: 99.09%\n", - "Epoch: 18\t Loss: 0.02 \t Test accuracy: 99.14%\n", - "Epoch: 19\t Loss: 0.01 \t Test accuracy: 99.09%\n", - "Epoch: 20\t Loss: 0.02 \t Test accuracy: 99.12%\n", - "Epoch: 21\t Loss: 0.03 \t Test accuracy: 99.13%\n", - "Epoch: 22\t Loss: 0.02 \t Test accuracy: 99.11%\n", - "Epoch: 23\t Loss: 0.01 \t Test accuracy: 99.10%\n", - "Epoch: 24\t Loss: 0.01 \t Test accuracy: 99.15%\n", - "Epoch: 25\t Loss: 0.01 \t Test accuracy: 99.15%\n", - "Epoch: 26\t Loss: 0.01 \t Test accuracy: 99.22%\n", - "Epoch: 27\t Loss: 0.00 \t Test accuracy: 99.27%\n", - "Epoch: 28\t Loss: 0.01 \t Test accuracy: 99.15%\n", - "Epoch: 29\t Loss: 0.00 \t Test accuracy: 99.09%\n", - "Epoch: 30\t Loss: 0.01 \t Test accuracy: 99.15%\n", - "Epoch: 31\t Loss: 0.00 \t Test accuracy: 99.18%\n", - "Epoch: 32\t Loss: 0.00 \t Test accuracy: 99.15%\n", - "Epoch: 33\t Loss: 0.00 \t Test accuracy: 99.23%\n", - "Epoch: 34\t Loss: 0.01 \t Test accuracy: 99.22%\n", - "Epoch: 35\t Loss: 0.00 \t Test accuracy: 99.16%\n", - "Epoch: 36\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 37\t Loss: 0.01 \t Test accuracy: 99.26%\n", - "Epoch: 38\t Loss: 0.00 \t Test accuracy: 99.15%\n", - "Epoch: 39\t Loss: 0.01 \t Test accuracy: 99.19%\n", - "Epoch: 40\t Loss: 0.01 \t Test accuracy: 99.24%\n", - "Epoch: 41\t Loss: 0.01 \t Test accuracy: 99.24%\n", - "Epoch: 42\t Loss: 0.01 \t Test accuracy: 99.22%\n", - "Epoch: 43\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 44\t Loss: 0.00 \t Test accuracy: 99.14%\n", - "Epoch: 45\t Loss: 0.00 \t Test accuracy: 99.23%\n", - "Epoch: 46\t Loss: 0.00 \t Test accuracy: 99.24%\n", - "Epoch: 47\t Loss: 0.00 \t Test accuracy: 99.21%\n", - "Epoch: 48\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 49\t Loss: 0.00 \t Test accuracy: 99.21%\n", - "Epoch: 50\t Loss: 0.01 \t Test accuracy: 99.26%\n", - "Epoch: 51\t Loss: 0.01 \t Test accuracy: 99.14%\n", - "Epoch: 52\t Loss: 0.01 \t Test accuracy: 99.26%\n", - "Epoch: 53\t Loss: 0.00 \t Test accuracy: 99.23%\n", - "Epoch: 54\t Loss: 0.00 \t Test accuracy: 99.28%\n", - "Epoch: 55\t Loss: 0.02 \t Test accuracy: 99.24%\n", - "Epoch: 56\t Loss: 0.00 \t Test accuracy: 99.26%\n", - "Epoch: 57\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 58\t Loss: 0.01 \t Test accuracy: 99.22%\n", - "Epoch: 59\t Loss: 0.01 \t Test accuracy: 99.25%\n", - "Epoch: 60\t Loss: 0.00 \t Test accuracy: 99.26%\n", - "Epoch: 61\t Loss: 0.01 \t Test accuracy: 99.24%\n", - "Epoch: 62\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 63\t Loss: 0.00 \t Test accuracy: 99.24%\n", - "Epoch: 64\t Loss: 0.00 \t Test accuracy: 99.27%\n", - "Epoch: 65\t Loss: 0.00 \t Test accuracy: 99.27%\n", - "Epoch: 66\t Loss: 0.00 \t Test accuracy: 99.31%\n", - "Epoch: 67\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 68\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 69\t Loss: 0.00 \t Test accuracy: 99.31%\n", - "Epoch: 70\t Loss: 0.00 \t Test accuracy: 99.24%\n", - "Epoch: 71\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 72\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 73\t Loss: 0.00 \t Test accuracy: 99.32%\n", - "Epoch: 74\t Loss: 0.00 \t Test accuracy: 99.26%\n", - "Epoch: 75\t Loss: 0.00 \t Test accuracy: 99.21%\n", - "Epoch: 76\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 77\t Loss: 0.01 \t Test accuracy: 99.24%\n", - "Epoch: 78\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 79\t Loss: 0.00 \t Test accuracy: 99.31%\n", - "Epoch: 80\t Loss: 0.00 \t Test accuracy: 99.18%\n", - "Epoch: 81\t Loss: 0.00 \t Test accuracy: 99.31%\n", - "Epoch: 82\t Loss: 0.00 \t Test accuracy: 99.28%\n", - "Epoch: 83\t Loss: 0.00 \t Test accuracy: 99.18%\n", - "Epoch: 84\t Loss: 0.01 \t Test accuracy: 99.24%\n", - "Epoch: 85\t Loss: 0.00 \t Test accuracy: 99.24%\n", - "Epoch: 86\t Loss: 0.00 \t Test accuracy: 99.29%\n", - "Epoch: 87\t Loss: 0.00 \t Test accuracy: 99.33%\n", - "Epoch: 88\t Loss: 0.00 \t Test accuracy: 99.28%\n", - "Epoch: 89\t Loss: 0.00 \t Test accuracy: 99.24%\n", - "Epoch: 90\t Loss: 0.00 \t Test accuracy: 99.35%\n", - "Epoch: 91\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 92\t Loss: 0.00 \t Test accuracy: 99.23%\n", - "Epoch: 93\t Loss: 0.00 \t Test accuracy: 99.22%\n", - "Epoch: 94\t Loss: 0.00 \t Test accuracy: 99.25%\n", - "Epoch: 95\t Loss: 0.02 \t Test accuracy: 99.25%\n", - "Epoch: 96\t Loss: 0.00 \t Test accuracy: 99.26%\n", - "Epoch: 97\t Loss: 0.00 \t Test accuracy: 99.29%\n", - "Epoch: 98\t Loss: 0.00 \t Test accuracy: 99.23%\n", - "Epoch: 99\t Loss: 0.00 \t Test accuracy: 99.20%\n", - "Best test accuracy: 99.35%\n", - "Taken time:\t1.97 mins\n" - ], - "name": "stdout" - } - ] - } - ] -} \ No newline at end of file diff --git a/example/mnist/README.md b/example/mnist/README.md index 9fd67dd..4502bef 100644 --- a/example/mnist/README.md +++ b/example/mnist/README.md @@ -1,5 +1,7 @@ # Linear Regression, NN, and CNN on MNIST dataset +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/nb/blob/master/mnist/mnist.ipynb) + ## MNIST - MNIST files can be obtained from [this source](http://yann.lecun.com/exdb/mnist/) and put in `data/mnist` from @@ -29,8 +31,6 @@ - Accuracy should be about **99.3%**. -## Run on Cloud - -[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/gotch_mnist.ipynb) +