{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# The Borch Graph\n\nA core component to ``borch`` is ``borch.Graph``, it is the foundation on which the\n``RandomVariable``s are build up on. But it is useful for many other things as well.\nA ``Graph`` is a ``borch.Module`` that also can act as a tensor, where graph``s\nforward takes no arguments and returns a single tensor. This tensor is stored with the graph\nand the graph itself can act as the tensor.\n\nThis may all sound a bit abstract so here we will show a few ways it can be used.\nLets show a basic example where we have an unconstrained parameter but want to\nconstrain it when we use it in a model.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import borch\nimport torch\n\nclass Exp(borch.Graph):\n    'Apply the exp transform'\n    def __init__(self, param):\n        super().__init__()\n        self.register_param_or_buffer(\"param\", param)\n    def forward(self):\n        return torch.exp(self.param)\nparam = torch.nn.Parameter(torch.zeros(1))\nexp = Exp(param)\nprint(exp*1)\nprint(list(exp.parameters()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here we basically bundled the logic for the transform in to one object\nsuch we can minimize some book keeping in some situation. We use this in ``borch``\na lot when we create approximating distributions in the posteriors.\nLike\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "rv = borch.distributions.Normal(torch.ones(1), exp)\nprint(list(rv.parameters()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since this is a common use case for us we have ``borch.Transform`` that can be \nused for this like\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "exp2 = borch.Transform(torch.exp, param)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "One thing to keep in mind is how to update/refresh the graph\nwhen the computation has been done, it will not be changed until\n``borch.sample`` has been called on the module or any parent module.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "param.data += 1\nprint(exp)\nborch.sample(exp)\nprint(exp)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Since the `Graph` can be used just like a tensor, one can easaly\nuse it as a drop in for a tensor or parameter.\n\nUsing that it opens up a very useful pattern for us if one wants to use\na specific approximating distribution when writing a model and one does\nnot want to create a custom posterior for it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class RVPair(borch.Graph):\n    \"\"\"\n    Provide a prior and the corresponding approximating\n    distribution.\n\n    This is useful when one wants a custom approximating\n    distribution.\n    \"\"\"\n\n    def __init__(self, p_dist, q_dist):\n        posterior = borch.posterior.Manual()\n        posterior.distribution = q_dist\n        super().__init__(posterior=posterior)\n        self.distribution = p_dist\n\n    def forward(self):\n        \"\"\"The forward\"\"\"\n        return self.distribution"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "A Complete example of how one can utilize both ``borch.RVPair`` and ``borch.Transform``\ncan be illustrated using a basic linear regression\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import borch.distributions as dist\nclass LinearRegression(borch.Module):\n    def __init__(self):\n        super().__init__()\n        self.a = dist.Normal(0, 3)\n        self.b = RVPair(\n            dist.Normal(0, 3),\n            dist.Normal(\n                torch.nn.Parameter(torch.zeros(1)),\n                borch.Transform(torch.exp, torch.nn.Parameter(torch.zeros(1)))\n            )\n        )\n        # Lets constrain sigma to be positive using `exp`\n        self.sigma = borch.Transform(torch.exp, dist.Normal(-.5, .4))\n\n    def forward(self, x):\n        mu = self.b * x + self.a\n        self.y = dist.Normal(mu, self.sigma)\n        return self.y, mu\n\n\n\nimport numpy as np\n# Lets generate some fake data to use\ndef generate_dataset(n=100):\n    x = np.linspace(0, 10, n)\n    y = 2*x+4+np.random.normal(0, 2, n)\n    return torch.tensor(y, dtype=torch.float32), torch.tensor(x, dtype=torch.float32)\n\ny, x = generate_dataset(10)\nmodel = LinearRegression()\nmodel.observe(y=y)\noptimizer=torch.optim.Adam(model.parameters(), lr=0.01, amsgrad=True)\nsubsamples = 10\nfor i in range(500):\n    optimizer.zero_grad()\n    loss = 0\n    for _ in range(subsamples):\n        borch.sample(model)\n        yhat, mu = model(x)\n        loss += borch.infer.vi_loss(**borch.pq_to_infer(model))\n    loss.backward()\n    torch.nn.utils.clip_grad_value_(model.parameters(), 2)\n    optimizer.step()\n\n    if i % 100 == 0: \n        print(\"Loss: {}\".format(loss))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Lets look at the predictions\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import matplotlib\nmatplotlib.use(\"TkAgg\")\nimport matplotlib.pyplot as plt\nmodel.observe(None)\npreds, loc = [], []\nfor i in range(20):\n    borch.sample(model)\n    ynew, mu = model(x)\n    ynew, mu = ynew.detach().numpy(), mu.detach().numpy()\n    preds.append(ynew)\n    loc.append(mu)\n    plt.plot(x, mu, 'blue', linewidth=2.0)\nmean_pred = np.stack(preds).mean(0)\nmean_loc= np.stack(loc).mean(0)\nplt.plot(x, mean_loc, 'g', label='MeanLoc', linewidth=5)\nplt.scatter(x, y, color='r', label='Actual', s=100)\nplt.xlabel('x')\nplt.ylabel('y')\nplt.legend()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}