{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Borchification of Networks\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In this tutorial we will create a ``torch`` neural network and use the ``borchify``\nfunctionality in ``alvis`` to turn it into a bayesian neural network. We will then\ntrain the two models on prediction on MNIST and see how they compare.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nfrom torch.nn import Module\nimport torch.nn.functional as F\n\nimport borch\nfrom borch import nn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Next we construct a fairly small ``torch`` CNN with batch normalisation after the\nfirst two convolutional layers:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class CNN(Module):\n    \"\"\"Four layer non-Bayesian neural network with two convolutional (and\n    batch normalisation following each of these) and two fully connected\n    layers.\"\"\"\n    def __init__(self, n_in, n_conv1, n_conv2, n_fc1, n_out):\n        super().__init__()\n        self.conv1 = torch.nn.Conv2d(n_in, n_conv1, kernel_size=5, stride=2)\n        self.bn1 = torch.nn.BatchNorm2d(n_conv1)\n        self.conv2 = torch.nn.Conv2d(n_conv1, n_conv2, kernel_size=5, stride=2)\n        self.bn2 = torch.nn.BatchNorm2d(n_conv2)\n        self.n_at_fc1 = 4 * 4 * n_conv2  # NB manually calculated\n        self.fc1 = torch.nn.Linear(self.n_at_fc1, n_fc1)\n        self.bn3 = torch.nn.BatchNorm1d(n_fc1)\n        self.fc2 = torch.nn.Linear(n_fc1, n_out)\n    def forward(self, x):\n        h = self.bn1(F.relu(self.conv1(x)))\n        h = self.bn2(F.relu(self.conv2(h)))\n        h = self.bn3(F.relu(self.fc1(h.view(-1, self.n_at_fc1))))\n        h = self.fc2(h)\n        self.cls = borch.RandomVariable(borch.distributions.Categorical(logits=h))\n        return h"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "By using `borchify_network` one can turn a torch network entierly bayesian\nit allows one to specify a `posterior_creator` that determines what posterior is used.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "cnn = CNN(1, 10, 10, 100, 10)\nbcnn = CNN(1, 10, 10, 100, 10)\nbcnn = nn.borchify_network(\n    bcnn,\n    posterior_creator=borch.posterior.Automatic)\n# NB borchify is not in-place"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After instantiation, both the ``cnn`` and ``bcnn`` are frequentist networks, but\nborchifying the network we replace the torch.nn.modules with alvis.nn.modules.\nNote that we now have four times as many parameters as all weights/biases been in the\n``cnn`` has been converted to normal distributions that each require both a ``loc``\nand a ``scale``, and furthermore, we both need a `prior` and `posterior distibution.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def total_params(net):\n    return sum([sum(x.size()) for x in net.parameters()])\n\n\nprint(f\"Total parameters in cnn:   {total_params(cnn)}\")\nprint(f\"Total parameters in bcnn: {total_params(bcnn)}\")"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}