# -*- coding: utf-8 -*-
"""
Training an Image classifier
============================
You will learn the basics of how to create an image classifier using the
`borch.nn` package and fit it using the `infer` package.


"""
####################################
# Lets start of with importing what we need

import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F

#################################################################
# The module ``borch.nn`` provides implementations of neural network modules that are used
# for deep probabilistic programming. It provides an interface almost identical to
# the torch.nn modules and in many cases it is possible to just switch

from torch import nn

###############################################################
# to

from borch import nn

############################################################################
# Data
# ----------
# In this example we will use simulated data and not run the fitting until convergence,
# but show how the model is set up and how one can construct the training loop.
# We will just generate some random data, where ``data`` represent the image and
# ``target`` is the class.

data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)


##############################################################################
# Model
# ----------
# Lets set up the model.
# In order to use `infer` and the ``borch`` to the fullest, we need to select a
# a likelihood distribution. For classification the `distributions.Categorical`
# is suitable.


class Net(borch.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=borch.posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 2)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return self.classification

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

##############################################################################
# Fit the model
# -------------
# Finally we can set up the training loop

optim = torch.optim.Adam(net.parameters())
for i in range(1):
    for data, target in loader:
        net.observe(classification=target)
        borch.sample(net)
        net(data)
        loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
        loss.backward()
        optim.step()
        optim.zero_grad()

###########################################################################
# Now we can check the accuracy, Note that one should stop condtioning on the
# target by setting `net.observe(None)`

net.observe(None)
tot_acc = 0
with torch.no_grad():
    for i, (data, target) in enumerate(loader):
        borch.sample(net)
        out = net(data)
        acc = float((target == out).sum().float() / target.shape[0]) * 100
        tot_acc += acc
    tot_acc /= i + 1
print(tot_acc)

#########################################################################
# the accuracy is basically random, this is due to the fact that we are fitting white
# noise so it to be expected.
#
# But in case you have trouble getting higher accuracy you should consider
# running for more epochs, setting up an augmentation pipeline (see:
# the data loading tutorial) and changing your posterior.
# The posterior can be changed using

net.apply(borch.set_posteriors(borch.posterior.Automatic))

#######################################################################
# One can also set the posterior when one creates the module
nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))

#######################################################################
# See the ``borch.posterior`` documentation for other posteriors and what parameters
# you can set. Note that all posteriors does not work with all parameters but you can
# have different posteriors for the different ``borch.Module``'s in your network.

#############################################################################
# Exercises
# ----------
# 1) Use what you have learned to train an image classifier for MNIST, you should
#    achieve an accuracy larger than 98 %.
#    Note: you can access MNST using ``torchvision.datasets.MNIST``.
#
# 2) Fit the same model architecture with normal torch and compare the likelihood
#    with the borch network, What are the differences and why?
#
# 3) Port the model to CIFAR and see how you can improve the accuracy.
#
# 4) Show how the `Categorical` distribution is related to the cross entropy loss
#    function that is commonly used in frequentest deep learning.
