"""
Use Pytorch Lightening
======================

Pytorch-lightening handles a lot of the boring engineering code,
allowing one to focus on the research code. See https://www.pytorchlightning.ai/ for more infromation.
The main benefits of pytorch-lightening are:
 - Models become hardware agnostic
 - Code is clear to read because engineering code is abstracted away
 - Easier to reproduce
 - Make fewer mistakes because lightning handles the tricky engineering
 - Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate

Some of the functionality that is availing right out of the box with minimal setup is
 - 16 bit precision
 - Multi-GPU training
 - Multi-node training
 - Early stopping
 - Model checkpointing

Lets start of with some imports
"""
from borch import nn, infer, distributions
import torch
import pytorch_lightning as pl
import borch
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.callbacks import ModelCheckpoint

########################################################
# Here we want to show how to use pytorch_lightning to create a model for MNIST.
# 
# The main challenge with using pytorch-lightning is how to scale the vi_loss
# the loss function normally looks something like
# `loss = infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/len(dataset))`
# and specifically in the `kl_scaling` we need both the length of the data set and of the batch.
# 
# In pytorch lightning they abstact away the trainloop and just includes a `training_step`

def training_step(self, batch, batch_idx):
    (input, target) = batch
    self.net.observe(classification=target)
    borch.sample(self.net)
    self.net(input)
    return infer.vi_loss(**borch.pq_to_infer(self.net))


########################################################
# The issue here is that we don't get access to the data loader or the data set, just the batch.
# One can of course set the dataset as a global variable and access it to get the length. But by
# doing so we loose a generality of the model to handle different data sets etc.
#
# Instead we have introduced a data set `AddDatasetLength` in to borch, it simply wraps any
# `torch.utils.data.Dataset` and make it returns the length of the dataset with each element.
from borch.utils.data import AddDatasetLength
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
mnist= MNIST('mnist',train=True, download=True, transform=transform)
data_set = AddDatasetLength(mnist)
data_loader = DataLoader(data_set, batch_size=3)
print(next(iter(data_loader)))

#######################################################
# Now when we have access to the length of the dataset we can actually conduct a training_step
# using borch and pytorch-lightening.
def training_step(self, batch, batch_idx):
    ds_len, (x, target) = batch
    self.net.observe(classification=target)
    loss = 0
    for _ in range(self.subsamples):
        borch.sample(self)
        self(x)
        loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])
    loss /= self.subsamples
    loss /= x.shape[0]
    self.net.observe(None)
    return loss

########################################################
# In order to keep the example clear, we will separate the network code `NeuralNetwork`
# and the pytorch_lightning code `LitModel`, and simply connect them using `self.net = NeuralNetwork()`
# in the init method of the `LitModel`.

class NeuralNetwork(nn.Module):
    """Our MNIST object detection network"""
    def __init__(self):
        super(NeuralNetwork, self).__init__(posterior=borch.posterior.Automatic())
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    def forward(self, x):
        """Run trough the network and construct the likelihood"""
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        self.classification = distributions.Categorical(logits=logits)
        return self.classification

class LitModel(pl.LightningModule):
    """The pytorch_lightning module that helps with the training"""
    def __init__(self, subsamples=2, learning_rate=0.001, batch_size=128):
        """Lets just enables some nice hyper parameters"""
        super().__init__()
        self.net = NeuralNetwork()
        self.subsamples = subsamples
        self.learning_rate = learning_rate
        self.batch_size = batch_size
    def forward(self, x):
        """We just run trough the network"""
        return self.net(x)
    def training_step(self, batch, batch_idx, prefix='train'):
        ds_len, (x, target) = batch
        self.net.observe(classification=target)
        loss = 0
        for _ in range(self.subsamples):
            borch.sample(self)
            self(x)
            # We scale the loss with how big part of the dataset we run trough
            loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])
        loss /= self.subsamples
        loss /= x.shape[0]
        self.net.observe(None)
        self.log(f"{prefix}_loss", loss)
        acc = (self.net.prior.classification.tensor == target).long().sum()/len(target)
        self.log(f"{prefix}_accuracy", acc)
        return loss
    def validation_step(self, batch, batch_idx):
        """Use the train step as the validation step"""
        return self.training_step(batch, batch_idx, 'validation')
    def test_step(self, batch, batch_idx):
        """Use the train step as the test step"""
        return self.training_step(batch, batch_idx, 'test')
    def configure_optimizers(self):
        """Set up the optimizers"""
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    def load_state_dict(self, state_dict, strict=False):
        """Change the strict=True default"""
        # Given how we store some intermediate samples we might end up with
        # keys in the statedic missing, and some of the pytorch-lightening functions
        # do not allow us to control the `strict` argument. So the easiest
        # thing is to change the default here
        return super().load_state_dict(state_dict, strict=strict)
    def train_dataloader(self):
        """Train data loader"""
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        mnist_train = MNIST('data/',train=True, download=True, transform=transform)
        return DataLoader(AddDatasetLength(mnist_train), batch_size=self.batch_size, num_workers=8)
    def val_dataloader(self):
        """Train dataloader as validation data loader"""
        return self.train_dataloader()
    def test_dataloader(self):
        """Validation dataloader as tests data loader"""
        return self.validation_dataloader()

###############################################################
# Here we will use the train_accuracy as what to use for the checkpointing
# but it would be better to ex. use `validation_accuracy`
checkpoint_callback = ModelCheckpoint(monitor="train_accuracy")
###############################################################
# We construct a `trainer` object that holds all the configuration relating to training
# see the docs from pytorch-lightening for all the possible settings.
trainer = pl.Trainer(
    fast_dev_run=5, # just run 5 batches as a quick tests
    max_epochs= 5, # maximum epochs, note that `fast_dev_run` will make it stop sooner
    min_epochs = 1, # minimum epochs to run, note that `fast_dev_run` will make it stop sooner
    # precision=16, # use 16 bit precisions, requires that the hardware supports it.
    gradient_clip_val=5, # use gradient clipping
    # gradient_clip_algorithm="value", # gradient clipping algorithm, standard is to clip based on the norm
    # stochastic_weight_avg=True,
    auto_lr_find=True, # automatically find the best learning_rate ("not guaranteed to work)")
    # auto_scale_batch_size= "binsearch", # find as big batch_size as possible before one gets memory errors, thus one can achieve higher GPU utilization
    callbacks=[checkpoint_callback],
)
################################################################
# The first step is to `tune` the trainer such it can find the best learning_rate and 
# batch_size if that setting is enabled.
model = LitModel()
trainer.tune(model)
################################################################
# When the tuning is done, we can go over to the fitting.
trainer.fit(model)
#######################################
# The metrics we logged can now be seen in tensorboard by using the command `tensorboard --logdir ./lightning_logs`
#
# Remember to load the best checkpoint.
if checkpoint_callback.best_model_path:
    model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path, strict=False)
