{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Introduction to Borch\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Borch's universal borch allows the creation of probabilistic models with arbitrary\ncontrol flow. The core components of the borch are ``borch.RandomVariable`` and\n``borch.nn.Module``.\n\nLets start of with the imports:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\n\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\nimport torch\n\nimport borch\nfrom borch import infer, distributions as dist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## RandomVariable\nThe `borch.RandomVariable` merges a ``torch.distributions.Distribution``, ``torch.nn.Module`` and a\n``torch.tensor``, it acts like a tensor and can be used just like one,\nbut it also support methods such as ``.log_prob()``, ``.entropy()``, ``.sample()``,\n``.rsample()`` etc. like a  ``torch.distributions.Distribution``. It also support \nmethods like ``.paramaters()``, ``.children()`` etc. lake a ``torch.nn.Module``.\n\nA random variable is instanciated as a `borch.distributions`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rvar = dist.Normal(0, 1)\nprint(rvar)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Everytime time the random variable is called the value of the random variable gets updated\nwith a sample from the distribution.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(rvar())\nprint(rvar)\nprint(rvar())\nprint(rvar)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The tensor that represent the value of the random variable is accessible via `.tensor`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(rvar.tensor)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It can be used just like a normal tensor\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(rvar * 100)\nprint(rvar * torch.randn(10))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The distribution that the ``borch.RandomVariable`` is initialized with, is accessible\n the method ``.distribution()``. The method on the ``borch.RandomVariable`` differs\nsightly form that of a ``torch.distributions.Distribution`` in that feeds in its\nown tensor as the input if no args are provided.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(rvar.log_prob())\nrvar.log_prob() == rvar.log_prob(rvar.tensor)\nprint(rvar.log_prob(torch.zeros(1)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It also supports sampling(`.sample()`) and reparameterized sampeling(`.rsample()`)\nif available. Note that this does not update the value of the ``RandomVariable``, only\ncalling the ``RandomVariable`` update its value.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rvar.sample()\nprint(rvar)\n\nplt.hist([rvar.sample().item() for i in range(1000)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Module\nThe ``borch.nn.Module`` is an object that supports attaching and book keeping of\n``borch.RandomVariable``'s. It also got a posterior that specifies how the approximating\ndistributions will look. It is the recommended practice to write models in two ways,\neither the same object oriented design as one does with ``torch.nn`` Modules.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Model(borch.Module):\n    def __init__(self):\n        module.weight1 = dist.Gamma(1, 1 / 2)\n        module.weight2 = dist.Normal(loc=1, scale=2)\n        module.weight3 = dist.Normal(loc=1, scale=2)\n\n    def forward(module):\n        mu = module.weight1 + module.weight2 + module.weight3\n        module.obs = dist.Normal(mu, 1)\n        return mu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Or as functions that have a ``borch.nn.Module`` as a first argument.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def forward(module):\n    module.weight1 = dist.Gamma(1, 1 / 2)\n    module.weight2 = dist.Normal(loc=1, scale=2)\n    module.weight3 = dist.Normal(loc=1, scale=2)\n    mu = module.weight1 + module.weight2 + module.weight3\n    module.obs = dist.Normal(mu, 1)\n    return mu"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Depending on what type of model the different syntax will be more suited then\nthe other. Worth noting is that placing the instantiation of random variables\nin the `__init__` method will result in less overhead then creating them for every\ncall.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "By feeding in a ``borch.nn.Module`` to the model function, the weights will be added\nto the ``borch.nn.Module`` object in place. The method ``borch.pq_to_infer(model)``\nconverts the RandomVariables that are attached to the Model object into a dict with\nlists, that contains p_dist, q_dist and value that can be used in the infer package.\n\nIn order to access all the parameters that we want to optimize, we run trough the\nmodel once before creating the optimizer.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "module = borch.Module()\nborch.sample(module) # this will sample all `RandomVariable`s in the network\nforward(module)\noptimizer = torch.optim.Adam(module.parameters())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Fitting a model using the infer package looks like this:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for ii in range(10):\n    optimizer.zero_grad()\n    borch.sample(module)\n    forward(module)\n    loss = infer.vi_loss(**borch.pq_to_infer(module))\n    loss.backward()\n    optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then the fitted ``borch.nn.Module`` object can be used to generate samples. Simply by\nusing it in the model function i.e.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for _ in range(5):\n    print(forward(module))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Condition\nFor convince it would be better to specify what to condition(observe) on\nafter the model is created and also allowing one to switch what to condition on.\nthis can be done with\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "module.observe(weight2=torch.zeros(1), obs=torch.ones(1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "It will continue to condition on the values until others are provided,\none can stop all conditioning by specifying\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "module.observe(None)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Exercises\n1) Create three ``RandomVariable``'s one with a discrete distribution, a continuous\n   distribution that is unconstrained and a continuous distribution that has a\n   positive support. Test the different methods on them like:\n\n   - ``.samlpe()``\n   - ``.rsamlpe()``\n   - ``.log_prob()``\n   - ``.entropy()``\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}