{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Use Pytorch Lightening\n\nPytorch-lightening handles a lot of the boring engineering code,\nallowing one to focus on the research code. See https://www.pytorchlightning.ai/ for more infromation.\nThe main benefits of pytorch-lightening are:\n - Models become hardware agnostic\n - Code is clear to read because engineering code is abstracted away\n - Easier to reproduce\n - Make fewer mistakes because lightning handles the tricky engineering\n - Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate\n\nSome of the functionality that is availing right out of the box with minimal setup is\n - 16 bit precision\n - Multi-GPU training\n - Multi-node training\n - Early stopping\n - Model checkpointing\n\nLets start of with some imports\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from borch import nn, infer, distributions\nimport torch\nimport pytorch_lightning as pl\nimport borch\nimport torch.nn.functional as F\n\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision.datasets import MNIST\nfrom torchvision import transforms\nfrom pytorch_lightning.callbacks import ModelCheckpoint"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here we want to show how to use pytorch_lightning to create a model for MNIST.\n\nThe main challenge with using pytorch-lightning is how to scale the vi_loss\nthe loss function normally looks something like\n`loss = infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/len(dataset))`\nand specifically in the `kl_scaling` we need both the length of the data set and of the batch.\n\nIn pytorch lightning they abstact away the trainloop and just includes a `training_step`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def training_step(self, batch, batch_idx):\n    (input, target) = batch\n    self.net.observe(classification=target)\n    borch.sample(self.net)\n    self.net(input)\n    return infer.vi_loss(**borch.pq_to_infer(self.net))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The issue here is that we don't get access to the data loader or the data set, just the batch.\nOne can of course set the dataset as a global variable and access it to get the length. But by\ndoing so we loose a generality of the model to handle different data sets etc.\n\nInstead we have introduced a data set `AddDatasetLength` in to borch, it simply wraps any\n`torch.utils.data.Dataset` and make it returns the length of the dataset with each element.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from borch.utils.data import AddDatasetLength\ntransform=transforms.Compose([\n    transforms.ToTensor(),\n    transforms.Normalize((0.1307,), (0.3081,))\n    ])\nmnist= MNIST('mnist',train=True, download=True, transform=transform)\ndata_set = AddDatasetLength(mnist)\ndata_loader = DataLoader(data_set, batch_size=3)\nprint(next(iter(data_loader)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now when we have access to the length of the dataset we can actually conduct a training_step\nusing borch and pytorch-lightening.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def training_step(self, batch, batch_idx):\n    ds_len, (x, target) = batch\n    self.net.observe(classification=target)\n    loss = 0\n    for _ in range(self.subsamples):\n        borch.sample(self)\n        self(x)\n        loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])\n    loss /= self.subsamples\n    loss /= x.shape[0]\n    self.net.observe(None)\n    return loss"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to keep the example clear, we will separate the network code `NeuralNetwork`\nand the pytorch_lightning code `LitModel`, and simply connect them using `self.net = NeuralNetwork()`\nin the init method of the `LitModel`.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class NeuralNetwork(nn.Module):\n    \"\"\"Our MNIST object detection network\"\"\"\n    def __init__(self):\n        super(NeuralNetwork, self).__init__(posterior=borch.posterior.Automatic())\n        self.flatten = nn.Flatten()\n        self.linear_relu_stack = nn.Sequential(\n            nn.Linear(28*28, 512),\n            nn.ReLU(),\n            nn.Linear(512, 512),\n            nn.ReLU(),\n            nn.Linear(512, 10),\n        )\n    def forward(self, x):\n        \"\"\"Run trough the network and construct the likelihood\"\"\"\n        x = self.flatten(x)\n        logits = self.linear_relu_stack(x)\n        self.classification = distributions.Categorical(logits=logits)\n        return self.classification\n\nclass LitModel(pl.LightningModule):\n    \"\"\"The pytorch_lightning module that helps with the training\"\"\"\n    def __init__(self, subsamples=2, learning_rate=0.001, batch_size=128):\n        \"\"\"Lets just enables some nice hyper parameters\"\"\"\n        super().__init__()\n        self.net = NeuralNetwork()\n        self.subsamples = subsamples\n        self.learning_rate = learning_rate\n        self.batch_size = batch_size\n    def forward(self, x):\n        \"\"\"We just run trough the network\"\"\"\n        return self.net(x)\n    def training_step(self, batch, batch_idx, prefix='train'):\n        ds_len, (x, target) = batch\n        self.net.observe(classification=target)\n        loss = 0\n        for _ in range(self.subsamples):\n            borch.sample(self)\n            self(x)\n            # We scale the loss with how big part of the dataset we run trough\n            loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])\n        loss /= self.subsamples\n        loss /= x.shape[0]\n        self.net.observe(None)\n        self.log(f\"{prefix}_loss\", loss)\n        acc = (self.net.prior.classification.tensor == target).long().sum()/len(target)\n        self.log(f\"{prefix}_accuracy\", acc)\n        return loss\n    def validation_step(self, batch, batch_idx):\n        \"\"\"Use the train step as the validation step\"\"\"\n        return self.training_step(batch, batch_idx, 'validation')\n    def test_step(self, batch, batch_idx):\n        \"\"\"Use the train step as the test step\"\"\"\n        return self.training_step(batch, batch_idx, 'test')\n    def configure_optimizers(self):\n        \"\"\"Set up the optimizers\"\"\"\n        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n    def load_state_dict(self, state_dict, strict=False):\n        \"\"\"Change the strict=True default\"\"\"\n        # Given how we store some intermediate samples we might end up with\n        # keys in the statedic missing, and some of the pytorch-lightening functions\n        # do not allow us to control the `strict` argument. So the easiest\n        # thing is to change the default here\n        return super().load_state_dict(state_dict, strict=strict)\n    def train_dataloader(self):\n        \"\"\"Train data loader\"\"\"\n        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n        mnist_train = MNIST('data/',train=True, download=True, transform=transform)\n        return DataLoader(AddDatasetLength(mnist_train), batch_size=self.batch_size, num_workers=8)\n    def val_dataloader(self):\n        \"\"\"Train dataloader as validation data loader\"\"\"\n        return self.train_dataloader()\n    def test_dataloader(self):\n        \"\"\"Validation dataloader as tests data loader\"\"\"\n        return self.validation_dataloader()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Here we will use the train_accuracy as what to use for the checkpointing\nbut it would be better to ex. use `validation_accuracy`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "checkpoint_callback = ModelCheckpoint(monitor=\"train_accuracy\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We construct a `trainer` object that holds all the configuration relating to training\nsee the docs from pytorch-lightening for all the possible settings.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "trainer = pl.Trainer(\n    fast_dev_run=5, # just run 5 batches as a quick tests\n    max_epochs= 5, # maximum epochs, note that `fast_dev_run` will make it stop sooner\n    min_epochs = 1, # minimum epochs to run, note that `fast_dev_run` will make it stop sooner\n    # precision=16, # use 16 bit precisions, requires that the hardware supports it.\n    gradient_clip_val=5, # use gradient clipping\n    # gradient_clip_algorithm=\"value\", # gradient clipping algorithm, standard is to clip based on the norm\n    # stochastic_weight_avg=True,\n    auto_lr_find=True, # automatically find the best learning_rate (\"not guaranteed to work)\")\n    # auto_scale_batch_size= \"binsearch\", # find as big batch_size as possible before one gets memory errors, thus one can achieve higher GPU utilization\n    callbacks=[checkpoint_callback],\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The first step is to `tune` the trainer such it can find the best learning_rate and \nbatch_size if that setting is enabled.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = LitModel()\ntrainer.tune(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When the tuning is done, we can go over to the fitting.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "trainer.fit(model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The metrics we logged can now be seen in tensorboard by using the command `tensorboard --logdir ./lightning_logs`\n\nRemember to load the best checkpoint.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "if checkpoint_callback.best_model_path:\n    model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path, strict=False)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}