diff --git a/README.md b/README.md
index cd3f44f..4d8d80d 100644
--- a/README.md
+++ b/README.md
@@ -53,7 +53,6 @@ func basicOps() {
fmt.Printf("%8.3f\n", xs)
fmt.Printf("%i", xs)
- // output
/*
(1,.,.) =
0.391 0.055 0.638 0.514 0.757 0.446
@@ -93,46 +92,42 @@ func basicOps() {
mul := ts1.MustMatmul(ts2, false)
defer mul.MustDrop()
- fmt.Println("ts1: ")
- ts1.Print()
- fmt.Println("ts2: ")
- ts2.Print()
- fmt.Println("mul tensor (ts1 x ts2): ")
- mul.Print()
- //ts1:
- // 0 1 2
- // 3 4 5
- //[ CPULongType{2,3} ]
- //ts2:
- // 1 1 1 1
- // 1 1 1 1
- // 1 1 1 1
- //[ CPULongType{3,4} ]
- //mul tensor (ts1 x ts2):
- // 3 3 3 3
- // 12 12 12 12
- //[ CPULongType{2,4} ]
+ fmt.Printf("ts1:\n%2d", ts1)
+ fmt.Printf("ts2:\n%2d", ts2)
+ fmt.Printf("mul tensor (ts1 x ts2):\n%2d", mul)
+
+ /*
+ ts1:
+ 0 1 2
+ 3 4 5
+
+ ts2:
+ 1 1 1 1
+ 1 1 1 1
+ 1 1 1 1
+
+ mul tensor (ts1 x ts2):
+ 3 3 3 3
+ 12 12 12 12
+ */
// In-place operation
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
- fmt.Println("Before:")
- ts3.Print()
+ fmt.Printf("Before:\n%v", ts3)
ts3.MustAdd1_(ts.FloatScalar(2.0))
- fmt.Printf("After (ts3 + 2.0): \n")
- ts3.Print()
- ts3.MustDrop()
+ fmt.Printf("After (ts3 + 2.0):\n%v", ts3)
- //Before:
- // 1 1 1
- // 1 1 1
- //[ CPUFloatType{2,3} ]
- //After (ts3 + 2.0):
- // 3 3 3
- // 3 3 3
- //[ CPUFloatType{2,3} ]
+ /*
+ Before:
+ 1 1 1
+ 1 1 1
+ After (ts3 + 2.0):
+ 3 3 3
+ 3 3 3
+ */
}
```
@@ -142,30 +137,32 @@ func basicOps() {
```go
import (
+ "fmt"
+
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
type Net struct {
- conv1 nn.Conv2D
- conv2 nn.Conv2D
- fc nn.Linear
+ conv1 *nn.Conv2D
+ conv2 *nn.Conv2D
+ fc *nn.Linear
}
- func newNet(vs nn.Path) Net {
+ func newNet(vs *nn.Path) *Net {
conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig())
fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig())
- return Net{
+ return &Net{
conv1,
conv2,
fc,
}
}
- func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
+ func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
xs = xs.MustView([]int64{-1, 1, 8, 8}, false)
outC1 := xs.Apply(n.conv1)
@@ -177,10 +174,8 @@ func basicOps() {
outView2 := outMP2.MustView([]int64{-1, 10}, true)
defer outView2.MustDrop()
- outFC := outView2.Apply(&n.fc)
-
+ outFC := outView2.Apply(n.fc)
return outFC.MustRelu(true)
-
}
func main() {
@@ -191,29 +186,33 @@ func basicOps() {
xs := ts.MustOnes([]int64{8, 8}, gotch.Float, gotch.CPU)
logits := net.ForwardT(xs, false)
- logits.Print()
+ fmt.Printf("Logits: %0.3f", logits)
}
- // 0.0000 0.0000 0.0000 0.2477 0.2437 0.0000 0.0000 0.0000 0.0000 0.0171
- //[ CPUFloatType{1,10} ]
-
-
+ //Logits: 0.000 0.000 0.000 0.225 0.321 0.147 0.000 0.207 0.000 0.000
```
- Real application examples can be found at [example folder](example/README.md)
+## Play with GoTch on Google Colab
+
+1. [Tensor Initiation](tensor/tensor-initiation.ipynb)
+2. [Tensor Indexing](tensor/tensor-indexing.ipynb)
+3. [MNIST](mnist)
+4. [Tokenizer - BPE model](tokenizer/bpe.ipynb)
+5. [transformer - BERT Mask Language Model](transformer/bert-mask-lm.ipynb)
+6. [YOLO v3 model infering](yolo/yolo.ipynb)
+
+More coming soon...
+
## Getting Started
-- [Documentations](docs/README.md)
-
- See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc) for detail APIs
-
## License
GoTch is Apache 2.0 licensed.
-
## Acknowledgement
- This project has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)
diff --git a/example/mnist/GoTch_MNIST_CNN.ipynb b/example/mnist/GoTch_MNIST_CNN.ipynb
deleted file mode 100644
index c7521b1..0000000
--- a/example/mnist/GoTch_MNIST_CNN.ipynb
+++ /dev/null
@@ -1,496 +0,0 @@
-{
- "nbformat": 4,
- "nbformat_minor": 0,
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "name": "GoTch-MNIST-CNN.ipynb",
- "provenance": [],
- "collapsed_sections": [
- "fsitX3NbLbTB",
- "l6aOKEMarRNH",
- "mWe6_MtPK8Kh"
- ],
- "toc_visible": true,
- "include_colab_link": true
- },
- "kernelspec": {
- "display_name": "Go",
- "name": "gophernotes"
- }
- },
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "view-in-github",
- "colab_type": "text"
- },
- "source": [
- "
"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dkq9izrfMXPV"
- },
- "source": [
- "# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding\n",
- "\n",
- "This notebook using\n",
- "\n",
- "1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)\n",
- "2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)\n",
- "3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "fsitX3NbLbTB"
- },
- "source": [
- "## Install Go kernel - GopherNotes\n",
- "\n",
- "*NOTE: refresh/reload (browser) after this step.*"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "caT1iMfshw62",
- "outputId": "5be1cd7d-c12e-4083-d523-a035a9fc06a6",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "source": [
- "# run this cell first time using python runtime\n",
- "!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null\n",
- "!apt update > /dev/null \n",
- "!apt install golang-go > /dev/null\n",
- "%env GOPATH=/root/go\n",
- "!go get -u github.com/gopherdata/gophernotes\n",
- "!cp ~/go/bin/gophernotes /usr/bin/\n",
- "!mkdir /usr/local/share/jupyter/kernels/gophernotes\n",
- "!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \\\n",
- " /usr/local/share/jupyter/kernels/gophernotes\n",
- "# then refresh (browser), it will now use gophernotes. Skip to golang in later cells"
- ],
- "execution_count": null,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "\n",
- "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
- "\n",
- "\n",
- "WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
- "\n",
- "env: GOPATH=/root/go\n"
- ],
- "name": "stdout"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "l6aOKEMarRNH"
- },
- "source": [
- "## Install Pytorch C++ APIs and Go binding - GoTch\n",
- "\n",
- "NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library\n",
- "\n",
- "see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6\n",
- "\n",
- "Google Colab default settings:\n",
- "```bash\n",
- "LD_LIBRARY_PATH=/usr/lib64-nvidia\n",
- "LIBRARY_PATH=/usr/local/cuda/lib64/stubs\n",
- "```\n",
- "We copy directly `libtorch/lib` to those paths as a hacky way. "
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "0N_jGbVft2Y7",
- "outputId": "e2052bf1-862f-49e7-ebed-23f09f7fb9e4",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "source": [
- "$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip\n",
- "$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local\n",
- "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/\n",
- "$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/"
- ],
- "execution_count": 1,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "/tmp/libtorch-cxx11 100%[===================>] 765.61M 22.4MB/s in 29s \n"
- ],
- "name": "stderr"
- }
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "dVo3An-pGzFE"
- },
- "source": [
- "import(\"os\")\n",
- "os.Setenv(\"CPATH\", \"usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include\")"
- ],
- "execution_count": 2,
- "outputs": []
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "x_9y6CpZGlPJ",
- "outputId": "79b538ff-9320-4097-8e76-43160253a60d",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "source": [
- "$rm -f -- go.mod\n",
- "$go mod init github.com/sugarme/playgo\n",
- "$go get github.com/sugarme/gotch@v0.3.2"
- ],
- "execution_count": 4,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "go: creating new go.mod: module github.com/sugarme/playgo\n",
- "go: downloading github.com/sugarme/gotch v0.3.2\n"
- ],
- "name": "stderr"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "mWe6_MtPK8Kh"
- },
- "source": [
- "## Download MNIST dataset"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "sYyKlDWchSv3",
- "outputId": "ba7b2bb6-f42f-4dbd-c96e-633f4b66e0f1",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "source": [
- "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
- "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
- "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
- "$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
- "\n",
- "$gunzip train-images-idx3-ubyte.gz\n",
- "$gunzip train-labels-idx1-ubyte.gz\n",
- "$gunzip t10k-images-idx3-ubyte.gz\n",
- "$gunzip t10k-labels-idx1-ubyte.gz"
- ],
- "execution_count": 5,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "train-images-idx3-u 100%[===================>] 9.45M 11.2MB/s in 0.8s \n",
- "train-labels-idx1-u 100%[===================>] 28.20K --.-KB/s in 0.06s \n",
- "t10k-images-idx3-ub 100%[===================>] 1.57M 2.03MB/s in 0.8s \n",
- "t10k-labels-idx1-ub 100%[===================>] 4.44K --.-KB/s in 0s \n"
- ],
- "name": "stderr"
- }
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7MzMZDHiLJQR"
- },
- "source": [
- "## Create Convolution Neural Network (CNN)"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "YkLmNrswJHYY"
- },
- "source": [
- "import(\n",
- " \"fmt\"\n",
- " \"time\"\n",
- "\n",
- " \"github.com/sugarme/gotch\"\n",
- " \"github.com/sugarme/gotch/nn\"\n",
- " ts \"github.com/sugarme/gotch/tensor\"\n",
- " \"github.com/sugarme/gotch/vision\"\n",
- ") \n",
- "\n",
- "const (\n",
- " MnistDirCNN string = \"./\"\n",
- " epochsCNN = 100\n",
- " batchCNN = 256\n",
- " batchSize = 256\n",
- " LrCNN = 1e-4\n",
- ")\n",
- "\n",
- "type Net struct {\n",
- " conv1 *nn.Conv2D\n",
- " conv2 *nn.Conv2D\n",
- " fc1 *nn.Linear\n",
- " fc2 *nn.Linear\n",
- "}\n",
- "\n",
- "func newNet(vs *nn.Path) Net {\n",
- " conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())\n",
- " conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())\n",
- " fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())\n",
- " fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())\n",
- "\n",
- " return Net{conv1,conv2,fc1,fc2}\n",
- "}\n",
- "\n",
- "func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {\n",
- " outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)\n",
- " defer outView1.MustDrop()\n",
- " outC1 := outView1.Apply(n.conv1)\n",
- " outMP1 := outC1.MaxPool2DDefault(2, true)\n",
- " defer outMP1.MustDrop()\n",
- " outC2 := outMP1.Apply(n.conv2)\n",
- " outMP2 := outC2.MaxPool2DDefault(2, true)\n",
- " outView2 := outMP2.MustView([]int64{-1, 1024}, true)\n",
- " defer outView2.MustDrop()\n",
- " outFC1 := outView2.Apply(n.fc1)\n",
- " outRelu := outFC1.MustRelu(true)\n",
- " defer outRelu.MustDrop()\n",
- " outDropout := ts.MustDropout(outRelu, 0.5, train)\n",
- " defer outDropout.MustDrop()\n",
- " return outDropout.Apply(n.fc2)\n",
- "}\n",
- "\n",
- "func trainCNN(){\n",
- " var ds *vision.Dataset\n",
- " ds = vision.LoadMNISTDir(MnistDirCNN)\n",
- " testImages := ds.TestImages\n",
- " testLabels := ds.TestLabels\n",
- "\n",
- " cuda := gotch.CudaBuilder(0)\n",
- " vs := nn.NewVarStore(cuda.CudaIfAvailable())\n",
- "\n",
- " var cnn Net = newNet(vs.Root())\n",
- " opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)\n",
- " if err != nil {fmt.Print(err)}\n",
- "\n",
- " var bestAccuracy float64 = 0.0\n",
- " startTime := time.Now()\n",
- "\n",
- " for epoch := 0; epoch < epochsCNN; epoch++ {\n",
- " totalSize := ds.TrainImages.MustSize()[0]\n",
- " samples := int(totalSize)\n",
- " index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)\n",
- " imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)\n",
- " labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)\n",
- "\n",
- " batches := samples / batchSize\n",
- " batchIndex := 0\n",
- " var epocLoss *ts.Tensor\n",
- " for i := 0; i < batches; i++ {\n",
- " start := batchIndex * batchSize\n",
- " size := batchSize\n",
- " if samples-start < batchSize {break}\n",
- " batchIndex += 1\n",
- "\n",
- " // Indexing\n",
- " narrowIndex := ts.NewNarrow(int64(start), int64(start+size))\n",
- " bImages := imagesTs.Idx(narrowIndex)\n",
- " bLabels := labelsTs.Idx(narrowIndex)\n",
- "\n",
- " bImages = bImages.MustTo(vs.Device(), true)\n",
- " bLabels = bLabels.MustTo(vs.Device(), true)\n",
- "\n",
- " logits := cnn.ForwardT(bImages, true)\n",
- " loss := logits.CrossEntropyForLogits(bLabels)\n",
- "\n",
- " opt.BackwardStep(loss)\n",
- "\n",
- " epocLoss = loss.MustShallowClone()\n",
- " epocLoss.Detach_()\n",
- "\n",
- " bImages.MustDrop()\n",
- " bLabels.MustDrop()\n",
- " }\n",
- "\n",
- " testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)\n",
- " fmt.Printf(\"Epoch: %v\\t Loss: %.2f \\t Test accuracy: %.2f%%\\n\", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)\n",
- " if testAccuracy > bestAccuracy {\n",
- " bestAccuracy = testAccuracy\n",
- " }\n",
- "\n",
- " epocLoss.MustDrop()\n",
- " imagesTs.MustDrop()\n",
- " labelsTs.MustDrop()\n",
- " }\n",
- "\n",
- " fmt.Printf(\"Best test accuracy: %.2f%%\\n\", bestAccuracy*100.0)\n",
- " fmt.Printf(\"Taken time:\\t%.2f mins\\n\", time.Since(startTime).Minutes())\n",
- "}"
- ],
- "execution_count": 6,
- "outputs": []
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dIwh8dLLOzJ9"
- },
- "source": [
- "## Run train and evaluation"
- ]
- },
- {
- "cell_type": "code",
- "metadata": {
- "id": "0AW2nIHjKA9W",
- "outputId": "0123b6df-7edc-4f35-b874-2b4f1ec69e54",
- "colab": {
- "base_uri": "https://localhost:8080/"
- }
- },
- "source": [
- "trainCNN()"
- ],
- "execution_count": 7,
- "outputs": [
- {
- "output_type": "stream",
- "text": [
- "Epoch: 0\t Loss: 0.24 \t Test accuracy: 93.34%\n",
- "Epoch: 1\t Loss: 0.19 \t Test accuracy: 96.02%\n",
- "Epoch: 2\t Loss: 0.09 \t Test accuracy: 97.10%\n",
- "Epoch: 3\t Loss: 0.10 \t Test accuracy: 97.66%\n",
- "Epoch: 4\t Loss: 0.03 \t Test accuracy: 98.13%\n",
- "Epoch: 5\t Loss: 0.05 \t Test accuracy: 98.43%\n",
- "Epoch: 6\t Loss: 0.09 \t Test accuracy: 98.60%\n",
- "Epoch: 7\t Loss: 0.05 \t Test accuracy: 98.80%\n",
- "Epoch: 8\t Loss: 0.03 \t Test accuracy: 98.80%\n",
- "Epoch: 9\t Loss: 0.05 \t Test accuracy: 98.89%\n",
- "Epoch: 10\t Loss: 0.02 \t Test accuracy: 98.88%\n",
- "Epoch: 11\t Loss: 0.03 \t Test accuracy: 98.98%\n",
- "Epoch: 12\t Loss: 0.03 \t Test accuracy: 99.05%\n",
- "Epoch: 13\t Loss: 0.04 \t Test accuracy: 99.06%\n",
- "Epoch: 14\t Loss: 0.02 \t Test accuracy: 99.07%\n",
- "Epoch: 15\t Loss: 0.02 \t Test accuracy: 98.98%\n",
- "Epoch: 16\t Loss: 0.02 \t Test accuracy: 99.06%\n",
- "Epoch: 17\t Loss: 0.01 \t Test accuracy: 99.09%\n",
- "Epoch: 18\t Loss: 0.02 \t Test accuracy: 99.14%\n",
- "Epoch: 19\t Loss: 0.01 \t Test accuracy: 99.09%\n",
- "Epoch: 20\t Loss: 0.02 \t Test accuracy: 99.12%\n",
- "Epoch: 21\t Loss: 0.03 \t Test accuracy: 99.13%\n",
- "Epoch: 22\t Loss: 0.02 \t Test accuracy: 99.11%\n",
- "Epoch: 23\t Loss: 0.01 \t Test accuracy: 99.10%\n",
- "Epoch: 24\t Loss: 0.01 \t Test accuracy: 99.15%\n",
- "Epoch: 25\t Loss: 0.01 \t Test accuracy: 99.15%\n",
- "Epoch: 26\t Loss: 0.01 \t Test accuracy: 99.22%\n",
- "Epoch: 27\t Loss: 0.00 \t Test accuracy: 99.27%\n",
- "Epoch: 28\t Loss: 0.01 \t Test accuracy: 99.15%\n",
- "Epoch: 29\t Loss: 0.00 \t Test accuracy: 99.09%\n",
- "Epoch: 30\t Loss: 0.01 \t Test accuracy: 99.15%\n",
- "Epoch: 31\t Loss: 0.00 \t Test accuracy: 99.18%\n",
- "Epoch: 32\t Loss: 0.00 \t Test accuracy: 99.15%\n",
- "Epoch: 33\t Loss: 0.00 \t Test accuracy: 99.23%\n",
- "Epoch: 34\t Loss: 0.01 \t Test accuracy: 99.22%\n",
- "Epoch: 35\t Loss: 0.00 \t Test accuracy: 99.16%\n",
- "Epoch: 36\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 37\t Loss: 0.01 \t Test accuracy: 99.26%\n",
- "Epoch: 38\t Loss: 0.00 \t Test accuracy: 99.15%\n",
- "Epoch: 39\t Loss: 0.01 \t Test accuracy: 99.19%\n",
- "Epoch: 40\t Loss: 0.01 \t Test accuracy: 99.24%\n",
- "Epoch: 41\t Loss: 0.01 \t Test accuracy: 99.24%\n",
- "Epoch: 42\t Loss: 0.01 \t Test accuracy: 99.22%\n",
- "Epoch: 43\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 44\t Loss: 0.00 \t Test accuracy: 99.14%\n",
- "Epoch: 45\t Loss: 0.00 \t Test accuracy: 99.23%\n",
- "Epoch: 46\t Loss: 0.00 \t Test accuracy: 99.24%\n",
- "Epoch: 47\t Loss: 0.00 \t Test accuracy: 99.21%\n",
- "Epoch: 48\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 49\t Loss: 0.00 \t Test accuracy: 99.21%\n",
- "Epoch: 50\t Loss: 0.01 \t Test accuracy: 99.26%\n",
- "Epoch: 51\t Loss: 0.01 \t Test accuracy: 99.14%\n",
- "Epoch: 52\t Loss: 0.01 \t Test accuracy: 99.26%\n",
- "Epoch: 53\t Loss: 0.00 \t Test accuracy: 99.23%\n",
- "Epoch: 54\t Loss: 0.00 \t Test accuracy: 99.28%\n",
- "Epoch: 55\t Loss: 0.02 \t Test accuracy: 99.24%\n",
- "Epoch: 56\t Loss: 0.00 \t Test accuracy: 99.26%\n",
- "Epoch: 57\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 58\t Loss: 0.01 \t Test accuracy: 99.22%\n",
- "Epoch: 59\t Loss: 0.01 \t Test accuracy: 99.25%\n",
- "Epoch: 60\t Loss: 0.00 \t Test accuracy: 99.26%\n",
- "Epoch: 61\t Loss: 0.01 \t Test accuracy: 99.24%\n",
- "Epoch: 62\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 63\t Loss: 0.00 \t Test accuracy: 99.24%\n",
- "Epoch: 64\t Loss: 0.00 \t Test accuracy: 99.27%\n",
- "Epoch: 65\t Loss: 0.00 \t Test accuracy: 99.27%\n",
- "Epoch: 66\t Loss: 0.00 \t Test accuracy: 99.31%\n",
- "Epoch: 67\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 68\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 69\t Loss: 0.00 \t Test accuracy: 99.31%\n",
- "Epoch: 70\t Loss: 0.00 \t Test accuracy: 99.24%\n",
- "Epoch: 71\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 72\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 73\t Loss: 0.00 \t Test accuracy: 99.32%\n",
- "Epoch: 74\t Loss: 0.00 \t Test accuracy: 99.26%\n",
- "Epoch: 75\t Loss: 0.00 \t Test accuracy: 99.21%\n",
- "Epoch: 76\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 77\t Loss: 0.01 \t Test accuracy: 99.24%\n",
- "Epoch: 78\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 79\t Loss: 0.00 \t Test accuracy: 99.31%\n",
- "Epoch: 80\t Loss: 0.00 \t Test accuracy: 99.18%\n",
- "Epoch: 81\t Loss: 0.00 \t Test accuracy: 99.31%\n",
- "Epoch: 82\t Loss: 0.00 \t Test accuracy: 99.28%\n",
- "Epoch: 83\t Loss: 0.00 \t Test accuracy: 99.18%\n",
- "Epoch: 84\t Loss: 0.01 \t Test accuracy: 99.24%\n",
- "Epoch: 85\t Loss: 0.00 \t Test accuracy: 99.24%\n",
- "Epoch: 86\t Loss: 0.00 \t Test accuracy: 99.29%\n",
- "Epoch: 87\t Loss: 0.00 \t Test accuracy: 99.33%\n",
- "Epoch: 88\t Loss: 0.00 \t Test accuracy: 99.28%\n",
- "Epoch: 89\t Loss: 0.00 \t Test accuracy: 99.24%\n",
- "Epoch: 90\t Loss: 0.00 \t Test accuracy: 99.35%\n",
- "Epoch: 91\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 92\t Loss: 0.00 \t Test accuracy: 99.23%\n",
- "Epoch: 93\t Loss: 0.00 \t Test accuracy: 99.22%\n",
- "Epoch: 94\t Loss: 0.00 \t Test accuracy: 99.25%\n",
- "Epoch: 95\t Loss: 0.02 \t Test accuracy: 99.25%\n",
- "Epoch: 96\t Loss: 0.00 \t Test accuracy: 99.26%\n",
- "Epoch: 97\t Loss: 0.00 \t Test accuracy: 99.29%\n",
- "Epoch: 98\t Loss: 0.00 \t Test accuracy: 99.23%\n",
- "Epoch: 99\t Loss: 0.00 \t Test accuracy: 99.20%\n",
- "Best test accuracy: 99.35%\n",
- "Taken time:\t1.97 mins\n"
- ],
- "name": "stdout"
- }
- ]
- }
- ]
-}
\ No newline at end of file
diff --git a/example/mnist/README.md b/example/mnist/README.md
index 9fd67dd..4502bef 100644
--- a/example/mnist/README.md
+++ b/example/mnist/README.md
@@ -1,5 +1,7 @@
# Linear Regression, NN, and CNN on MNIST dataset
+[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/nb/blob/master/mnist/mnist.ipynb)
+
## MNIST
- MNIST files can be obtained from [this source](http://yann.lecun.com/exdb/mnist/) and put in `data/mnist` from
@@ -29,8 +31,6 @@
- Accuracy should be about **99.3%**.
-## Run on Cloud
-
-[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/gotch_mnist.ipynb)
+