updated example code on README and added Colab links

This commit is contained in:
sugarme 2020-11-15 10:32:37 +11:00
parent c0da658e6b
commit 992ff860d4
3 changed files with 53 additions and 550 deletions

101
README.md
View File

@ -53,7 +53,6 @@ func basicOps() {
fmt.Printf("%8.3f\n", xs)
fmt.Printf("%i", xs)
// output
/*
(1,.,.) =
0.391 0.055 0.638 0.514 0.757 0.446
@ -93,46 +92,42 @@ func basicOps() {
mul := ts1.MustMatmul(ts2, false)
defer mul.MustDrop()
fmt.Println("ts1: ")
ts1.Print()
fmt.Println("ts2: ")
ts2.Print()
fmt.Println("mul tensor (ts1 x ts2): ")
mul.Print()
//ts1:
// 0 1 2
// 3 4 5
//[ CPULongType{2,3} ]
//ts2:
// 1 1 1 1
// 1 1 1 1
// 1 1 1 1
//[ CPULongType{3,4} ]
//mul tensor (ts1 x ts2):
// 3 3 3 3
// 12 12 12 12
//[ CPULongType{2,4} ]
fmt.Printf("ts1:\n%2d", ts1)
fmt.Printf("ts2:\n%2d", ts2)
fmt.Printf("mul tensor (ts1 x ts2):\n%2d", mul)
/*
ts1:
0 1 2
3 4 5
ts2:
1 1 1 1
1 1 1 1
1 1 1 1
mul tensor (ts1 x ts2):
3 3 3 3
12 12 12 12
*/
// In-place operation
ts3 := ts.MustOnes([]int64{2, 3}, gotch.Float, gotch.CPU)
fmt.Println("Before:")
ts3.Print()
fmt.Printf("Before:\n%v", ts3)
ts3.MustAdd1_(ts.FloatScalar(2.0))
fmt.Printf("After (ts3 + 2.0): \n")
ts3.Print()
ts3.MustDrop()
fmt.Printf("After (ts3 + 2.0):\n%v", ts3)
//Before:
// 1 1 1
// 1 1 1
//[ CPUFloatType{2,3} ]
//After (ts3 + 2.0):
// 3 3 3
// 3 3 3
//[ CPUFloatType{2,3} ]
/*
Before:
1 1 1
1 1 1
After (ts3 + 2.0):
3 3 3
3 3 3
*/
}
```
@ -142,30 +137,32 @@ func basicOps() {
```go
import (
"fmt"
"github.com/sugarme/gotch"
"github.com/sugarme/gotch/nn"
ts "github.com/sugarme/gotch/tensor"
)
type Net struct {
conv1 nn.Conv2D
conv2 nn.Conv2D
fc nn.Linear
conv1 *nn.Conv2D
conv2 *nn.Conv2D
fc *nn.Linear
}
func newNet(vs nn.Path) Net {
func newNet(vs *nn.Path) *Net {
conv1 := nn.NewConv2D(vs, 1, 16, 2, nn.DefaultConv2DConfig())
conv2 := nn.NewConv2D(vs, 16, 10, 2, nn.DefaultConv2DConfig())
fc := nn.NewLinear(vs, 10, 10, nn.DefaultLinearConfig())
return Net{
return &Net{
conv1,
conv2,
fc,
}
}
func (n Net) ForwardT(xs ts.Tensor, train bool) (retVal ts.Tensor) {
func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
xs = xs.MustView([]int64{-1, 1, 8, 8}, false)
outC1 := xs.Apply(n.conv1)
@ -177,10 +174,8 @@ func basicOps() {
outView2 := outMP2.MustView([]int64{-1, 10}, true)
defer outView2.MustDrop()
outFC := outView2.Apply(&n.fc)
outFC := outView2.Apply(n.fc)
return outFC.MustRelu(true)
}
func main() {
@ -191,29 +186,33 @@ func basicOps() {
xs := ts.MustOnes([]int64{8, 8}, gotch.Float, gotch.CPU)
logits := net.ForwardT(xs, false)
logits.Print()
fmt.Printf("Logits: %0.3f", logits)
}
// 0.0000 0.0000 0.0000 0.2477 0.2437 0.0000 0.0000 0.0000 0.0000 0.0171
//[ CPUFloatType{1,10} ]
//Logits: 0.000 0.000 0.000 0.225 0.321 0.147 0.000 0.207 0.000 0.000
```
- Real application examples can be found at [example folder](example/README.md)
## Play with GoTch on Google Colab
1. [Tensor Initiation](tensor/tensor-initiation.ipynb) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/tensor/tensor-initiation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
2. [Tensor Indexing](tensor/tensor-indexing.ipynb) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/tensor/tensor-indexing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
3. [MNIST](mnist) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/mnist/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
4. [Tokenizer - BPE model](tokenizer/bpe.ipynb) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/tokenizer/bpe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
5. [transformer - BERT Mask Language Model](transformer/bert-mask-lm.ipynb) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/transformer/bert-mask-lm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
6. [YOLO v3 model infering](yolo/yolo.ipynb) <a href="https://colab.research.google.com/github/sugarme/nb/blob/master/yolo/yolo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
More coming soon...
## Getting Started
- [Documentations](docs/README.md)
- See [pkg.go.dev](https://pkg.go.dev/github.com/sugarme/gotch?tab=doc) for detail APIs
## License
GoTch is Apache 2.0 licensed.
## Acknowledgement
- This project has been inspired and used many concepts from [tch-rs](https://github.com/LaurentMazare/tch-rs)

View File

@ -1,496 +0,0 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "GoTch-MNIST-CNN.ipynb",
"provenance": [],
"collapsed_sections": [
"fsitX3NbLbTB",
"l6aOKEMarRNH",
"mWe6_MtPK8Kh"
],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Go",
"name": "gophernotes"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/GoTch_MNIST_CNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dkq9izrfMXPV"
},
"source": [
"# MNIST CNN Training Using GoTch - Pytorch C++ APIs Go Binding\n",
"\n",
"This notebook using\n",
"\n",
"1. [GoTch - Pytorch C++ APIs Go bindind](https://github.com/sugarme/gotch)\n",
"2. [GopherNotes - Jupyter Notebook Go kernel](https://github.com/gopherdata/gophernotes)\n",
"3. [MNIST dataset](http://yann.lecun.com/exdb/mnist/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fsitX3NbLbTB"
},
"source": [
"## Install Go kernel - GopherNotes\n",
"\n",
"*NOTE: refresh/reload (browser) after this step.*"
]
},
{
"cell_type": "code",
"metadata": {
"id": "caT1iMfshw62",
"outputId": "5be1cd7d-c12e-4083-d523-a035a9fc06a6",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"# run this cell first time using python runtime\n",
"!add-apt-repository ppa:longsleep/golang-backports -y > /dev/null\n",
"!apt update > /dev/null \n",
"!apt install golang-go > /dev/null\n",
"%env GOPATH=/root/go\n",
"!go get -u github.com/gopherdata/gophernotes\n",
"!cp ~/go/bin/gophernotes /usr/bin/\n",
"!mkdir /usr/local/share/jupyter/kernels/gophernotes\n",
"!cp ~/go/src/github.com/gopherdata/gophernotes/kernel/* \\\n",
" /usr/local/share/jupyter/kernels/gophernotes\n",
"# then refresh (browser), it will now use gophernotes. Skip to golang in later cells"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"text": [
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"\n",
"WARNING: apt does not have a stable CLI interface. Use with caution in scripts.\n",
"\n",
"env: GOPATH=/root/go\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l6aOKEMarRNH"
},
"source": [
"## Install Pytorch C++ APIs and Go binding - GoTch\n",
"\n",
"NOTE: `ldconfig` (GLIBC) current version 2.27 is currently broken when linking Libtorch library\n",
"\n",
"see issue: https://discuss.pytorch.org/libtorch-c-so-files-truncated-error-when-ldconfig/46404/6\n",
"\n",
"Google Colab default settings:\n",
"```bash\n",
"LD_LIBRARY_PATH=/usr/lib64-nvidia\n",
"LIBRARY_PATH=/usr/local/cuda/lib64/stubs\n",
"```\n",
"We copy directly `libtorch/lib` to those paths as a hacky way. "
]
},
{
"cell_type": "code",
"metadata": {
"id": "0N_jGbVft2Y7",
"outputId": "e2052bf1-862f-49e7-ebed-23f09f7fb9e4",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$wget -q --show-progress --progress=bar:force:noscroll -O /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip https://download.pytorch.org/libtorch/cu101/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip\n",
"$unzip -qq /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip -d /usr/local\n",
"$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/lib64-nvidia/\n",
"$unzip -qq -j /tmp/libtorch-cxx11-abi-shared-with-deps-1.7.0%2Bcu101.zip libtorch/lib/* -d /usr/local/cuda/lib64/stubs/"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"/tmp/libtorch-cxx11 100%[===================>] 765.61M 22.4MB/s in 29s \n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "dVo3An-pGzFE"
},
"source": [
"import(\"os\")\n",
"os.Setenv(\"CPATH\", \"usr/local/libtorch/lib:/usr/local/libtorch/include:/usr/local/libtorch/include/torch/csrc/api/include\")"
],
"execution_count": 2,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "x_9y6CpZGlPJ",
"outputId": "79b538ff-9320-4097-8e76-43160253a60d",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$rm -f -- go.mod\n",
"$go mod init github.com/sugarme/playgo\n",
"$go get github.com/sugarme/gotch@v0.3.2"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"go: creating new go.mod: module github.com/sugarme/playgo\n",
"go: downloading github.com/sugarme/gotch v0.3.2\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mWe6_MtPK8Kh"
},
"source": [
"## Download MNIST dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "sYyKlDWchSv3",
"outputId": "ba7b2bb6-f42f-4dbd-c96e-633f4b66e0f1",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
"$wget -q --show-progress --progress=bar:force:noscroll http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
"\n",
"$gunzip train-images-idx3-ubyte.gz\n",
"$gunzip train-labels-idx1-ubyte.gz\n",
"$gunzip t10k-images-idx3-ubyte.gz\n",
"$gunzip t10k-labels-idx1-ubyte.gz"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"text": [
"train-images-idx3-u 100%[===================>] 9.45M 11.2MB/s in 0.8s \n",
"train-labels-idx1-u 100%[===================>] 28.20K --.-KB/s in 0.06s \n",
"t10k-images-idx3-ub 100%[===================>] 1.57M 2.03MB/s in 0.8s \n",
"t10k-labels-idx1-ub 100%[===================>] 4.44K --.-KB/s in 0s \n"
],
"name": "stderr"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7MzMZDHiLJQR"
},
"source": [
"## Create Convolution Neural Network (CNN)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "YkLmNrswJHYY"
},
"source": [
"import(\n",
" \"fmt\"\n",
" \"time\"\n",
"\n",
" \"github.com/sugarme/gotch\"\n",
" \"github.com/sugarme/gotch/nn\"\n",
" ts \"github.com/sugarme/gotch/tensor\"\n",
" \"github.com/sugarme/gotch/vision\"\n",
") \n",
"\n",
"const (\n",
" MnistDirCNN string = \"./\"\n",
" epochsCNN = 100\n",
" batchCNN = 256\n",
" batchSize = 256\n",
" LrCNN = 1e-4\n",
")\n",
"\n",
"type Net struct {\n",
" conv1 *nn.Conv2D\n",
" conv2 *nn.Conv2D\n",
" fc1 *nn.Linear\n",
" fc2 *nn.Linear\n",
"}\n",
"\n",
"func newNet(vs *nn.Path) Net {\n",
" conv1 := nn.NewConv2D(vs, 1, 32, 5, nn.DefaultConv2DConfig())\n",
" conv2 := nn.NewConv2D(vs, 32, 64, 5, nn.DefaultConv2DConfig())\n",
" fc1 := nn.NewLinear(vs, 1024, 1024, nn.DefaultLinearConfig())\n",
" fc2 := nn.NewLinear(vs, 1024, 10, nn.DefaultLinearConfig())\n",
"\n",
" return Net{conv1,conv2,fc1,fc2}\n",
"}\n",
"\n",
"func (n Net) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {\n",
" outView1 := xs.MustView([]int64{-1, 1, 28, 28}, false)\n",
" defer outView1.MustDrop()\n",
" outC1 := outView1.Apply(n.conv1)\n",
" outMP1 := outC1.MaxPool2DDefault(2, true)\n",
" defer outMP1.MustDrop()\n",
" outC2 := outMP1.Apply(n.conv2)\n",
" outMP2 := outC2.MaxPool2DDefault(2, true)\n",
" outView2 := outMP2.MustView([]int64{-1, 1024}, true)\n",
" defer outView2.MustDrop()\n",
" outFC1 := outView2.Apply(n.fc1)\n",
" outRelu := outFC1.MustRelu(true)\n",
" defer outRelu.MustDrop()\n",
" outDropout := ts.MustDropout(outRelu, 0.5, train)\n",
" defer outDropout.MustDrop()\n",
" return outDropout.Apply(n.fc2)\n",
"}\n",
"\n",
"func trainCNN(){\n",
" var ds *vision.Dataset\n",
" ds = vision.LoadMNISTDir(MnistDirCNN)\n",
" testImages := ds.TestImages\n",
" testLabels := ds.TestLabels\n",
"\n",
" cuda := gotch.CudaBuilder(0)\n",
" vs := nn.NewVarStore(cuda.CudaIfAvailable())\n",
"\n",
" var cnn Net = newNet(vs.Root())\n",
" opt, err := nn.DefaultAdamConfig().Build(vs, LrCNN)\n",
" if err != nil {fmt.Print(err)}\n",
"\n",
" var bestAccuracy float64 = 0.0\n",
" startTime := time.Now()\n",
"\n",
" for epoch := 0; epoch < epochsCNN; epoch++ {\n",
" totalSize := ds.TrainImages.MustSize()[0]\n",
" samples := int(totalSize)\n",
" index := ts.MustRandperm(int64(totalSize), gotch.Int64, gotch.CPU)\n",
" imagesTs := ds.TrainImages.MustIndexSelect(0, index, false)\n",
" labelsTs := ds.TrainLabels.MustIndexSelect(0, index, false)\n",
"\n",
" batches := samples / batchSize\n",
" batchIndex := 0\n",
" var epocLoss *ts.Tensor\n",
" for i := 0; i < batches; i++ {\n",
" start := batchIndex * batchSize\n",
" size := batchSize\n",
" if samples-start < batchSize {break}\n",
" batchIndex += 1\n",
"\n",
" // Indexing\n",
" narrowIndex := ts.NewNarrow(int64(start), int64(start+size))\n",
" bImages := imagesTs.Idx(narrowIndex)\n",
" bLabels := labelsTs.Idx(narrowIndex)\n",
"\n",
" bImages = bImages.MustTo(vs.Device(), true)\n",
" bLabels = bLabels.MustTo(vs.Device(), true)\n",
"\n",
" logits := cnn.ForwardT(bImages, true)\n",
" loss := logits.CrossEntropyForLogits(bLabels)\n",
"\n",
" opt.BackwardStep(loss)\n",
"\n",
" epocLoss = loss.MustShallowClone()\n",
" epocLoss.Detach_()\n",
"\n",
" bImages.MustDrop()\n",
" bLabels.MustDrop()\n",
" }\n",
"\n",
" testAccuracy := nn.BatchAccuracyForLogits(vs, cnn, testImages, testLabels, vs.Device(), 1024)\n",
" fmt.Printf(\"Epoch: %v\\t Loss: %.2f \\t Test accuracy: %.2f%%\\n\", epoch, epocLoss.Float64Values()[0], testAccuracy*100.0)\n",
" if testAccuracy > bestAccuracy {\n",
" bestAccuracy = testAccuracy\n",
" }\n",
"\n",
" epocLoss.MustDrop()\n",
" imagesTs.MustDrop()\n",
" labelsTs.MustDrop()\n",
" }\n",
"\n",
" fmt.Printf(\"Best test accuracy: %.2f%%\\n\", bestAccuracy*100.0)\n",
" fmt.Printf(\"Taken time:\\t%.2f mins\\n\", time.Since(startTime).Minutes())\n",
"}"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dIwh8dLLOzJ9"
},
"source": [
"## Run train and evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0AW2nIHjKA9W",
"outputId": "0123b6df-7edc-4f35-b874-2b4f1ec69e54",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"source": [
"trainCNN()"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"text": [
"Epoch: 0\t Loss: 0.24 \t Test accuracy: 93.34%\n",
"Epoch: 1\t Loss: 0.19 \t Test accuracy: 96.02%\n",
"Epoch: 2\t Loss: 0.09 \t Test accuracy: 97.10%\n",
"Epoch: 3\t Loss: 0.10 \t Test accuracy: 97.66%\n",
"Epoch: 4\t Loss: 0.03 \t Test accuracy: 98.13%\n",
"Epoch: 5\t Loss: 0.05 \t Test accuracy: 98.43%\n",
"Epoch: 6\t Loss: 0.09 \t Test accuracy: 98.60%\n",
"Epoch: 7\t Loss: 0.05 \t Test accuracy: 98.80%\n",
"Epoch: 8\t Loss: 0.03 \t Test accuracy: 98.80%\n",
"Epoch: 9\t Loss: 0.05 \t Test accuracy: 98.89%\n",
"Epoch: 10\t Loss: 0.02 \t Test accuracy: 98.88%\n",
"Epoch: 11\t Loss: 0.03 \t Test accuracy: 98.98%\n",
"Epoch: 12\t Loss: 0.03 \t Test accuracy: 99.05%\n",
"Epoch: 13\t Loss: 0.04 \t Test accuracy: 99.06%\n",
"Epoch: 14\t Loss: 0.02 \t Test accuracy: 99.07%\n",
"Epoch: 15\t Loss: 0.02 \t Test accuracy: 98.98%\n",
"Epoch: 16\t Loss: 0.02 \t Test accuracy: 99.06%\n",
"Epoch: 17\t Loss: 0.01 \t Test accuracy: 99.09%\n",
"Epoch: 18\t Loss: 0.02 \t Test accuracy: 99.14%\n",
"Epoch: 19\t Loss: 0.01 \t Test accuracy: 99.09%\n",
"Epoch: 20\t Loss: 0.02 \t Test accuracy: 99.12%\n",
"Epoch: 21\t Loss: 0.03 \t Test accuracy: 99.13%\n",
"Epoch: 22\t Loss: 0.02 \t Test accuracy: 99.11%\n",
"Epoch: 23\t Loss: 0.01 \t Test accuracy: 99.10%\n",
"Epoch: 24\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 25\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 26\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 27\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 28\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 29\t Loss: 0.00 \t Test accuracy: 99.09%\n",
"Epoch: 30\t Loss: 0.01 \t Test accuracy: 99.15%\n",
"Epoch: 31\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 32\t Loss: 0.00 \t Test accuracy: 99.15%\n",
"Epoch: 33\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 34\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 35\t Loss: 0.00 \t Test accuracy: 99.16%\n",
"Epoch: 36\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 37\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 38\t Loss: 0.00 \t Test accuracy: 99.15%\n",
"Epoch: 39\t Loss: 0.01 \t Test accuracy: 99.19%\n",
"Epoch: 40\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 41\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 42\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 43\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 44\t Loss: 0.00 \t Test accuracy: 99.14%\n",
"Epoch: 45\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 46\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 47\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 48\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 49\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 50\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 51\t Loss: 0.01 \t Test accuracy: 99.14%\n",
"Epoch: 52\t Loss: 0.01 \t Test accuracy: 99.26%\n",
"Epoch: 53\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 54\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 55\t Loss: 0.02 \t Test accuracy: 99.24%\n",
"Epoch: 56\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 57\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 58\t Loss: 0.01 \t Test accuracy: 99.22%\n",
"Epoch: 59\t Loss: 0.01 \t Test accuracy: 99.25%\n",
"Epoch: 60\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 61\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 62\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 63\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 64\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 65\t Loss: 0.00 \t Test accuracy: 99.27%\n",
"Epoch: 66\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 67\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 68\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 69\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 70\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 71\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 72\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 73\t Loss: 0.00 \t Test accuracy: 99.32%\n",
"Epoch: 74\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 75\t Loss: 0.00 \t Test accuracy: 99.21%\n",
"Epoch: 76\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 77\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 78\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 79\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 80\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 81\t Loss: 0.00 \t Test accuracy: 99.31%\n",
"Epoch: 82\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 83\t Loss: 0.00 \t Test accuracy: 99.18%\n",
"Epoch: 84\t Loss: 0.01 \t Test accuracy: 99.24%\n",
"Epoch: 85\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 86\t Loss: 0.00 \t Test accuracy: 99.29%\n",
"Epoch: 87\t Loss: 0.00 \t Test accuracy: 99.33%\n",
"Epoch: 88\t Loss: 0.00 \t Test accuracy: 99.28%\n",
"Epoch: 89\t Loss: 0.00 \t Test accuracy: 99.24%\n",
"Epoch: 90\t Loss: 0.00 \t Test accuracy: 99.35%\n",
"Epoch: 91\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 92\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 93\t Loss: 0.00 \t Test accuracy: 99.22%\n",
"Epoch: 94\t Loss: 0.00 \t Test accuracy: 99.25%\n",
"Epoch: 95\t Loss: 0.02 \t Test accuracy: 99.25%\n",
"Epoch: 96\t Loss: 0.00 \t Test accuracy: 99.26%\n",
"Epoch: 97\t Loss: 0.00 \t Test accuracy: 99.29%\n",
"Epoch: 98\t Loss: 0.00 \t Test accuracy: 99.23%\n",
"Epoch: 99\t Loss: 0.00 \t Test accuracy: 99.20%\n",
"Best test accuracy: 99.35%\n",
"Taken time:\t1.97 mins\n"
],
"name": "stdout"
}
]
}
]
}

View File

@ -1,5 +1,7 @@
# Linear Regression, NN, and CNN on MNIST dataset
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/nb/blob/master/mnist/mnist.ipynb)
## MNIST
- MNIST files can be obtained from [this source](http://yann.lecun.com/exdb/mnist/) and put in `data/mnist` from
@ -29,8 +31,6 @@
- Accuracy should be about **99.3%**.
## Run on Cloud
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sugarme/gotch/blob/master/example/mnist/gotch_mnist.ipynb)