"""
Borchification of Networks
==========================
"""

###########################################
# In this tutorial we will create a ``torch`` neural network and use the ``borchify``
# functionality in ``alvis`` to turn it into a bayesian neural network. We will then
# train the two models on prediction on MNIST and see how they compare.

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.nn import Module
import torch.nn.functional as F

import borch
from borch import nn

###########################################
# Next we construct a fairly small ``torch`` CNN with batch normalisation after the
# first two convolutional layers:


class CNN(Module):
    """Four layer non-Bayesian neural network with two convolutional (and
    batch normalisation following each of these) and two fully connected
    layers."""
    def __init__(self, n_in, n_conv1, n_conv2, n_fc1, n_out):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(n_in, n_conv1, kernel_size=5, stride=2)
        self.bn1 = torch.nn.BatchNorm2d(n_conv1)
        self.conv2 = torch.nn.Conv2d(n_conv1, n_conv2, kernel_size=5, stride=2)
        self.bn2 = torch.nn.BatchNorm2d(n_conv2)
        self.n_at_fc1 = 4 * 4 * n_conv2  # NB manually calculated
        self.fc1 = torch.nn.Linear(self.n_at_fc1, n_fc1)
        self.bn3 = torch.nn.BatchNorm1d(n_fc1)
        self.fc2 = torch.nn.Linear(n_fc1, n_out)
    def forward(self, x):
        h = self.bn1(F.relu(self.conv1(x)))
        h = self.bn2(F.relu(self.conv2(h)))
        h = self.bn3(F.relu(self.fc1(h.view(-1, self.n_at_fc1))))
        h = self.fc2(h)
        self.cls = borch.RandomVariable(borch.distributions.Categorical(logits=h))
        return h

##########################################
# By using `borchify_network` one can turn a torch network entierly bayesian
# it allows one to specify a `posterior_creator` that determines what posterior is used.

cnn = CNN(1, 10, 10, 100, 10)
bcnn = CNN(1, 10, 10, 100, 10)
bcnn = nn.borchify_network(
    bcnn,
    posterior_creator=borch.posterior.Automatic)
# NB borchify is not in-place

###########################################
# After instantiation, both the ``cnn`` and ``bcnn`` are frequentist networks, but
# borchifying the network we replace the torch.nn.modules with alvis.nn.modules.
# Note that we now have four times as many parameters as all weights/biases been in the
# ``cnn`` has been converted to normal distributions that each require both a ``loc``
# and a ``scale``, and furthermore, we both need a `prior` and `posterior distibution.

def total_params(net):
    return sum([sum(x.size()) for x in net.parameters()])


print(f"Total parameters in cnn:   {total_params(cnn)}")
print(f"Total parameters in bcnn: {total_params(bcnn)}")


