Note
Click here to download the full example code
Borchification of NetworksΒΆ
In this tutorial we will create a torch
neural network and use the borchify
functionality in alvis
to turn it into a bayesian neural network. We will then
train the two models on prediction on MNIST and see how they compare.
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Module
import torch.nn.functional as F
import borch
from borch import nn
Next we construct a fairly small torch
CNN with batch normalisation after the
first two convolutional layers:
class CNN(Module):
"""Four layer non-Bayesian neural network with two convolutional (and
batch normalisation following each of these) and two fully connected
layers."""
def __init__(self, n_in, n_conv1, n_conv2, n_fc1, n_out):
super().__init__()
self.conv1 = torch.nn.Conv2d(n_in, n_conv1, kernel_size=5, stride=2)
self.bn1 = torch.nn.BatchNorm2d(n_conv1)
self.conv2 = torch.nn.Conv2d(n_conv1, n_conv2, kernel_size=5, stride=2)
self.bn2 = torch.nn.BatchNorm2d(n_conv2)
self.n_at_fc1 = 4 * 4 * n_conv2 # NB manually calculated
self.fc1 = torch.nn.Linear(self.n_at_fc1, n_fc1)
self.bn3 = torch.nn.BatchNorm1d(n_fc1)
self.fc2 = torch.nn.Linear(n_fc1, n_out)
def forward(self, x):
h = self.bn1(F.relu(self.conv1(x)))
h = self.bn2(F.relu(self.conv2(h)))
h = self.bn3(F.relu(self.fc1(h.view(-1, self.n_at_fc1))))
h = self.fc2(h)
self.cls = borch.RandomVariable(borch.distributions.Categorical(logits=h))
return h
By using borchify_network one can turn a torch network entierly bayesian it allows one to specify a posterior_creator that determines what posterior is used.
cnn = CNN(1, 10, 10, 100, 10)
bcnn = CNN(1, 10, 10, 100, 10)
bcnn = nn.borchify_network(
bcnn,
posterior_creator=borch.posterior.Automatic)
# NB borchify is not in-place
After instantiation, both the cnn
and bcnn
are frequentist networks, but
borchifying the network we replace the torch.nn.modules with alvis.nn.modules.
Note that we now have four times as many parameters as all weights/biases been in the
cnn
has been converted to normal distributions that each require both a loc
and a scale
, and furthermore, we both need a prior and `posterior distibution.
def total_params(net):
return sum([sum(x.size()) for x in net.parameters()])
print(f"Total parameters in cnn: {total_params(cnn)}")
print(f"Total parameters in bcnn: {total_params(bcnn)}")
Out:
Total parameters in cnn: 791
Total parameters in bcnn: 1342
Total running time of the script: ( 0 minutes 0.018 seconds)