{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Training an Image classifier\nYou will learn the basics of how to create an image classifier using the\n`borch.nn` package and fit it using the `infer` package.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Lets start of with importing what we need\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom torch.utils.data import TensorDataset, DataLoader\nimport borch\nfrom borch import infer, distributions\nimport torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The module ``borch.nn`` provides implementations of neural network modules that are used\nfor deep probabilistic programming. It provides an interface almost identical to\nthe torch.nn modules and in many cases it is possible to just switch\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from torch import nn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "to\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from borch import nn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Data\nIn this example we will use simulated data and not run the fitting until convergence,\nbut show how the model is set up and how one can construct the training loop.\nWe will just generate some random data, where ``data`` represent the image and\n``target`` is the class.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = torch.randn(20, 1, 32, 32)\nlabels = torch.randperm(2).repeat(10)\ndata_set = TensorDataset(data, labels)\nloader = DataLoader(data_set, batch_size=20)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Model\nLets set up the model.\nIn order to use `infer` and the ``borch`` to the fullest, we need to select a\na likelihood distribution. For classification the `distributions.Categorical`\nis suitable.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Net(borch.Module):\n    def __init__(self):\n        super(Net, self).__init__(posterior=borch.posterior.Automatic())\n        # 1 input image channel, 6 output channels, 5x5 square convolution\n        # kernel\n        self.conv1 = nn.Conv2d(1, 6, 5)\n        self.conv2 = nn.Conv2d(6, 16, 5)\n        # an affine operation: y = Wx + b\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n        self.fc2 = nn.Linear(120, 2)\n\n    def forward(self, x):\n        # Max pooling over a (2, 2) window\n        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n        # If the size is a square you can only specify a single number\n        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n        x = x.view(-1, self.num_flat_features(x))\n        x = F.relu(self.fc1(x))\n        x = self.fc2(x)\n        # Specifying the likelihood function\n        self.classification = distributions.Categorical(logits=x)\n        return self.classification\n\n    def num_flat_features(self, x):\n        size = x.size()[1:]  # all dimensions except the batch dimension\n        num_features = 1\n        for s in size:\n            num_features *= s\n        return num_features\n\n\nnet = Net()\nprint(net)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Fit the model\nFinally we can set up the training loop\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "optim = torch.optim.Adam(net.parameters())\nfor i in range(1):\n    for data, target in loader:\n        net.observe(classification=target)\n        borch.sample(net)\n        net(data)\n        loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))\n        loss.backward()\n        optim.step()\n        optim.zero_grad()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can check the accuracy, Note that one should stop condtioning on the\ntarget by setting `net.observe(None)`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net.observe(None)\ntot_acc = 0\nwith torch.no_grad():\n    for i, (data, target) in enumerate(loader):\n        borch.sample(net)\n        out = net(data)\n        acc = float((target == out).sum().float() / target.shape[0]) * 100\n        tot_acc += acc\n    tot_acc /= i + 1\nprint(tot_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "the accuracy is basically random, this is due to the fact that we are fitting white\nnoise so it to be expected.\n\nBut in case you have trouble getting higher accuracy you should consider\nrunning for more epochs, setting up an augmentation pipeline (see:\nthe data loading tutorial) and changing your posterior.\nThe posterior can be changed using\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "net.apply(borch.set_posteriors(borch.posterior.Automatic))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "One can also set the posterior when one creates the module\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "See the ``borch.posterior`` documentation for other posteriors and what parameters\nyou can set. Note that all posteriors does not work with all parameters but you can\nhave different posteriors for the different ``borch.Module``'s in your network.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Exercises\n1) Use what you have learned to train an image classifier for MNIST, you should\n   achieve an accuracy larger than 98 %.\n   Note: you can access MNST using ``torchvision.datasets.MNIST``.\n\n2) Fit the same model architecture with normal torch and compare the likelihood\n   with the borch network, What are the differences and why?\n\n3) Port the model to CIFAR and see how you can improve the accuracy.\n\n4) Show how the `Categorical` distribution is related to the cross entropy loss\n   function that is commonly used in frequentest deep learning.\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}