Borchification of NetworksΒΆ

In this tutorial we will create a torch neural network and use the borchify functionality in alvis to turn it into a bayesian neural network. We will then train the two models on prediction on MNIST and see how they compare.

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Module
import torch.nn.functional as F

import borch
from borch import nn

Next we construct a fairly small torch CNN with batch normalisation after the first two convolutional layers:

class CNN(Module):
    """Four layer non-Bayesian neural network with two convolutional (and
    batch normalisation following each of these) and two fully connected
    layers."""
    def __init__(self, n_in, n_conv1, n_conv2, n_fc1, n_out):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(n_in, n_conv1, kernel_size=5, stride=2)
        self.bn1 = torch.nn.BatchNorm2d(n_conv1)
        self.conv2 = torch.nn.Conv2d(n_conv1, n_conv2, kernel_size=5, stride=2)
        self.bn2 = torch.nn.BatchNorm2d(n_conv2)
        self.n_at_fc1 = 4 * 4 * n_conv2  # NB manually calculated
        self.fc1 = torch.nn.Linear(self.n_at_fc1, n_fc1)
        self.bn3 = torch.nn.BatchNorm1d(n_fc1)
        self.fc2 = torch.nn.Linear(n_fc1, n_out)
    def forward(self, x):
        h = self.bn1(F.relu(self.conv1(x)))
        h = self.bn2(F.relu(self.conv2(h)))
        h = self.bn3(F.relu(self.fc1(h.view(-1, self.n_at_fc1))))
        h = self.fc2(h)
        self.cls = borch.RandomVariable(borch.distributions.Categorical(logits=h))
        return h

By using borchify_network one can turn a torch network entierly bayesian it allows one to specify a posterior_creator that determines what posterior is used.

cnn = CNN(1, 10, 10, 100, 10)
bcnn = CNN(1, 10, 10, 100, 10)
bcnn = nn.borchify_network(
    bcnn,
    posterior_creator=borch.posterior.Automatic)
# NB borchify is not in-place

After instantiation, both the cnn and bcnn are frequentist networks, but borchifying the network we replace the torch.nn.modules with alvis.nn.modules. Note that we now have four times as many parameters as all weights/biases been in the cnn has been converted to normal distributions that each require both a loc and a scale, and furthermore, we both need a prior and `posterior distibution.

def total_params(net):
    return sum([sum(x.size()) for x in net.parameters()])


print(f"Total parameters in cnn:   {total_params(cnn)}")
print(f"Total parameters in bcnn: {total_params(bcnn)}")

Out:

Total parameters in cnn:   791
Total parameters in bcnn: 1342

Total running time of the script: ( 0 minutes 0.018 seconds)

Gallery generated by Sphinx-Gallery