"""Bayesian Graph Neural Networks
=================================

`borch` is designed to effortlessly integrate with other other `torch` projects.
In order to use borch with other pytorch projects it comes down to have some
`borch.Module` in the model, call `borch.sample(model)` and update the loss function.

Here we show how we can use `borch` with `torch_geometric`, we build a simple graph
convolutional neural network for the `Cora` dataset.

First some imports, and fetching of the dataset.
"""

import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric import nn
import borch

dataset = Planetoid(root="/tmp/Cora", name="Cora")

#######################################
# In order to build neural networks as normal we would like to control
# what is a `borch` module and what is the normal `torch` module.
#
# To do that we create a `borch` version of the `torch_geometric.nn` module
# such that all modules created using this module is the `borch` version
# of the module.

bnn = borch.nn.borchify_namespace(nn)
print(isinstance(bnn.GCNConv(2, 3), borch.nn.Module))


#######################################
# Now we can create a small convolutional neural network where the first
# layer `conv1` is the non bayesian version of the module and `conv2`
# is the bayesian version


class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.GCNConv(dataset.num_node_features, 16)
        self.conv2 = bnn.GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN().to(device)

####################################
# The Cora dataset in an in-memory dataset and as such there is only one graph and one entry in the dataset
# so we get just that one item from the dataset and use that to train the network. Before moving on to training the network
# we'll have a look at some of the characteristics of the dataset.

data = dataset[0].to(device)
print(f"Dataset: {dataset}:")
print("======================")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
print(f"Average node degree: {data.num_edges / data.num_nodes:.2f}")
print(f"Number of training nodes: {data.train_mask.sum()}")
print(f"Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}")
print(f"Contains isolated nodes: {data.has_isolated_nodes()}")
print(f"Contains self-loops: {data.has_self_loops()}")
print(f"Is undirected: {data.is_undirected()}")

####################################
# If you would like to see a representation of this data visually you can use networkx to do it in the following manner:
#
# import matplotlib.pyplot as plt
# import networkx as nx
# from torch_geometric.utils import to_networkx
# G = to_networkx(data, to_undirected=True)
# nx.draw(G)
# plt.draw()
#
# Now let's get on with the training.

optimizer = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    borch.sample(model)  # remember to sample the network
    model(data)
    out = model(data)
    # update the loss, note that we set the reduction to 'sum' in order to balance the loss
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask], reduction="sum")
    loss += borch.infer.vi_loss(**borch.pq_to_infer(model), kl_scaling=1)
    loss.backward()
    optimizer.step()


##################################
# Lets see how the fit looks like by running a few predictions
# where we sample between each predictions.
model.eval()
for _ in range(10):
    borch.sample(model)
    pred = model(data).argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    print("Accuracy: {:.4f}".format(acc))
