From 6056096eb1fef2a0204715d1315773f797298c38 Mon Sep 17 00:00:00 2001 From: Sugarme Date: Fri, 13 Nov 2020 12:47:07 +1100 Subject: [PATCH] Created MNIST CNN colab notebook --- example/mnist/GoTch_MNIST_CNN.ipynb | 496 ++++++++++++++++++++++++++++ 1 file changed, 496 insertions(+) create mode 100644 example/mnist/GoTch_MNIST_CNN.ipynb diff --git a/example/mnist/GoTch_MNIST_CNN.ipynb b/example/mnist/GoTch_MNIST_CNN.ipynb new file mode 100644 index 0000000..c7521b1 --- /dev/null +++ b/example/mnist/GoTch_MNIST_CNN.ipynb @@ -0,0 +1,496 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "accelerator": "GPU", + "colab": { + "name": "GoTch-MNIST-CNN.ipynb", + "provenance": [], + "collapsed_sections": [ + "fsitX3NbLbTB", + "l6aOKEMarRNH", + "mWe6_MtPK8Kh" + ], + "toc_visible": true, + "include_colab_link": true + }, + "kernelspec": { + "display_name": "Go", + "name": "gophernotes" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dkq9izrfMXPV" + }, + "source": [ + "# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding\n", + "\n", + "This notebook using\n", + "\n", + "1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)\n", + "2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)\n", + "3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fsitX3NbLbTB" + }, + "source": [ + "## Install Go kernel - GopherNotes\n", + "\n", + "*NOTE: refresh/reload (browser) after this step.*" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "caT1iMfshw62", + "outputId": "5be1cd7d-c12e-4083-d523-a035a9fc06a6", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "# run this cell first time using python runtime\n", + "!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null\n", + "!apt update > /dev/null \n", + "!apt install golang-go > /dev/null\n", + "%env GOPATH=/root/go\n", + "!go get -u github.com/gopherdata/gophernotes\n", + "!cp ~/go/bin/gophernotes /usr/bin/\n", + "!mkdir /usr/local/share/jupyter/kernels/gophernotes\n", + "!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \\\n", + " /usr/local/share/jupyter/kernels/gophernotes\n", + "# then refresh (browser), it will now use gophernotes. Skip to golang in later cells" + ], + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\n", + "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n", + "\n", + "\n", + "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n", + "\n", + "env: GOPATH=/root/go\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "l6aOKEMarRNH" + }, + "source": [ + "## Install Pytorch C++ APIs and Go binding - GoTch\n", + "\n", + "NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library\n", + "\n", + "see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6\n", + "\n", + "Google Colab default settings:\n", + "```bash\n", + "LD_LIBRARY_PATH=/usr/lib64-nvidia\n", + "LIBRARY_PATH=/usr/local/cuda/lib64/stubs\n", + "```\n", + "We copy directly `libtorch/lib` to those paths as a hacky way. " + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0N_jGbVft2Y7", + "outputId": "e2052bf1-862f-49e7-ebed-23f09f7fb9e4", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip\n", + "$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local\n", + "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/\n", + "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/" + ], + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "text": [ + "/tmp/libtorch-cxx11 100%[===================>] 765.61M 22.4MB/s in 29s \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "dVo3An-pGzFE" + }, + "source": [ + "import(\"os\")\n", + "os.Setenv(\"CPATH\", \"usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include\")" + ], + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "x_9y6CpZGlPJ", + "outputId": "79b538ff-9320-4097-8e76-43160253a60d", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "$rm -f -- go.mod\n", + "$go mod init github.com/sugarme/playgo\n", + "$go get github.com/sugarme/gotch@v0.3.2" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "text": [ + "go: creating new go.mod: module github.com/sugarme/playgo\n", + "go: downloading github.com/sugarme/gotch v0.3.2\n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mWe6_MtPK8Kh" + }, + "source": [ + "## Download MNIST dataset" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "sYyKlDWchSv3", + "outputId": "ba7b2bb6-f42f-4dbd-c96e-633f4b66e0f1", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n", + "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "\n", + "$gunzip train-images-idx3-ubyte.gz\n", + "$gunzip train-labels-idx1-ubyte.gz\n", + "$gunzip t10k-images-idx3-ubyte.gz\n", + "$gunzip t10k-labels-idx1-ubyte.gz" + ], + "execution_count": 5, + "outputs": [ + { + "output_type": "stream", + "text": [ + "train-images-idx3-u 100%[===================>] 9.45M 11.2MB/s in 0.8s \n", + "train-labels-idx1-u 100%[===================>] 28.20K --.-KB/s in 0.06s \n", + "t10k-images-idx3-ub 100%[===================>] 1.57M 2.03MB/s in 0.8s \n", + "t10k-labels-idx1-ub 100%[===================>] 4.44K --.-KB/s in 0s \n" + ], + "name": "stderr" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7MzMZDHiLJQR" + }, + "source": [ + "## Create Convolution Neural Network (CNN)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "YkLmNrswJHYY" + }, + "source": [ + "import(\n", + " \"fmt\"\n", + " \"time\"\n", + "\n", + " \"github.com/sugarme/gotch\"\n", + " \"github.com/sugarme/gotch/nn\"\n", + " ts \"github.com/sugarme/gotch/tensor\"\n", + " \"github.com/sugarme/gotch/vision\"\n", + ") \n", + "\n", + "const (\n", + " MnistDirCNN string = \"./\"\n", + " epochsCNN = 100\n", + " batchCNN = 256\n", + " batchSize = 256\n", + " LrCNN = 1e-4\n", + ")\n", + "\n", + "type Net struct {\n", + " conv1 *nn.Conv2D\n", + " conv2 *nn.Conv2D\n", + " fc1 *nn.Linear\n", + " fc2 *nn.Linear\n", + "}\n", + "\n", + "func newNet(vs *nn.Path) Net {\n", + " conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())\n", + " conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())\n", + " fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())\n", + " fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())\n", + "\n", + " return Net{conv1,conv2,fc1,fc2}\n", + "}\n", + "\n", + "func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {\n", + " outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)\n", + " defer outView1.MustDrop()\n", + " outC1 := outView1.Apply(n.conv1)\n", + " outMP1 := outC1.MaxPool2DDefault(2, true)\n", + " defer outMP1.MustDrop()\n", + " outC2 := outMP1.Apply(n.conv2)\n", + " outMP2 := outC2.MaxPool2DDefault(2, true)\n", + " outView2 := outMP2.MustView([]int64{-1, 1024}, true)\n", + " defer outView2.MustDrop()\n", + " outFC1 := outView2.Apply(n.fc1)\n", + " outRelu := outFC1.MustRelu(true)\n", + " defer outRelu.MustDrop()\n", + " outDropout := ts.MustDropout(outRelu, 0.5, train)\n", + " defer outDropout.MustDrop()\n", + " return outDropout.Apply(n.fc2)\n", + "}\n", + "\n", + "func trainCNN(){\n", + " var ds *vision.Dataset\n", + " ds = vision.LoadMNISTDir(MnistDirCNN)\n", + " testImages := ds.TestImages\n", + " testLabels := ds.TestLabels\n", + "\n", + " cuda := gotch.CudaBuilder(0)\n", + " vs := nn.NewVarStore(cuda.CudaIfAvailable())\n", + "\n", + " var cnn Net = newNet(vs.Root())\n", + " opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)\n", + " if err != nil {fmt.Print(err)}\n", + "\n", + " var bestAccuracy float64 = 0.0\n", + " startTime := time.Now()\n", + "\n", + " for epoch := 0; epoch < epochsCNN; epoch++ {\n", + " totalSize := ds.TrainImages.MustSize()[0]\n", + " samples := int(totalSize)\n", + " index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)\n", + " imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)\n", + " labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)\n", + "\n", + " batches := samples / batchSize\n", + " batchIndex := 0\n", + " var epocLoss *ts.Tensor\n", + " for i := 0; i < batches; i++ {\n", + " start := batchIndex * batchSize\n", + " size := batchSize\n", + " if samples-start < batchSize {break}\n", + " batchIndex += 1\n", + "\n", + " // Indexing\n", + " narrowIndex := ts.NewNarrow(int64(start), int64(start+size))\n", + " bImages := imagesTs.Idx(narrowIndex)\n", + " bLabels := labelsTs.Idx(narrowIndex)\n", + "\n", + " bImages = bImages.MustTo(vs.Device(), true)\n", + " bLabels = bLabels.MustTo(vs.Device(), true)\n", + "\n", + " logits := cnn.ForwardT(bImages, true)\n", + " loss := logits.CrossEntropyForLogits(bLabels)\n", + "\n", + " opt.BackwardStep(loss)\n", + "\n", + " epocLoss = loss.MustShallowClone()\n", + " epocLoss.Detach_()\n", + "\n", + " bImages.MustDrop()\n", + " bLabels.MustDrop()\n", + " }\n", + "\n", + " testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)\n", + " fmt.Printf(\"Epoch: %v\\t Loss: %.2f \\t Test accuracy: %.2f%%\\n\", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)\n", + " if testAccuracy > bestAccuracy {\n", + " bestAccuracy = testAccuracy\n", + " }\n", + "\n", + " epocLoss.MustDrop()\n", + " imagesTs.MustDrop()\n", + " labelsTs.MustDrop()\n", + " }\n", + "\n", + " fmt.Printf(\"Best test accuracy: %.2f%%\\n\", bestAccuracy*100.0)\n", + " fmt.Printf(\"Taken time:\\t%.2f mins\\n\", time.Since(startTime).Minutes())\n", + "}" + ], + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dIwh8dLLOzJ9" + }, + "source": [ + "## Run train and evaluation" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "0AW2nIHjKA9W", + "outputId": "0123b6df-7edc-4f35-b874-2b4f1ec69e54", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "source": [ + "trainCNN()" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Epoch: 0\t Loss: 0.24 \t Test accuracy: 93.34%\n", + "Epoch: 1\t Loss: 0.19 \t Test accuracy: 96.02%\n", + "Epoch: 2\t Loss: 0.09 \t Test accuracy: 97.10%\n", + "Epoch: 3\t Loss: 0.10 \t Test accuracy: 97.66%\n", + "Epoch: 4\t Loss: 0.03 \t Test accuracy: 98.13%\n", + "Epoch: 5\t Loss: 0.05 \t Test accuracy: 98.43%\n", + "Epoch: 6\t Loss: 0.09 \t Test accuracy: 98.60%\n", + "Epoch: 7\t Loss: 0.05 \t Test accuracy: 98.80%\n", + "Epoch: 8\t Loss: 0.03 \t Test accuracy: 98.80%\n", + "Epoch: 9\t Loss: 0.05 \t Test accuracy: 98.89%\n", + "Epoch: 10\t Loss: 0.02 \t Test accuracy: 98.88%\n", + "Epoch: 11\t Loss: 0.03 \t Test accuracy: 98.98%\n", + "Epoch: 12\t Loss: 0.03 \t Test accuracy: 99.05%\n", + "Epoch: 13\t Loss: 0.04 \t Test accuracy: 99.06%\n", + "Epoch: 14\t Loss: 0.02 \t Test accuracy: 99.07%\n", + "Epoch: 15\t Loss: 0.02 \t Test accuracy: 98.98%\n", + "Epoch: 16\t Loss: 0.02 \t Test accuracy: 99.06%\n", + "Epoch: 17\t Loss: 0.01 \t Test accuracy: 99.09%\n", + "Epoch: 18\t Loss: 0.02 \t Test accuracy: 99.14%\n", + "Epoch: 19\t Loss: 0.01 \t Test accuracy: 99.09%\n", + "Epoch: 20\t Loss: 0.02 \t Test accuracy: 99.12%\n", + "Epoch: 21\t Loss: 0.03 \t Test accuracy: 99.13%\n", + "Epoch: 22\t Loss: 0.02 \t Test accuracy: 99.11%\n", + "Epoch: 23\t Loss: 0.01 \t Test accuracy: 99.10%\n", + "Epoch: 24\t Loss: 0.01 \t Test accuracy: 99.15%\n", + "Epoch: 25\t Loss: 0.01 \t Test accuracy: 99.15%\n", + "Epoch: 26\t Loss: 0.01 \t Test accuracy: 99.22%\n", + "Epoch: 27\t Loss: 0.00 \t Test accuracy: 99.27%\n", + "Epoch: 28\t Loss: 0.01 \t Test accuracy: 99.15%\n", + "Epoch: 29\t Loss: 0.00 \t Test accuracy: 99.09%\n", + "Epoch: 30\t Loss: 0.01 \t Test accuracy: 99.15%\n", + "Epoch: 31\t Loss: 0.00 \t Test accuracy: 99.18%\n", + "Epoch: 32\t Loss: 0.00 \t Test accuracy: 99.15%\n", + "Epoch: 33\t Loss: 0.00 \t Test accuracy: 99.23%\n", + "Epoch: 34\t Loss: 0.01 \t Test accuracy: 99.22%\n", + "Epoch: 35\t Loss: 0.00 \t Test accuracy: 99.16%\n", + "Epoch: 36\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 37\t Loss: 0.01 \t Test accuracy: 99.26%\n", + "Epoch: 38\t Loss: 0.00 \t Test accuracy: 99.15%\n", + "Epoch: 39\t Loss: 0.01 \t Test accuracy: 99.19%\n", + "Epoch: 40\t Loss: 0.01 \t Test accuracy: 99.24%\n", + "Epoch: 41\t Loss: 0.01 \t Test accuracy: 99.24%\n", + "Epoch: 42\t Loss: 0.01 \t Test accuracy: 99.22%\n", + "Epoch: 43\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 44\t Loss: 0.00 \t Test accuracy: 99.14%\n", + "Epoch: 45\t Loss: 0.00 \t Test accuracy: 99.23%\n", + "Epoch: 46\t Loss: 0.00 \t Test accuracy: 99.24%\n", + "Epoch: 47\t Loss: 0.00 \t Test accuracy: 99.21%\n", + "Epoch: 48\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 49\t Loss: 0.00 \t Test accuracy: 99.21%\n", + "Epoch: 50\t Loss: 0.01 \t Test accuracy: 99.26%\n", + "Epoch: 51\t Loss: 0.01 \t Test accuracy: 99.14%\n", + "Epoch: 52\t Loss: 0.01 \t Test accuracy: 99.26%\n", + "Epoch: 53\t Loss: 0.00 \t Test accuracy: 99.23%\n", + "Epoch: 54\t Loss: 0.00 \t Test accuracy: 99.28%\n", + "Epoch: 55\t Loss: 0.02 \t Test accuracy: 99.24%\n", + "Epoch: 56\t Loss: 0.00 \t Test accuracy: 99.26%\n", + "Epoch: 57\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 58\t Loss: 0.01 \t Test accuracy: 99.22%\n", + "Epoch: 59\t Loss: 0.01 \t Test accuracy: 99.25%\n", + "Epoch: 60\t Loss: 0.00 \t Test accuracy: 99.26%\n", + "Epoch: 61\t Loss: 0.01 \t Test accuracy: 99.24%\n", + "Epoch: 62\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 63\t Loss: 0.00 \t Test accuracy: 99.24%\n", + "Epoch: 64\t Loss: 0.00 \t Test accuracy: 99.27%\n", + "Epoch: 65\t Loss: 0.00 \t Test accuracy: 99.27%\n", + "Epoch: 66\t Loss: 0.00 \t Test accuracy: 99.31%\n", + "Epoch: 67\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 68\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 69\t Loss: 0.00 \t Test accuracy: 99.31%\n", + "Epoch: 70\t Loss: 0.00 \t Test accuracy: 99.24%\n", + "Epoch: 71\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 72\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 73\t Loss: 0.00 \t Test accuracy: 99.32%\n", + "Epoch: 74\t Loss: 0.00 \t Test accuracy: 99.26%\n", + "Epoch: 75\t Loss: 0.00 \t Test accuracy: 99.21%\n", + "Epoch: 76\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 77\t Loss: 0.01 \t Test accuracy: 99.24%\n", + "Epoch: 78\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 79\t Loss: 0.00 \t Test accuracy: 99.31%\n", + "Epoch: 80\t Loss: 0.00 \t Test accuracy: 99.18%\n", + "Epoch: 81\t Loss: 0.00 \t Test accuracy: 99.31%\n", + "Epoch: 82\t Loss: 0.00 \t Test accuracy: 99.28%\n", + "Epoch: 83\t Loss: 0.00 \t Test accuracy: 99.18%\n", + "Epoch: 84\t Loss: 0.01 \t Test accuracy: 99.24%\n", + "Epoch: 85\t Loss: 0.00 \t Test accuracy: 99.24%\n", + "Epoch: 86\t Loss: 0.00 \t Test accuracy: 99.29%\n", + "Epoch: 87\t Loss: 0.00 \t Test accuracy: 99.33%\n", + "Epoch: 88\t Loss: 0.00 \t Test accuracy: 99.28%\n", + "Epoch: 89\t Loss: 0.00 \t Test accuracy: 99.24%\n", + "Epoch: 90\t Loss: 0.00 \t Test accuracy: 99.35%\n", + "Epoch: 91\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 92\t Loss: 0.00 \t Test accuracy: 99.23%\n", + "Epoch: 93\t Loss: 0.00 \t Test accuracy: 99.22%\n", + "Epoch: 94\t Loss: 0.00 \t Test accuracy: 99.25%\n", + "Epoch: 95\t Loss: 0.02 \t Test accuracy: 99.25%\n", + "Epoch: 96\t Loss: 0.00 \t Test accuracy: 99.26%\n", + "Epoch: 97\t Loss: 0.00 \t Test accuracy: 99.29%\n", + "Epoch: 98\t Loss: 0.00 \t Test accuracy: 99.23%\n", + "Epoch: 99\t Loss: 0.00 \t Test accuracy: 99.20%\n", + "Best test accuracy: 99.35%\n", + "Taken time:\t1.97 mins\n" + ], + "name": "stdout" + } + ] + } + ] +} \ No newline at end of file