{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Posteriors\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The concept of posteriors is as important as the concept of modules in the borch\nframework.\nThe posteriors job is to create ``RandomVariable`` s for approximating distributions.\nWhenever we add a ``RandomVariable`` to a ``Module``, the posterior will pick it up and\ncreate an approximating distribution for it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import numpy as np\nimport torch\nfrom torch import optim\n\nimport borch\nfrom borch import Module, posterior, distributions as dist, infer"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The default posterior for instantiating a module is the ``Automatic`` posterior that will infer\nfrom the prior what approximating distribution to use.\nWhenever we assign a ``RandomVariable`` to a\nmodule, that ``RandomVariable`` becomes the `prior` for that attribute.\nWhenever we assign a\nnew prior on a module, the posterior will pick up on it and use this random variable to\ncreate an approximating distribution for it.\nBy changing the posterior, we change the way the approximating\ndistribution is created. The ``Normalposterior`` creates a normal posterior centered around\na sample from the prior distribution.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "module = Module(posterior=posterior.Normal())\nmodule.rv = dist.StudentT(4, 0, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The prior of the variable ``rv`` is a normal distribution; Normal(0, 1). The\napproximating distribution however is created around a mean, which is a sample from\nthe prior distribution. For stable training the ``Normalposterior`` instantiates\nthis normal distribution with a ``log_scale == -3``. This is because having very wide\napproximating distributions gives very high variance gradients at first in training,\nand in order to have a clear gradient at the beginning of training, we make the\ndistribution narrow. Over the course of training however, to find the equilibrium\nbetween divergence and negative log likelihood, we expect the width of these\napproximating distributions to widen.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(f\"Prior of rv {module.prior.rv.distribution}\")\nprint(f\"Posterior of rv {module.posterior.rv.distribution}\")\nprint(\n    f\"the scale comes from exp(-3): {np.exp(-3)}, and the location is a sample from \"\n    f\"the prior. \"\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The optimisable parameters of the module are the posterior parameters and can be\naccessed by `.parameters()`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\n      f\"optimisable parameters of module: {list(module.parameters())}\"\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The first parameter is the loc and the second parameter is the log_scale of\nthe q-distribution. We can update the priors of the, model without having to change\nthe learned posterior distributions.\nWe have a few different posteriors available:\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we will see how to use manual posteriors in borch. Manual posteriors allows one to freely\nspecifying the approximating distribution, where the control flow can be\ndifferent compared to the model.\nFor simplicity we will infer the loc and scale of a normal.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def forward(mod):\n    mod.test = dist.Normal(5, 1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The posterior is specified in the same way. In order for the parameters to be\nlearnable we need to define them as ``torch.Parameters``. The ``Parameter`` wrapper\nfor tensors just lets the framework know that it should be returned when calling\n``model.parameters()``, and thus it is a convenient way to pass them to an\noptimiser.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "man_posterior = posterior.Manual()\nman_posterior.mean = torch.nn.Parameter(torch.ones(1))\nman_posterior.sd = torch.nn.Parameter(torch.ones(1))\n\n\ndef forward_posterior(posterior):\n    scale = torch.exp(posterior.sd)+0.01\n    mean = posterior.mean.abs()\n    posterior.test = dist.Normal(mean, scale)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When a ``torch.Parameter`` is added to the posterior, it will be accessible in\n``.parameters()``\nor in the `.parameters()` of the model, thus enabling us to optimise them.\n\nWhen running inference with a manual posterior, one have to run the posterior before the\nmodel each time to reninstantiate the distributions on it using the learned\nparameters.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "forward_posterior(man_posterior)\n\nlatent = Module(posterior=man_posterior)\noptimizer = optim.Adam(latent.parameters(), lr=.1)\n\nfor _ in range(500):\n    loss = 0\n    for _ in range(10):\n        forward_posterior(man_posterior)\n        forward(latent)\n        loss += infer.vi_loss(**borch.pq_to_infer(latent))\n    loss.backward()\n    torch.nn.utils.clip_grad_norm_(latent.parameters(), 1)\n    optimizer.step()\n\nprint('mean: ', man_posterior.mean.item())\nprint('sd: ', torch.exp(man_posterior.sd).item())"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}