{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Just-in-time (JIT) compilation\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We will cover how one can use torch.jit together with pytorch\n\n\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import io\nimport torch\nimport borch\nfrom borch import distributions as dist, nn, as_tensor"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to use jit functions with `RandomVariable`s one needs to\nmanually send in the just the torch tesnor.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "@torch.jit.script\ndef my_function(x):\n    if x.sum() > 10:\n        return x\n    return x**2\n\nrv = dist.StudentT(1, torch.tensor([20., 30.]), 4)\nprint(my_function(rv.tensor))\nprint(my_function(as_tensor(rv)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In normal usage this is not a big deal as `getattr` from\na `borch.Module` will only give the tensor anyways.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = borch.Module()\nmodel.rv = rv\nprint(my_function(model.rv))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "At the time of this writing, calling torch.jit.trace on a borch.Module\ndoes not work as one would hope. It basically freezes the network at the \ncurrent sample and will not generate new ones. So to get around this one\nneeds to add a forward hook that triggers a resample\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Perceptron(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc1 = nn.Linear(3,3)\n        self.relu = torch.nn.ReLU()\n    def forward(self, x):\n        x = self.fc1(x)\n        x = self.relu(x)\n        return x\n\n\nnet = Perceptron()\nfor _ in range(2):\n    borch.sample(net)\n    print(net(torch.ones(2, 3)))\n\ndef trigger_sample(net, input):\n    borch.sample(net)\nnet.register_forward_pre_hook(trigger_sample)\n\ntraced_net = torch.jit.trace(net, torch.ones(2, 3),check_trace=False)\nfor _ in range(3):\n    traced_net(torch.ones(2, 3))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Sadly there is no onnx support at this time, this is due to some of the opperators\n`torch.distributions.Distribution` use are not supported by onnx at this time.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    torch.onnx.export(net, torch.ones(2, 3), io.BytesIO())\nexcept Exception as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Also at the time of this writing, calling torch.jit.script on a borch.Module\ndoes not work\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "try:\n    net_jit = torch.jit.script(net)\n    net_jit.sample()\nexcept Exception as e:\n    print(e)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}