gotch/example/mnist
2022-11-24 13:32:46 +11:00
..
cnn.go updated KaimingUniform initialization and mnist CNN 2022-11-24 13:32:46 +11:00
linear.go change package 'tensor' to 'ts' 2022-03-12 18:20:20 +11:00
main.go fix(KaimingUniformInit): fixed incorrect init of KaimingUniform method 2020-06-24 12:47:10 +10:00
nn.go change package 'tensor' to 'ts' 2022-03-12 18:20:20 +11:00
README.md updated KaimingUniform initialization and mnist CNN 2022-11-24 13:32:46 +11:00

Linear Regression, NN, and CNN on MNIST dataset

Open In Colab

MNIST

  • MNIST files can be obtained from this source and put in data/mnist from root folder of this project.

  • Load MNIST data using helper function at vision sub-package

Linear Regression

  • Run with go clean -cache -testcache && go run . -model="linear"

  • Accuracy should be about 91.68%.

Neural Network (NN)

  • Run with go clean -cache -testcache && go run . -model="nn"

  • Accuracy should be about 94%.

Convolutional Neural Network (CNN)

  • Run with go clean -cache -testcache && go run . -model="cnn"

  • Accuracy should be about 99.3%.

Benchmark against Python

  • Train batch size: 256
  • Test batch size: 1000
  • Adam optimizer, learning rate = 3*1e-4
  • Epochs: 30
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time

class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 5, 1)
        self.conv2 = nn.Conv2d(32, 64, 5, 1)
        self.fc1 = nn.Linear(1024, 1024)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(
                dim=1,
                keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=256,
                        metavar='N',
                        help='input batch size for training (default: 256)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=14,
                        metavar='N',
                        help='number of epochs to train (default: 14)')
    parser.add_argument('--lr',
                        type=float,
                        default=1e-4,
                        metavar='LR',
                        help='learning rate (default: 1e-4)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')
    args = parser.parse_args()
    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    if use_cuda:
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1, 'pin_memory': True, 'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform = transforms.Compose([
        transforms.ToTensor(),
        #  transforms.Normalize((0.1307, ), (0.3081, )),
    ])
    dataset1 = datasets.MNIST('../data',
                              train=True,
                              download=True,
                              transform=transform)
    dataset2 = datasets.MNIST('../data', train=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    start = time.time()

    for epoch in range(1, args.epochs + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

    end = time.time()

    print("taken time: {:.2f}mins".format((end - start) / 60.0))

    if args.save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")


if __name__ == '__main__':
    main()
Test set: Average loss: 0.1101, Accuracy: 9666/10000 (96.66%)
Test set: Average loss: 0.0697, Accuracy: 9779/10000 (97.79%)
Test set: Average loss: 0.0442, Accuracy: 9856/10000 (98.56%)
Test set: Average loss: 0.0384, Accuracy: 9873/10000 (98.73%)
Test set: Average loss: 0.0358, Accuracy: 9875/10000 (98.75%)
Test set: Average loss: 0.0323, Accuracy: 9898/10000 (98.98%)
Test set: Average loss: 0.0290, Accuracy: 9906/10000 (99.06%)
Test set: Average loss: 0.0272, Accuracy: 9910/10000 (99.10%)
Test set: Average loss: 0.0280, Accuracy: 9913/10000 (99.13%)
Test set: Average loss: 0.0295, Accuracy: 9908/10000 (99.08%)
Test set: Average loss: 0.0251, Accuracy: 9919/10000 (99.19%)
Test set: Average loss: 0.0246, Accuracy: 9924/10000 (99.24%)
Test set: Average loss: 0.0258, Accuracy: 9921/10000 (99.21%)
Test set: Average loss: 0.0296, Accuracy: 9911/10000 (99.11%)
Test set: Average loss: 0.0271, Accuracy: 9912/10000 (99.12%)
Test set: Average loss: 0.0251, Accuracy: 9918/10000 (99.18%)
Test set: Average loss: 0.0276, Accuracy: 9916/10000 (99.16%)
Test set: Average loss: 0.0291, Accuracy: 9912/10000 (99.12%)
Test set: Average loss: 0.0291, Accuracy: 9920/10000 (99.20%)
Test set: Average loss: 0.0333, Accuracy: 9904/10000 (99.04%)
Test set: Average loss: 0.0268, Accuracy: 9919/10000 (99.19%)
Test set: Average loss: 0.0265, Accuracy: 9931/10000 (99.31%)
Test set: Average loss: 0.0316, Accuracy: 9918/10000 (99.18%)
Test set: Average loss: 0.0299, Accuracy: 9917/10000 (99.17%)
Test set: Average loss: 0.0303, Accuracy: 9923/10000 (99.23%)
Test set: Average loss: 0.0327, Accuracy: 9914/10000 (99.14%)
Test set: Average loss: 0.0314, Accuracy: 9918/10000 (99.18%)
Test set: Average loss: 0.0316, Accuracy: 9920/10000 (99.20%)
Test set: Average loss: 0.0346, Accuracy: 9916/10000 (99.16%)
Test set: Average loss: 0.0308, Accuracy: 9923/10000 (99.23%)
taken time: 5.63mins

Gotch CNN performance

testImages: [10000 784]
testLabels: [10000]
Epoch: 0         Loss: 0.16      Test accuracy: 96.53%
Epoch: 1         Loss: 0.08      Test accuracy: 97.27%
Epoch: 2         Loss: 0.14      Test accuracy: 97.28%
Epoch: 3         Loss: 0.08      Test accuracy: 97.64%
Epoch: 4         Loss: 0.07      Test accuracy: 98.44%
Epoch: 5         Loss: 0.05      Test accuracy: 98.59%
Epoch: 6         Loss: 0.06      Test accuracy: 98.67%
Epoch: 7         Loss: 0.07      Test accuracy: 98.80%
Epoch: 8         Loss: 0.11      Test accuracy: 98.01%
Epoch: 9         Loss: 0.07      Test accuracy: 98.81%
Epoch: 10        Loss: 0.05      Test accuracy: 98.76%
Epoch: 11        Loss: 0.04      Test accuracy: 98.78%
Epoch: 12        Loss: 0.02      Test accuracy: 98.81%
Epoch: 13        Loss: 0.05      Test accuracy: 98.78%
Epoch: 14        Loss: 0.05      Test accuracy: 98.74%
Epoch: 15        Loss: 0.06      Test accuracy: 98.86%
Epoch: 16        Loss: 0.07      Test accuracy: 98.95%
Epoch: 17        Loss: 0.03      Test accuracy: 98.93%
Epoch: 18        Loss: 0.04      Test accuracy: 98.99%
Epoch: 19        Loss: 0.05      Test accuracy: 99.05%
Epoch: 20        Loss: 0.06      Test accuracy: 99.11%
Epoch: 21        Loss: 0.03      Test accuracy: 98.78%
Epoch: 22        Loss: 0.05      Test accuracy: 98.88%
Epoch: 23        Loss: 0.02      Test accuracy: 99.04%
Epoch: 24        Loss: 0.04      Test accuracy: 99.08%
Epoch: 25        Loss: 0.03      Test accuracy: 98.96%
Epoch: 26        Loss: 0.07      Test accuracy: 98.78%
Epoch: 27        Loss: 0.05      Test accuracy: 98.81%
Epoch: 28        Loss: 0.03      Test accuracy: 98.79%
Epoch: 29        Loss: 0.07      Test accuracy: 98.82%
Best test accuracy: 99.11%
Taken time:     2.81 mins