{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Setting Priors\nThe prior distribution has many roles in Bayesian inference. Primarily it is used to\nencode domain knowledge into the model, such as encoding that a parameter must be passive etc.\nIn practice it also becomes a means of stabilizing inferences in complex, high-dimensional\nproblems.\n\n\nThere is a lot of literature on the choice of prior on should use in statistical models,\nsuch as uniform priors, Jeffreys\u2019 priors, reference priors, maximum entropy priors\nand weakly informative priors. In order to select a good prior on have to take the entire\nmodel into account. A simple way of doing this is to see if the constructed model generates\nsane values when generating from the prior :cite:`2017Entrp..19..555G`.\n\nLets do an analysis of how to set the prior for a neural network.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import math\n\nimport numpy\nimport matplotlib\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\nimport torch\nfrom torch.nn import init\n\nimport borch\nfrom borch import  nn\nimport borch.distributions as dist\nfrom borch.utils.torch_utils import get_device\n\nDEVICE = get_device()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Visualization of prior vs posterior plot\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "plt.figure()\nbins = numpy.linspace(-3, 3, 100)\nplt.hist(\n    [dist.Normal(0, 1).sample().item() for _ in range(1000)],\n    bins,\n    alpha=0.5,\n    label=\"Prior\",\n)\nplt.hist(\n    [dist.Normal(0.25, 0.25).sample().item() for _ in range(1000)],\n    bins,\n    alpha=0.5,\n    label=\"Posterior\",\n)\nplt.legend()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A very commonly used prior for weights in neural networks is the standard gaussian\n$N(0,1)$. In order to set the prior for modules in `borch.nn`, one has to use the\nthis can be ``rv_factory`` argument. To get a standard gaussian one can use the\n``parameter_to_normal_rv`` as an ``rv_factory``. Lets see how this prior works for a\nchain of linear layers(for simplicity we negelct non-lineareties between the loinear layes).\n\n\nIn order to see how sane the prior is by sending random noise in an look at the output\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = nn.Sequential(\n    nn.Linear(\n        1000,\n        1000,\n        weight=dist.Normal(0, 1),\n        bias=dist.Normal(0, 1),\n        posterior=borch.posterior.Automatic(),\n    ),\n    nn.Linear(\n        1000,\n        1000,\n        weight=dist.Normal(0, 1),\n        bias=dist.Normal(0, 1),\n        posterior=borch.posterior.Automatic(),\n    ),\n)\n\n\nborch.sample(net)\nout = net(torch.randn(3, 1000, 1000))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "lets look at the mean and standard deviation of the out put.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(out.mean())\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As we can see the standard deviation is VERY large, unless the target that you are trying to\nmodel can be described with that variance, the prior is set wrongly.\n\nLets see how this prior works for a single linear layer in order to develop some\nunderstanding of what is happening\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = nn.Linear(\n    1000, \n    1000,\n    weight=dist.Normal(0, 1),\n    bias=dist.Normal(0, 1),\n    posterior=borch.posterior.Automatic()\n)\nborch.sample(net)\nout = net(torch.randn(3, 1000, 1000))\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Basically one layer increases the standard deviation 30 times, then this will be sent\ninto the next layer creates a compunding effect that makes the standard deviation explode.\nSo in order to avoid this behavior we should construct a prior creates gaussian noise\nas output from a linear if one feeds it in gaussian noise.\n\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can get a standarad deviation that is equal to one by calculating the standard\ndeviation of the prior using the method described in \"Understanding the difficulty of training deep feedforward\nneural networks\" - Glorot, X. & Bengio, Y. (2010). The std can be used to construct\na  $\\mathcal{N}(0, \\text{std})$ distribution,  where\n\n\\begin{align}\\text{std} = \\text{gain} \\times \\sqrt{\\frac{2}{\\text{fan_in} + \\text{fan_out}}}\\end{align}\n\nAlso known as Glorot initialization.\n\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def xavier_normal_std(shape, gain=1):\n    \"\"\"\n    Args:\n        shape (tuple): a tuple with ints of shape of the tensor where the xavier init\n          method will be used.\n        gain (float):  an optional scaling factor\n\n    Returns:\n        float, the std to be used in a Gaussian Distribution to achieve a xavier\n         initialization.\n\n    \"\"\"\n    tensor = torch.randn(*shape)\n    if tensor.ndimension() < 2:\n        tensor = tensor.view(*tensor.shape, 1)\n    fan_in, fan_out = init._calculate_fan_in_and_fan_out(tensor)\n    std = gain * math.sqrt(2.0 / (fan_in + fan_out))\n    return std"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Lets try the network again with the new prior\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = nn.Sequential(\n    nn.Linear(\n        1000, 1000, weight=dist.Normal(0, xavier_normal_std((1000, 1000))), posterior=borch.posterior.Automatic()\n    ),\n    nn.Linear(\n        1000, 1000, weight=dist.Normal(0, xavier_normal_std((1000, 1000))), posterior=borch.posterior.Automatic()\n    ),\n)\n\n\nborch.sample(net)\nout = net(torch.randn(3, 1000, 1000))\nprint(out.mean())\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we see that the mean is close to zero and the standard deviation is close to one.\nSo form this we can conclude that this is a more sane prior as it generates values in\nthe range that do not explode.\n\nLets try the same analysis of a single layer for an ``Conv2d``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = nn.Conv2d(\n    16,\n    33,\n    (3, 5),\n    stride=(2, 1),\n    padding=(4, 2),\n    dilation=(3, 1),\n    weight=dist.Normal(0,1),\n    posterior=borch.posterior.Automatic(),\n)\nborch.sample(net)\nout = net(torch.randn(20, 16, 50, 100))\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "again we see that the standard deviation is to large, lets try setting the prior\nwith the  xavier_normal_std.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net = nn.Conv2d(\n    16,\n    33,\n    (3, 5),\n    stride=1,\n    weight = dist.Normal(0, xavier_normal_std((33, 16, 3, 5))),\n    posterior=borch.posterior.Automatic(),\n)\nborch.sample(net)\nout = net(torch.randn(20, 16, 50, 100))\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we see that we are on the right scale, but we are not at the desired standard\ndeviation of 1 for the output. To address this and also support activation functions\nwe can calculate the standard deviation according to the method described in in\n\"Delving deep into rectifiers: Surpassing human-level\nperformance on ImageNet classification\" - He, K. et al. (2015), using a\nnormal distribution. The resulting tensor will have values sampled from\n$\\mathcal{N}(0, \\text{std})$ where\n\n\\begin{align}\\text{std} = \\sqrt{\\frac{2}{(1 + a^2) \\times \\text{fan_in}}}\\end{align}\n\nAlso known as He initialization.\n\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def kaiming_normal_std(shape, a=math.sqrt(5), mode=\"fan_in\", nonlinearity=\"linear\"):\n    \"\"\"\n    Args:\n        shape (tuple): a tuple with ints of shape of the tensor where\n          the xavier init method will be used.\n        gain (float):  an optional scaling factor\n\n    Returns:\n        float, the std to be used in a Gaussian Distribution to achieve\n          a xavier initialization.\n    \"\"\"\n    tensor = torch.randn(*shape)\n    if tensor.ndimension() < 2:\n        tensor = tensor.view(*tensor.shape, 1)\n    gain = init.calculate_gain(nonlinearity, a)\n    fan = init._calculate_correct_fan(tensor, mode)\n    std = gain/ math.sqrt(fan)\n    return std\n\n\nnet = nn.Conv2d(\n    16,\n    33,\n    (3, 5),\n    stride=1,\n    weight=dist.Normal(0, kaiming_normal_std((33, 16, 3, 5))),\n    posterior=borch.posterior.Automatic(),\n)\nborch.sample(net)\nout = net(torch.randn(20, 16, 50, 100))\nprint(out.std())\n\nnet = nn.Linear(\n    1000, 1000,\n    weight=dist.Normal(0, kaiming_normal_std((1000, 1000))),\n    posterior=borch.posterior.Automatic()\n)\nborch.sample(net)\nout = net(torch.randn(3, 1000, 1000))\nprint(out.std())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we have constructed a prior that works for both linear and conv layers\nwhere we chose to get the stanadard deviation to be close to 1, but this will ofcourse\nbe effected by the activation function on selects. When combining it with activation functions\none will have to multiply the std with something approprate depending on the\n``nonlinearity`` in ``kaiming_normal_std`` according to what one have in the network.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Exercises\n1) Set up a model to solve CIFAR, where the base model have an accuracy of 80% or\n   higher, then test it with different priors and q_distributions. Test at least\n   two priors discussed above, but also try some other distribution like StudentT.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}