{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Bayesian Graph Neural Networks\n\n`borch` is designed to effortlessly integrate with other other `torch` projects.\nIn order to use borch with other pytorch projects it comes down to have some\n`borch.Module` in the model, call `borch.sample(model)` and update the loss function.\n\nHere we show how we can use `borch` with `torch_geometric`, we build a simple graph\nconvolutional neural network for the `Cora` dataset.\n\nFirst some imports, and fetching of the dataset.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch.nn.functional as F\nfrom torch_geometric.datasets import Planetoid\nfrom torch_geometric import nn\nimport borch\n\ndataset = Planetoid(root=\"/tmp/Cora\", name=\"Cora\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "In order to build neural networks as normal we would like to control\nwhat is a `borch` module and what is the normal `torch` module.\n\nTo do that we create a `borch` version of the `torch_geometric.nn` module\nsuch that all modules created using this module is the `borch` version\nof the module.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "bnn = borch.nn.borchify_namespace(nn)\nprint(isinstance(bnn.GCNConv(2, 3), borch.nn.Module))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can create a small convolutional neural network where the first\nlayer `conv1` is the non bayesian version of the module and `conv2`\nis the bayesian version\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class GCN(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.conv1 = nn.GCNConv(dataset.num_node_features, 16)\n        self.conv2 = bnn.GCNConv(16, dataset.num_classes)\n\n    def forward(self, data):\n        x, edge_index = data.x, data.edge_index\n        x = self.conv1(x, edge_index)\n        x = F.relu(x)\n        x = F.dropout(x, training=self.training)\n        x = self.conv2(x, edge_index)\n        return F.log_softmax(x, dim=1)\n\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nmodel = GCN().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The Cora dataset in an in-memory dataset and as such there is only one graph and one entry in the dataset\nso we get just that one item from the dataset and use that to train the network. Before moving on to training the network\nwe'll have a look at some of the characteristics of the dataset.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "data = dataset[0].to(device)\nprint(f\"Dataset: {dataset}:\")\nprint(\"======================\")\nprint(f\"Number of graphs: {len(dataset)}\")\nprint(f\"Number of features: {dataset.num_features}\")\nprint(f\"Number of classes: {dataset.num_classes}\")\nprint(f\"Number of nodes: {data.num_nodes}\")\nprint(f\"Number of edges: {data.num_edges}\")\nprint(f\"Average node degree: {data.num_edges / data.num_nodes:.2f}\")\nprint(f\"Number of training nodes: {data.train_mask.sum()}\")\nprint(f\"Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}\")\nprint(f\"Contains isolated nodes: {data.has_isolated_nodes()}\")\nprint(f\"Contains self-loops: {data.has_self_loops()}\")\nprint(f\"Is undirected: {data.is_undirected()}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "If you would like to see a representation of this data visually you can use networkx to do it in the following manner:\n\nimport matplotlib.pyplot as plt\nimport networkx as nx\nfrom torch_geometric.utils import to_networkx\nG = to_networkx(data, to_undirected=True)\nnx.draw(G)\nplt.draw()\n\nNow let's get on with the training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=5e-4)\nmodel.train()\nfor epoch in range(200):\n    optimizer.zero_grad()\n    borch.sample(model)  # remember to sample the network\n    model(data)\n    out = model(data)\n    # update the loss, note that we set the reduction to 'sum' in order to balance the loss\n    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask], reduction=\"sum\")\n    loss += borch.infer.vi_loss(**borch.pq_to_infer(model), kl_scaling=1)\n    loss.backward()\n    optimizer.step()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Lets see how the fit looks like by running a few predictions\nwhere we sample between each predictions.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model.eval()\nfor _ in range(10):\n    borch.sample(model)\n    pred = model(data).argmax(dim=1)\n    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()\n    acc = int(correct) / int(data.test_mask.sum())\n    print(\"Accuracy: {:.4f}\".format(acc))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}