Created MNIST CNN colab notebook

This commit is contained in:
Sugarme 2020-11-13 12:47:07 +11:00
parent 8296fa1a40
commit 6056096eb1

View File

@ -0,0 +1,496 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "GoTch-MNIST-CNN.ipynb",
"provenance": [],
"collapsed_sections": [
"fsitX3NbLbTB",
"l6aOKEMarRNH",
"mWe6_MtPK8Kh"
],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Go",
"name": "gophernotes"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/GoTch_MNIST_CNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dkq9izrfMXPV"
},
"source": [
"# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding\n",
"\n",
"This notebook using\n",
"\n",
"1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)\n",
"2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)\n",
"3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsitX3NbLbTB"
},
"source": [
"## Install Go kernel - GopherNotes\n",
"\n",
"*NOTE: refresh/reload (browser) after this step.*"
]
},
{
"cell_type": "code",
"metadata": {
"id": "caT1iMfshw62",
"outputId": "5be1cd7d-c12e-4083-d523-a035a9fc06a6",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# run this cell first time using python runtime\n",
"!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null\n",
"!apt update > /dev/null \n",
"!apt install golang-go > /dev/null\n",
"%env GOPATH=/root/go\n",
"!go get -u github.com/gopherdata/gophernotes\n",
"!cp ~/go/bin/gophernotes /usr/bin/\n",
"!mkdir /usr/local/share/jupyter/kernels/gophernotes\n",
"!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \\\n",
" /usr/local/share/jupyter/kernels/gophernotes\n",
"# then refresh (browser), it will now use gophernotes. Skip to golang in later cells"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"env: GOPATH=/root/go\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l6aOKEMarRNH"
},
"source": [
"## Install Pytorch C++ APIs and Go binding - GoTch\n",
"\n",
"NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library\n",
"\n",
"see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6\n",
"\n",
"Google Colab default settings:\n",
"```bash\n",
"LD_LIBRARY_PATH=/usr/lib64-nvidia\n",
"LIBRARY_PATH=/usr/local/cuda/lib64/stubs\n",
"```\n",
"We copy directly `libtorch/lib` to those paths as a hacky way. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "0N_jGbVft2Y7",
"outputId": "e2052bf1-862f-49e7-ebed-23f09f7fb9e4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip\n",
"$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local\n",
"$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/\n",
"$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"/tmp/libtorch-cxx11 100%[===================>] 765.61M 22.4MB/s in 29s \n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dVo3An-pGzFE"
},
"source": [
"import(\"os\")\n",
"os.Setenv(\"CPATH\", \"usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include\")"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "x_9y6CpZGlPJ",
"outputId": "79b538ff-9320-4097-8e76-43160253a60d",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$rm -f -- go.mod\n",
"$go mod init github.com/sugarme/playgo\n",
"$go get github.com/sugarme/gotch@v0.3.2"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"go: creating new go.mod: module github.com/sugarme/playgo\n",
"go: downloading github.com/sugarme/gotch v0.3.2\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mWe6_MtPK8Kh"
},
"source": [
"## Download MNIST dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sYyKlDWchSv3",
"outputId": "ba7b2bb6-f42f-4dbd-c96e-633f4b66e0f1",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"\n",
"$gunzip train-images-idx3-ubyte.gz\n",
"$gunzip train-labels-idx1-ubyte.gz\n",
"$gunzip t10k-images-idx3-ubyte.gz\n",
"$gunzip t10k-labels-idx1-ubyte.gz"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"train-images-idx3-u 100%[===================>] 9.45M 11.2MB/s in 0.8s \n",
"train-labels-idx1-u 100%[===================>] 28.20K --.-KB/s in 0.06s \n",
"t10k-images-idx3-ub 100%[===================>] 1.57M 2.03MB/s in 0.8s \n",
"t10k-labels-idx1-ub 100%[===================>] 4.44K --.-KB/s in 0s \n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7MzMZDHiLJQR"
},
"source": [
"## Create Convolution Neural Network (CNN)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YkLmNrswJHYY"
},
"source": [
"import(\n",
" \"fmt\"\n",
" \"time\"\n",
"\n",
" \"github.com/sugarme/gotch\"\n",
" \"github.com/sugarme/gotch/nn\"\n",
" ts \"github.com/sugarme/gotch/tensor\"\n",
" \"github.com/sugarme/gotch/vision\"\n",
") \n",
"\n",
"const (\n",
" MnistDirCNN string = \"./\"\n",
" epochsCNN = 100\n",
" batchCNN = 256\n",
" batchSize = 256\n",
" LrCNN = 1e-4\n",
")\n",
"\n",
"type Net struct {\n",
" conv1 *nn.Conv2D\n",
" conv2 *nn.Conv2D\n",
" fc1 *nn.Linear\n",
" fc2 *nn.Linear\n",
"}\n",
"\n",
"func newNet(vs *nn.Path) Net {\n",
" conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())\n",
" conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())\n",
" fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())\n",
" fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())\n",
"\n",
" return Net{conv1,conv2,fc1,fc2}\n",
"}\n",
"\n",
"func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {\n",
" outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)\n",
" defer outView1.MustDrop()\n",
" outC1 := outView1.Apply(n.conv1)\n",
" outMP1 := outC1.MaxPool2DDefault(2, true)\n",
" defer outMP1.MustDrop()\n",
" outC2 := outMP1.Apply(n.conv2)\n",
" outMP2 := outC2.MaxPool2DDefault(2, true)\n",
" outView2 := outMP2.MustView([]int64{-1, 1024}, true)\n",
" defer outView2.MustDrop()\n",
" outFC1 := outView2.Apply(n.fc1)\n",
" outRelu := outFC1.MustRelu(true)\n",
" defer outRelu.MustDrop()\n",
" outDropout := ts.MustDropout(outRelu, 0.5, train)\n",
" defer outDropout.MustDrop()\n",
" return outDropout.Apply(n.fc2)\n",
"}\n",
"\n",
"func trainCNN(){\n",
" var ds *vision.Dataset\n",
" ds = vision.LoadMNISTDir(MnistDirCNN)\n",
" testImages := ds.TestImages\n",
" testLabels := ds.TestLabels\n",
"\n",
" cuda := gotch.CudaBuilder(0)\n",
" vs := nn.NewVarStore(cuda.CudaIfAvailable())\n",
"\n",
" var cnn Net = newNet(vs.Root())\n",
" opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)\n",
" if err != nil {fmt.Print(err)}\n",
"\n",
" var bestAccuracy float64 = 0.0\n",
" startTime := time.Now()\n",
"\n",
" for epoch := 0; epoch < epochsCNN; epoch++ {\n",
" totalSize := ds.TrainImages.MustSize()[0]\n",
" samples := int(totalSize)\n",
" index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)\n",
" imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)\n",
" labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)\n",
"\n",
" batches := samples / batchSize\n",
" batchIndex := 0\n",
" var epocLoss *ts.Tensor\n",
" for i := 0; i < batches; i++ {\n",
" start := batchIndex * batchSize\n",
" size := batchSize\n",
" if samples-start < batchSize {break}\n",
" batchIndex += 1\n",
"\n",
" // Indexing\n",
" narrowIndex := ts.NewNarrow(int64(start), int64(start+size))\n",
" bImages := imagesTs.Idx(narrowIndex)\n",
" bLabels := labelsTs.Idx(narrowIndex)\n",
"\n",
" bImages = bImages.MustTo(vs.Device(), true)\n",
" bLabels = bLabels.MustTo(vs.Device(), true)\n",
"\n",
" logits := cnn.ForwardT(bImages, true)\n",
" loss := logits.CrossEntropyForLogits(bLabels)\n",
"\n",
" opt.BackwardStep(loss)\n",
"\n",
" epocLoss = loss.MustShallowClone()\n",
" epocLoss.Detach_()\n",
"\n",
" bImages.MustDrop()\n",
" bLabels.MustDrop()\n",
" }\n",
"\n",
" testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)\n",
" fmt.Printf(\"Epoch: %v\\t Loss: %.2f \\t Test accuracy: %.2f%%\\n\", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)\n",
" if testAccuracy > bestAccuracy {\n",
" bestAccuracy = testAccuracy\n",
" }\n",
"\n",
" epocLoss.MustDrop()\n",
" imagesTs.MustDrop()\n",
" labelsTs.MustDrop()\n",
" }\n",
"\n",
" fmt.Printf(\"Best test accuracy: %.2f%%\\n\", bestAccuracy*100.0)\n",
" fmt.Printf(\"Taken time:\\t%.2f mins\\n\", time.Since(startTime).Minutes())\n",
"}"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dIwh8dLLOzJ9"
},
"source": [
"## Run train and evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0AW2nIHjKA9W",
"outputId": "0123b6df-7edc-4f35-b874-2b4f1ec69e54",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"trainCNN()"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 0\t Loss: 0.24 \t Test accuracy: 93.34%\n",
"Epoch: 1\t Loss: 0.19 \t Test accuracy: 96.02%\n",
"Epoch: 2\t Loss: 0.09 \t Test accuracy: 97.10%\n",
"Epoch: 3\t Loss: 0.10 \t Test accuracy: 97.66%\n",
"Epoch: 4\t Loss: 0.03 \t Test accuracy: 98.13%\n",
"Epoch: 5\t Loss: 0.05 \t Test accuracy: 98.43%\n",
"Epoch: 6\t Loss: 0.09 \t Test accuracy: 98.60%\n",
"Epoch: 7\t Loss: 0.05 \t Test accuracy: 98.80%\n",
"Epoch: 8\t Loss: 0.03 \t Test accuracy: 98.80%\n",
"Epoch: 9\t Loss: 0.05 \t Test accuracy: 98.89%\n",
"Epoch: 10\t Loss: 0.02 \t Test accuracy: 98.88%\n",
"Epoch: 11\t Loss: 0.03 \t Test accuracy: 98.98%\n",
"Epoch: 12\t Loss: 0.03 \t Test accuracy: 99.05%\n",
"Epoch: 13\t Loss: 0.04 \t Test accuracy: 99.06%\n",
"Epoch: 14\t Loss: 0.02 \t Test accuracy: 99.07%\n",
"Epoch: 15\t Loss: 0.02 \t Test accuracy: 98.98%\n",
"Epoch: 16\t Loss: 0.02 \t Test accuracy: 99.06%\n",
"Epoch: 17\t Loss: 0.01 \t Test accuracy: 99.09%\n",
"Epoch: 18\t Loss: 0.02 \t Test accuracy: 99.14%\n",
"Epoch: 19\t Loss: 0.01 \t Test accuracy: 99.09%\n",
"Epoch: 20\t Loss: 0.02 \t Test accuracy: 99.12%\n",
"Epoch: 21\t Loss: 0.03 \t Test accuracy: 99.13%\n",
"Epoch: 22\t Loss: 0.02 \t Test accuracy: 99.11%\n",
"Epoch: 23\t Loss: 0.01 \t Test accuracy: 99.10%\n",
"Epoch: 24\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 25\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 26\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 27\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 28\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 29\t Loss: 0.00 \t Test accuracy: 99.09%\n",
"Epoch: 30\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 31\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 32\t Loss: 0.00 \t Test accuracy: 99.15%\n",
"Epoch: 33\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 34\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 35\t Loss: 0.00 \t Test accuracy: 99.16%\n",
"Epoch: 36\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 37\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 38\t Loss: 0.00 \t Test accuracy: 99.15%\n",
"Epoch: 39\t Loss: 0.01 \t Test accuracy: 99.19%\n",
"Epoch: 40\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 41\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 42\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 43\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 44\t Loss: 0.00 \t Test accuracy: 99.14%\n",
"Epoch: 45\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 46\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 47\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 48\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 49\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 50\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 51\t Loss: 0.01 \t Test accuracy: 99.14%\n",
"Epoch: 52\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 53\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 54\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 55\t Loss: 0.02 \t Test accuracy: 99.24%\n",
"Epoch: 56\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 57\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 58\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 59\t Loss: 0.01 \t Test accuracy: 99.25%\n",
"Epoch: 60\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 61\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 62\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 63\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 64\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 65\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 66\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 67\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 68\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 69\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 70\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 71\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 72\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 73\t Loss: 0.00 \t Test accuracy: 99.32%\n",
"Epoch: 74\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 75\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 76\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 77\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 78\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 79\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 80\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 81\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 82\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 83\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 84\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 85\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 86\t Loss: 0.00 \t Test accuracy: 99.29%\n",
"Epoch: 87\t Loss: 0.00 \t Test accuracy: 99.33%\n",
"Epoch: 88\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 89\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 90\t Loss: 0.00 \t Test accuracy: 99.35%\n",
"Epoch: 91\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 92\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 93\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 94\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 95\t Loss: 0.02 \t Test accuracy: 99.25%\n",
"Epoch: 96\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 97\t Loss: 0.00 \t Test accuracy: 99.29%\n",
"Epoch: 98\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 99\t Loss: 0.00 \t Test accuracy: 99.20%\n",
"Best test accuracy: 99.35%\n",
"Taken time:\t1.97 mins\n"
],
"name": "stdout"
}
]
}
]
}