{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Hamiltonain Monte Carlo\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "#\n# The `infer` package provides Hamiltonian Monte Carlo (HMC) as a sampling method.\n# HMC is an Markov chain Monte Carlo (MCMC) method used to sample from probability\n# distributions. We also have support for the No-U-Turn Sampler (NUTS), that is\n# an improvement on HMC. Note that both the HMC and NUTS implementations have not\n# gone trough rigorous use and is considered experimental.\n#\n# Let start with the inital imports\n\nimport matplotlib\nmatplotlib.use(\"Agg\")\nimport matplotlib.pyplot as plt\n\nimport borch\nfrom borch import  Module, distributions as dist\nfrom borch.utils.state_dict import add_state_dict_to_state, sample_state\nfrom borch.posterior import PointMass\nfrom borch.infer.model_conversion import model_to_neg_log_prob_closure\nfrom borch.infer.hmc import hmc_step\nfrom borch.infer.nuts import nuts_step, dual_averaging, find_reasonable_epsilon\nfrom borch.utils.torch_utils import detach_tensor_dict"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In this example we are going to fit a `Gamma` and a `Normal` distribution, where the\nmodel is written using borch.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def model(latent):\n    latent.weight1 = dist.HalfNormal(.5)\n    latent.weight2 = dist.Normal(loc=1, scale=2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to use hmc from alvis.infer, the log_joint should be provided as a closure,\ni.e. a function or python callable that takes no arguments. The parameters used for\nhmc also need to be in the unconstraind space, this is easy to do with borch.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "trace = PointMass()\nlatent = borch.Module(posterior=trace)\n\n\ndef model_call():\n    borch.sample(latent)\n    model(latent)\n\n\nclosure = model_to_neg_log_prob_closure(model_call, latent)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to call hmc, parameters needs to be provided as a list, so in order to get\nthem for a borch model one has to run the closure once to get them. `hmc_step` takes\n`epsilon` and `L` as arguments. `epsilon` is how big steps the HMC uses when it \nexplores the space and `L` is the number of leepfrog steps HMC tries to make before it\nfinishes.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "state = []\nclosure()\nsamples = {}\nfor i in range(100):\n    accept_prob = hmc_step(\n        epsilon=.1,\n        L=10,\n        parameters=trace.parameters(),\n        closure=closure\n    )\n    state = add_state_dict_to_state(latent, state)\n    samples[i] = detach_tensor_dict(dict(borch.named_random_variables(latent)))\nplt.hist([samp['posterior.weight2'] for samp in samples.values()], bins=100)\nplt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If one want to restore a current sample such that one can generate from the mode\none simply does\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for _ in range(1):\n    sample_state(latent, state)\n    borch.sample(latent)\n    pred = model(latent)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In a similar fashion one can use The No-U-Turn Sampler (NUTS) an extension\nto Hamiltonian Monte Carlo that eliminates the need to set a number of steps\nleapfrog steps. Empirically, NUTS perform at least as efficiently as and \nsometimes more efficiently than a well tuned standard HMC method, without \nrequiring user intervention.\n\nNote that the NUTS implementations have not gone trough rigorous use and\nis considered experimental.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "samples_nuts = {}\ninital_epsilon, epsilon_bar, h_bar = find_reasonable_epsilon(\n                                        trace.parameters(), closure)\nepsilon = inital_epsilon\nfor i in range(1, 100):\n    accept_prob = nuts_step(.1, trace.parameters(), closure)\n    if i < 25: # the warmup stage\n        epsilon, epsilon_bar, h_bar = dual_averaging(accept_prob, i,\n                                                inital_epsilon,\n                                                epsilon_bar, h_bar)\n    else:\n        epsilon = epsilon_bar\n    samples_nuts[i] = detach_tensor_dict(dict(borch.module.named_random_variables(latent)))\nplt.hist([samp['posterior.weight2'] for samp in samples_nuts.values()], bins=10)\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}