How to train a CNN to 99% accuracy on MNIST in less than a second on a laptop

Overview

Training a NN to 99% accuracy on MNIST in 0.76 seconds

A quick study on how fast you can reach 99% accuracy on MNIST with a single laptop. Our answer is 0.76 seconds, reaching 99% accuracy in just one epoch of training. This is more than 200 times faster than the default training code from Pytorch. To see the final results, check 8_Final_00s76.ipynb. If you're interested in the process read on below for a step by step description of changes made.

The repo is organized into jupyter notebooks, showing a chronological order of changes required to go from initial Pytorch tutorial that trains for 3 minutes to less than a second of training time on a laptop with GeForce GTX 1660 Ti GPU. I aimed for a coordinate ascent like procedure, changing only one thing at a time to make sure we understand what is the source of improvements each time, but sometimes I bunched up correlated or small changes.

Requirements

Python3.x and Pytorch 1.8 (most likely works with >= 1.3). For fast times you'll need Cuda and a compatible GPU.

0_Pytorch_initial_2m_52s.ipynb: Starting benchmark

First we need to benchmark starting performance. This can be found in the file 0_Pytorch_initial_2m_52s.ipynb. Note the code downloads the dataset if not already present so reporting second run time. Trains for 14 epochs each run, average accuracy of two runs is 99.185% on test set, and the mean runtime is 2min 52s ± 38.1ms.

1_Early_stopping_57s40.ipynb: Stop early

Since our goal is to reach only 99% accuracy, we don't need the full training time. Our first modification is to simply stop training after the epoch we hit 99% test accuracy. This is typically reached within 3-5 epochs with average final accuracy of 99.07%, cutting training time to around a third of the original at 57.4s ± 6.85s.

2_Smaller_NN_30s30.ipynb: Reduce network size

Next we employ the trick of reducing both network size and regularization to speed up convergence. This is done by adding a 2x2 max pool layer after the first conv layer, reducing parameters in our fully connected layers by more than 4x. To compensate we also remove one of the 2 dropout layers. This reduces number of epochs we need to converge to 2-3, and training time to 30.3s ± 5.28s.

3_Data_loading_07s31.ipynb: Optimize Data Loading!

This is probably the biggest and most surprising time save of this project. Just by better optimizing the data loading process we can save 75% of the entire training run time. It turns out that torch.utils.data.DataLoader is really inefficient for small datasets like MNIST, instead of reading it from the disk one batch at a time we can simply load the entire dataset into GPU memory at once and keep it there. To do this we save the entire dataset with the same processing we had before onto disk in a single pytorch array using data_loader.save_data(). This takes around 10s and is not counted in the training time as it has to be done only once. With this optimization, our average training time goes down to 7.31s ± 1.36s.

4_128_Batch_04s66.ipynb: Increase batch size

Now that we have optimized data loading, increasing batch size can significantly increase the speed of training. Simply increasing the batch size from 64 to 128 reduces our average train time to 4.66s ± 583ms.

5_Onecycle_lr_03s14.ipynb: Better learning rate schedule

For this step, we turn our looks to to the learning rate schedule. Previously we used an exponential decay where after each epoch lr is multiplied by 0.7. We replace this by Superconvergence also known as OneCycleLR, where the learning starts close to 0 and is linearly(or with cosine schedule) increased to to its peak value at the middle of training and slowly lowered down to zero again in the end. This allows using much higher learning rates than otherwise. We used peak LR of 4.0, 4 times higher than the starting lr used previously. The network reaches 99% in 2 epochs every time now, and this takes our training time down to 3.14s ± 4.72ms.

6_256_Batch_02s31.ipynb: Increase batch size, again

With our better lr schedule we can once more double our batch size without hurting performance much. Note this time around it doesn't reach 99% on all random seeds but I count it as a success as long I'm confident the mean accuracy is greater than 99%. This is because Superconvergence requires a fixed length training and we can't quarantee every seed works. This cuts our training time down to 2.31s ± 23.2ms.

7_Smaller_NN2_01s74.ipynb: Remove dropout and reduce size, again

Next we repeat our procedure from step 2 once again, remove the remaning dropout layer and compensate by reducing the width of our convolutional layers, first to 24 from 32 and second to 32 from 64. This reduces the time to train an epoch, and even nets us with increased accuracy, averaging around 99.1% after two epochs of training. This gives us mean time of 1.74s ± 18.3ms.

8_Final_00s76.ipynb: Tune everything

Now that we have a fast working model and we have grabbed most of the low hanging improvements, it is time to dive into final finetuning. To start off, we simply move our max pool operations before the ReLU activation, which doesn't change the network but saves us a bit of compute.

The next changes were the result of a large search operation, where I tried a number of different things, optimizing one hyperparameter at a time. For each change I trained on 30 different seeds and measured what gives us the highest mean accuracy. 30 seeds was necessary to make statistically significant conclusions on small changes, and it is worth noting training 30 seeds took less than a minute at this point. Higher accuracy can then be translated into faster times by cutting down on the number of epochs.

First I actually made the network bigger in select places that didn't slow down performance too much. The kernel size of first convolutional layer was incresed from 3 to 5, and the final fully connected layer increased from 128 to 256.

Next, it was time to change the optimizer. I found that with proper hyperparameters, Adam actually outperforms Adadelta which we had used so far. The hyperparameters I changed from default are learning rate of 0.01(default 0.001), beta1 of 0.7(default 0.9) and bata2 of 0.9(default 0.999).

All of this lead to a large boost in accuracy(99.245% accuracy after 2 epochs), which I was able to finally trade into faster training times by cutting training down to just one epoch! Our final result is 99.04% mean accuracy in just 762ms ± 24.9ms.

Owner
Tuomas Oikarinen
PhD student at UC San Diego, trying to understand ML and hopefully make it more safe. Previously @MIT.
Tuomas Oikarinen
[ACM MM2021] MGH: Metadata Guided Hypergraph Modeling for Unsupervised Person Re-identification

Introduction This project is developed based on FastReID, which is an ongoing ReID project. Projects BUC In projects/BUC, we implement AAAI 2019 paper

WuYiming 7 Apr 13, 2022
This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

18 Sep 02, 2022
AWS documentation corpus for zero-shot open-book question answering.

aws-documentation We present the AWS documentation corpus, an open-book QA dataset, which contains 25,175 documents along with 100 matched questions a

Sia Gholami 2 Jul 07, 2022
A Fast and Accurate One-Stage Approach to Visual Grounding, ICCV 2019 (Oral)

One-Stage Visual Grounding ***** New: Our recent work on One-stage VG is available at ReSC.***** A Fast and Accurate One-Stage Approach to Visual Grou

Zhengyuan Yang 118 Dec 05, 2022
Ground truth data for the Optical Character Recognition of Historical Classical Commentaries.

OCR Ground Truth for Historical Commentaries The dataset OCR ground truth for historical commentaries (GT4HistComment) was created from the public dom

Ajax Multi-Commentary 3 Sep 08, 2022
PyTorch implementation of Convolutional Neural Fabrics http://arxiv.org/abs/1606.02492

PyTorch implementation of Convolutional Neural Fabrics arxiv:1606.02492 There are some minor differences: The raw image is first convolved, to obtain

Anuvabh Dutt 25 Dec 22, 2021
Towards uncontrained hand-object reconstruction from RGB videos

Towards uncontrained hand-object reconstruction from RGB videos Yana Hasson, Gül Varol, Ivan Laptev and Cordelia Schmid Project page Paper Table of Co

Yana 69 Dec 27, 2022
Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods."

pv_predict_unet-lstm Code for "Intra-hour Photovoltaic Generation Forecasting based on Multi-source Data and Deep Learning Methods." IEEE Transactions

FolkScientistInDL 8 Oct 08, 2022
Learning to Reach Goals via Iterated Supervised Learning

Vanilla GCSL This repository contains a vanilla implementation of "Learning to Reach Goals via Iterated Supervised Learning" proposed by Dibya Gosh et

Christoph Heindl 4 Aug 10, 2022
Perception-aware multi-sensor fusion for 3D LiDAR semantic segmentation (ICCV 2021)

Perception-Aware Multi-Sensor Fusion for 3D LiDAR Semantic Segmentation (ICCV 2021) [中文|EN] 概述 本工作主要探索一种高效的多传感器(激光雷达和摄像头)融合点云语义分割方法。现有的多传感器融合方法主要将点云投影

ICE 126 Dec 30, 2022
StyleGAN2 with adaptive discriminator augmentation (ADA) - Official TensorFlow implementation

StyleGAN2 with adaptive discriminator augmentation (ADA) — Official TensorFlow implementation Training Generative Adversarial Networks with Limited Da

NVIDIA Research Projects 1.7k Dec 29, 2022
Contour-guided image completion with perceptual grouping (BMVC 2021 publication)

Contour-guided Image Completion with Perceptual Grouping Authors Morteza Rezanejad*, Sidharth Gupta*, Chandra Gummaluru, Ryan Marten, John Wilder, Mic

Sid Gupta 6 Dec 27, 2022
Linear algebra python - Number of operations and problems in Linear Algebra and Numerical Linear Algebra

Linear algebra in python Number of operations and problems in Linear Algebra and

Alireza 5 Oct 09, 2022
Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch

Implementing SYNTHESIZER: Rethinking Self-Attention in Transformer Models using Pytorch Reference Paper URL Author: Yi Tay, Dara Bahri, Donald Metzler

Myeongjun Kim 66 Nov 30, 2022
Colar: Effective and Efficient Online Action Detection by Consulting Exemplars, CVPR 2022.

Colar: Effective and Efficient Online Action Detection by Consulting Exemplars This repository is the official implementation of Colar. In this work,

LeYang 246 Dec 13, 2022
Pcos-prediction - Predicts the likelihood of Polycystic Ovary Syndrome based on patient attributes and symptoms

PCOS Prediction 🥼 Predicts the likelihood of Polycystic Ovary Syndrome based on

Samantha Van Seters 1 Jan 10, 2022
Code for "Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance" at NeurIPS 2021

Finding Regions of Heterogeneity in Decision-Making via Expected Conditional Covariance Justin Lim, Christina X Ji, Michael Oberst, Saul Blecker, Leor

Sontag Lab 3 Feb 03, 2022
Yolov5+SlowFast: Realtime Action Detection Based on PytorchVideo

Yolov5+SlowFast: Realtime Action Detection A realtime action detection frame work based on PytorchVideo. Here are some details about our modification:

WuFan 181 Dec 30, 2022
A Python library for differentiable optimal control on accelerators.

A Python library for differentiable optimal control on accelerators.

Google 80 Dec 21, 2022
PAMI stands for PAttern MIning. It constitutes several pattern mining algorithms to discover interesting patterns in transactional/temporal/spatiotemporal databases

Introduction PAMI stands for PAttern MIning. It constitutes several pattern mining algorithms to discover interesting patterns in transactional/tempor

RAGE UDAY KIRAN 43 Jan 08, 2023