Hinton MNIST Dropout in PyTorch

In their classic 2012 paper: Improving neural networks by preventing co-adaptation of feature detectors, Hinton, Srivastava, Krizhevsky, Sutskever and Salakhutdinov showed that using dropout in a feed-forward neural network improved performance significantly. In particular they produced this result for the MNIST dataset:

However, if you try to create this result in PyTorch using modern standards such as the Adam optimiser you may find that recovering the same outcome is not as straight forward as it first appears.

In this post I’m going to recreate their result in PyTorch trying to keep as close as possible to the original methodology.

The Setup

Everything will be done with PyTorch, so we need to start by importing the relevant bits of that library:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

At some point we might also want to use a GPU, so let’s prepare for that:

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
print("Device: ",device)

Now let’s break down the paper, starting with the simple stuff.

The Data

In the paper, they state that they use the 60,000 MNIST training images and 10,000 test images for the training and testing datasets, respectively. They break these datasets into mini-batches of 100 images.

batch_size    = 100       # number of samples per mini-batch

We can then use the PyTorch MNIST dataset to load the data:

transform = transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,)),  # MNIST specific normalisation

train_data = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

valid_data = datasets.MNIST('data', train=False, transform=transform)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=True)

Note the highlighted line. The Hinton+ paper doesn’t specify what normalisation was used for the MNIST data, so here I’ve adopted the standard mean and variance approach.

The Network

To start with let’s specify a 784-800-800-10 net with optional dropout. The Hinton+ paper doesn’t specify what activation function they used, so I’ve just put in a ReLU.

class Classifier_MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, hidden_drop=0.0, input_drop=0.0):
        self.h1  = nn.Linear(in_dim, hidden_dim)
        self.h2  = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, out_dim)
        self.dr1 = nn.Dropout(p=hidden_drop)
        self.dr2 = nn.Dropout(p=input_drop)
        # weight initialisation:
        # following: https://arxiv.org/pdf/1207.0580.pdf (Appendix A)
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.dr2(x)
        x = F.relu(self.h1(x))
        x = self.dr1(x)
        x = F.relu(self.h2(x))
        x = self.dr1(x)
        x = self.out(x)
        return x, F.softmax(x,dim=1)

Note the weight initialisation in lines 10-16. This follows the Hinton+ paper, which states:

Weights are initialzed to small random values drawn from a zero-mean normal distribution with standard deviation 0.01.” [sic]

We are not explicitly told whether biases are included but I’ve used them here and initialised their values to zero.

Each image has 32 x 32 pixels, which means that the first hidden layer has 32 x 32 = 784 input features, and the MNIST data set contains ten different digits, so the output layer should predict over 10 classes.

We can call this network in our code using:

input_size    = 784       # The number of features
hidden_size   = 800       # The number of nodes at the hidden layer
num_classes   = 10        # The number of output classes. 

model = Classifier_MLP(in_dim=input_size, hidden_dim=hidden_size, out_dim=num_classes).to(device=device)

The Optimiser

This is where things start to get a bit more complicated. In the Hinton+ paper, the SGD updates are specified as:

w^t = w^{t-1} + \Delta w^t

\Delta w^t = p^t \Delta w^{t-1} - (1 - p^t)\epsilon^t \nabla_w \mathcal{L}

where w are the weights, p is the momentum, \epsilon is the learning rate and \mathcal{L} is the loss.

The reason this complicates things is that in PyTorch SGD is implemented slightly differently:

w^t = w^{t-1} - \epsilon^t \Delta w^t

\Delta w^t = p^t \Delta w^{t-1} + (1 - d^t) \nabla_w \mathcal{L}

where d is a dampening parameter. It used to be the case that the dampening parameter defaulted to be equal to the momentum in PyTorch, but in more current releases it has a default value of zero.

By setting d^t = p^t we can come one step closer to the Hinton+ form of SGD update. However, the way the learning rate is implemented means that there will still be a subtle difference. If we wanted to use exactly the same SGD update we would need to create a new version of the PyTorch SGD optimiser class, which is actually much more straight forward than it might sound (see below).

The L2 threshold

Instead of specifying a weight decay and using an L2 regularisation the Hinton+ paper states that:

“Instead of penalizing the squared length (L2 norm) of the whole weight vector, we set an upper bound on the L2 norm of the incoming weight vector for each individual hidden unit. If a weight-update violates this constraint, we renormalize the weights of the hidden unit by division.”

To implement this same constraint we need to calculate the L2 norm of the weights for each neuron before each optimiser update, i.e. per mini-batch.

It seems to me that there are two possible ways to do this. Firstly a function could be added into the model class to calculate the L2 norm and renormalise at the beginning of each batch update. Secondly the same kind of functionality could be added into the optimizer class and run as part of the update. To me, the second option seems less open to user error than the first and so that is what I propose here.

Here is an example of a Hinton+ style SGD update including an L2 threshold check and renormalisation.

import torch
from torch.optim.optimizer import Optimizer, required

class SGDHinton(Optimizer):
    r"""Implements stochastic gradient descent with momentum following the update method described in Hinton+ 2012 https://arxiv.org/pdf/1207.0580.pdf including the renormalisation of input weights for individual neurons if they exceed a specified L2 norm threshold.

        params (iterable): iterable of parameters to optimize or dicts defining
            parameter groups
        lr (float): learning rate
        momentum (float, optional): momentum factor (default: 0)
        l2_limit (float, optional): L2 threshold per neuron (default: 15; set to 0 to remove)
        >>> optimizer = SGDHinton(model.parameters(), lr=0.1, momentum=0.9, l2_limit=0)
        >>> optimizer.zero_grad()
        >>> loss_fn(model(input), target).backward()
        >>> optimizer.step()

    .. note::
        The implementation of SGD with Momentum differs from
        the standard PyTorch SGD implementation

        The update can be written as

        .. math::
            v_{t+1} & = \mu * v_{t} + (1 - \mu) * \text{lr} * g_{t+1}, \\
            p_{t+1} & = p_{t} + v_{t+1},
        This is in contrast to the standard PyTorch implementation which employs an update of the form
        .. math::
                v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
                p_{t+1} & = p_{t} - \text{lr} * v_{t+1},

        where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
        parameters, gradient, velocity, and momentum respectively.


    def __init__(self, params, lr=required, momentum=0, l2_limit=15.):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if l2_limit < 0.0:
            raise ValueError("Invalid L2 threshold value: {}".format(l2_limit))

        defaults = dict(lr=lr, momentum=momentum, l2_limit=l2_limit)
        super(SGDHinton, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGDHinton, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            momentum = group['momentum']
            learning_rate = group['lr']
            l2_limit = group['l2_limit']
            for p in group['params']:
                if p.grad is None:
                d_p = p.grad
                renorm = torch.ones(p.size())
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                        buf = param_state['momentum_buffer']
                        alpha = learning_rate*(1 - momentum)
                        buf.mul_(momentum).add_(d_p, alpha=-1.*alpha)
                        p_tmp = torch.clone(p).detach()
                        if l2_limit > 0 and len(p_tmp.size()) > 1:
                            l2_norm = [torch.sqrt(torch.sum(p_tmp[i,:]**2)) for i in range(p_tmp.size(0))]
                            for i in range(p_tmp.size(0)):
                                if l2_norm[i].item()>l2_limit:
                                    renorm[i,:] = l2_limit/l2_norm[i].item()
                        d_p = buf
                p.add_(d_p, alpha=1.).mul_(renorm)
        return loss

We can call this optimiser in our main code using:

learning_rate = 1e-1      # Initial learning rate
p_i           = 5e-1      # Initial momentum

optimizer = SGDHinton(model.parameters(), lr=learning_rate, momentum=p_i, l2_limit=15.)

Learning Rate & Momentum Schedule

You can see in the code block above that the momentum and learning rate values that we use to initiate the optimiser class are described as initial. This is because the Hinton+ paper evolves both of these parameters as a function of epoch. Specifically,

\epsilon^k = \epsilon_0 \times \gamma^k

where k is the epoch, \epsilon_0 = 0.1 is the initial learning rate and \gamma=0.998 is a decay factor; the momentum is given by

p^k = \begin{cases} \frac{k}{K} p_f+ \left( 1 - \frac{k}{K} \right) p_i &  k < K \\ p_f & k \geq K \end{cases}

where p_i =0.5 is the initial momentum and p_f =0.99 is the final momentum at epoch K=500.

Note: the equation for the momentum evolution specified above differs from the one in Appendix~A of the Hinton+ paper, because there is a typo in the paper that swaps p_i and p_f.

By defining the two following functions to encapsulate these changes as

# -----------------------------------------------------------

def get_lr(epoch, lr0, gamma):

    return lr0*gamma**epoch
# -----------------------------------------------------------
def get_momentum(epoch, p_i, p_f, T):

    if epoch<T:
        p = (epoch/T)*p_f + (1 - (epoch/T))*p_i
        p = p_f
    return p

# -----------------------------------------------------------

we can then update the momentum and learning rate at each epoch using

optimizer.param_groups[0]['lr'] = get_lr(epoch, learning_rate, gamma)
optimizer.param_groups[0]['momentum'] = get_momentum(epoch, p_i, p_f, K)

Open Questions

  1. Can I just use dampening = momentum in normal PyTorch SGD and get a similar output?
  2. What happens if I ignore the L2 threshold or use a weight decay instead?
  3. Will the Adam optimiser give me the same results as SGD with dampening?

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

%d bloggers like this: