In their classic 2012 paper: Improving neural networks by preventing co-adaptation of feature detectors, Hinton, Srivastava, Krizhevsky, Sutskever and Salakhutdinov showed that using dropout in a feed-forward neural network improved performance significantly. In particular they produced this result for the MNIST dataset:

However, if you try to create this result in PyTorch using modern standards such as the Adam optimiser you may find that recovering the same outcome is not as straight forward as it first appears.
In this post I’m going to recreate their result in PyTorch trying to keep as close as possible to the original methodology.
The Setup
Everything will be done with PyTorch, so we need to start by importing the relevant bits of that library:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
At some point we might also want to use a GPU, so let’s prepare for that:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
print("Device: ",device)
Now let’s break down the paper, starting with the simple stuff.
The Data
In the paper, they state that they use the 60,000 MNIST training images and 10,000 test images for the training and testing datasets, respectively. They break these datasets into mini-batches of 100 images.
batch_size = 100 # number of samples per mini-batch
We can then use the PyTorch MNIST dataset to load the data:
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)), # MNIST specific normalisation
torch.flatten])
train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
valid_data = datasets.MNIST('data', train=False, transform=transform)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True)
Note the highlighted line. The Hinton+ paper doesn’t specify what normalisation was used for the MNIST data, so here I’ve adopted the standard mean and variance approach.
The Network
To start with let’s specify a 784-800-800-10 net with optional dropout. The Hinton+ paper doesn’t specify what activation function they used, so I’ve just put in a ReLU.
class Classifier_MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, hidden_drop=0.0, input_drop=0.0):
super().__init__()
self.h1 = nn.Linear(in_dim, hidden_dim)
self.h2 = nn.Linear(hidden_dim, hidden_dim)
self.out = nn.Linear(hidden_dim, out_dim)
self.dr1 = nn.Dropout(p=hidden_drop)
self.dr2 = nn.Dropout(p=input_drop)
# weight initialisation:
# following: https://arxiv.org/pdf/1207.0580.pdf (Appendix A)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.dr2(x)
x = F.relu(self.h1(x))
x = self.dr1(x)
x = F.relu(self.h2(x))
x = self.dr1(x)
x = self.out(x)
return x, F.softmax(x,dim=1)
Note the weight initialisation in lines 10-16. This follows the Hinton+ paper, which states:
“Weights are initialzed to small random values drawn from a zero-mean normal distribution with standard deviation 0.01.” [sic]
We are not explicitly told whether biases are included but I’ve used them here and initialised their values to zero.
Each image has 32 x 32 pixels, which means that the first hidden layer has 32 x 32 = 784 input features, and the MNIST data set contains ten different digits, so the output layer should predict over 10 classes.
We can call this network in our code using:
input_size = 784 # The number of features
hidden_size = 800 # The number of nodes at the hidden layer
num_classes = 10 # The number of output classes.
model = Classifier_MLP(in_dim=input_size, hidden_dim=hidden_size, out_dim=num_classes).to(device=device)
The Optimiser
This is where things start to get a bit more complicated. In the Hinton+ paper, the SGD updates are specified as:
where are the weights,
is the momentum,
is the learning rate and
is the loss.
The reason this complicates things is that in PyTorch SGD is implemented slightly differently:
where is a dampening parameter. It used to be the case that the dampening parameter defaulted to be equal to the momentum in PyTorch, but in more current releases it has a default value of zero.
By setting we can come one step closer to the Hinton+ form of SGD update. However, the way the learning rate is implemented means that there will still be a subtle difference. If we wanted to use exactly the same SGD update we would need to create a new version of the PyTorch SGD optimiser class, which is actually much more straight forward than it might sound (see below).
The L2 threshold
Instead of specifying a weight decay and using an L2 regularisation the Hinton+ paper states that:
“Instead of penalizing the squared length (L2 norm) of the whole weight vector, we set an upper bound on the L2 norm of the incoming weight vector for each individual hidden unit. If a weight-update violates this constraint, we renormalize the weights of the hidden unit by division.”
To implement this same constraint we need to calculate the L2 norm of the weights for each neuron before each optimiser update, i.e. per mini-batch.
It seems to me that there are two possible ways to do this. Firstly a function could be added into the model class to calculate the L2 norm and renormalise at the beginning of each batch update. Secondly the same kind of functionality could be added into the optimizer class and run as part of the update. To me, the second option seems less open to user error than the first and so that is what I propose here.
Here is an example of a Hinton+ style SGD update including an L2 threshold check and renormalisation.
import torch
from torch.optim.optimizer import Optimizer, required
class SGDHinton(Optimizer):
r"""Implements stochastic gradient descent with momentum following the update method described in Hinton+ 2012 https://arxiv.org/pdf/1207.0580.pdf including the renormalisation of input weights for individual neurons if they exceed a specified L2 norm threshold.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
l2_limit (float, optional): L2 threshold per neuron (default: 15; set to 0 to remove)
Example:
>>> optimizer = SGDHinton(model.parameters(), lr=0.1, momentum=0.9, l2_limit=0)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
.. note::
The implementation of SGD with Momentum differs from
the standard PyTorch SGD implementation
The update can be written as
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + (1 - \mu) * \text{lr} * g_{t+1}, \\
p_{t+1} & = p_{t} + v_{t+1},
\end{aligned}
This is in contrast to the standard PyTorch implementation which employs an update of the form
.. math::
\begin{aligned}
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
\end{aligned}
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
parameters, gradient, velocity, and momentum respectively.
"""
def __init__(self, params, lr=required, momentum=0, l2_limit=15.):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if l2_limit < 0.0:
raise ValueError("Invalid L2 threshold value: {}".format(l2_limit))
defaults = dict(lr=lr, momentum=momentum, l2_limit=l2_limit)
super(SGDHinton, self).__init__(params, defaults)
def __setstate__(self, state):
super(SGDHinton, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
momentum = group['momentum']
learning_rate = group['lr']
l2_limit = group['l2_limit']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad
renorm = torch.ones(p.size())
if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
alpha = learning_rate*(1 - momentum)
buf.mul_(momentum).add_(d_p, alpha=-1.*alpha)
p_tmp = torch.clone(p).detach()
if l2_limit > 0 and len(p_tmp.size()) > 1:
l2_norm = [torch.sqrt(torch.sum(p_tmp[i,:]**2)) for i in range(p_tmp.size(0))]
for i in range(p_tmp.size(0)):
if l2_norm[i].item()>l2_limit:
renorm[i,:] = l2_limit/l2_norm[i].item()
d_p = buf
p.add_(d_p, alpha=1.).mul_(renorm)
return loss
We can call this optimiser in our main code using:
learning_rate = 1e-1 # Initial learning rate
p_i = 5e-1 # Initial momentum
optimizer = SGDHinton(model.parameters(), lr=learning_rate, momentum=p_i, l2_limit=15.)
Learning Rate & Momentum Schedule
You can see in the code block above that the momentum and learning rate values that we use to initiate the optimiser class are described as initial. This is because the Hinton+ paper evolves both of these parameters as a function of epoch. Specifically,
where is the epoch,
is the initial learning rate and
is a decay factor; the momentum is given by
where is the initial momentum and
is the final momentum at epoch
.
Note: the equation for the momentum evolution specified above differs from the one in Appendix~A of the Hinton+ paper, because there is a typo in the paper that swaps and
.
By defining the two following functions to encapsulate these changes as
# -----------------------------------------------------------
def get_lr(epoch, lr0, gamma):
return lr0*gamma**epoch
# -----------------------------------------------------------
def get_momentum(epoch, p_i, p_f, T):
if epoch<T:
p = (epoch/T)*p_f + (1 - (epoch/T))*p_i
else:
p = p_f
return p
# -----------------------------------------------------------
we can then update the momentum and learning rate at each epoch using
optimizer.param_groups[0]['lr'] = get_lr(epoch, learning_rate, gamma)
optimizer.param_groups[0]['momentum'] = get_momentum(epoch, p_i, p_f, K)
Open Questions
- Can I just use dampening = momentum in normal PyTorch SGD and get a similar output?
- What happens if I ignore the L2 threshold or use a weight decay instead?
- Will the Adam optimiser give me the same results as SGD with dampening?