# -*- coding: utf-8 -*-
"""
Neural Networks
===============

Neural networks can be constructed using the ``borch.nn`` package.

Now that you've had a glimpse of ``autograd``, ``nn`` depends on ``autograd``
to define models and differentiate them. An ``nn.Module`` contains layers,
and a method ``forward(input)`` that returns the ``output``.

For example, look at this network that classifies digit images:

.. figure:: /_static/img/mnist.png
   :alt: convnet

   convnet

It is a simple feed-forward network. It takes an input, feeds it through
several layers one after the other, and then finally gives the output.

A typical training procedure for a neural network is as follows:

* Define a network that has some learnable parameters and/or randomVariables
* For each batch in a dataset, do:

  - Process the input data through the network
  - Compute the loss (how far is the output from being correct?)
  - Propagate gradients back into the network’s parameters
  - Update the weights of the network, typically using a simple update rule:
    ``weight = weight - learning_rate * gradient``

Define the network
------------------

Let’s define this network:
"""
import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)

        # 6 input channels, 16 output channels, 5x5 convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)

        # An affine operation: y = Wx + b
        # NB after two convolutional operations with 5x5 kernels and no padding,
        # the spatial dimension of an image with intial dimension 32x32 is
        # 5x5 (with 16 channels)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

########################################################################
# You just have to define the ``forward`` function, and the ``backward``
# function (where gradients are computed) is automatically defined for you
# using ``autograd``.
# You can use any of the Tensor operations in the ``forward`` function.
#
# The learnable parameters of a model are returned by ``net.parameters()``

params = list(net.parameters())
print(len(params))
print(params[0].size())

########################################################################
# Let try a random 32x32 input

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

########################################################################
# Zero the gradient buffers of all parameters and backprops with random
# gradients:
net.zero_grad()
out.backward(torch.randn(1, 10))


########################################################################
# .. note::
#
#     ``borch.nn`` only supports mini-batches. The entire ``borch.nn``
#     package only supports inputs that are a mini-batch of samples, and not
#     a single sample.
#
#     For example, ``nn.Conv2d`` will take in a 4D Tensor of
#     ``nSamples x nChannels x Height x Width``.
#
#     If you have a single sample, just use ``input.unsqueeze(0)`` to add
#     a fake batch dimension.
#
# Before proceeding further, let's recap all the classes you’ve seen so far.
#
# **Recap:**
#   -  ``torch.Tensor`` - A *multi-dimensional array* with support for autograd
#      operations like ``backward()``. Also *holds the gradient* w.r.t. the
#      tensor.
#   -  ``nn.Module`` - Neural network module. *Convenient way of
#      encapsulating parameters*, with helpers for moving them to GPU,
#      exporting, loading, etc.
#   -  ``nn.Parameter`` - A kind of Tensor, that is *automatically
#      registered as a parameter when assigned as an attribute to a*
#      ``Module``.
#   -  ``autograd.Function`` - Implements *forward and backward definitions
#      of an autograd operation*. Every ``Tensor`` operation, creates at
#      least a single ``Function`` node, that connects to functions that
#      created a ``Tensor`` and *encodes its history*.
#
# **At this point, we covered:**
#   -  Defining a neural network
#   -  Processing inputs and calling backward
#
# **Still Left:**
#   -  Computing the loss
#   -  Updating the weights of the network
#
# Loss Function
# -------------
# A loss function takes the (output, target) pair of inputs, and computes a
# value that estimates how far away the output is from the target.
#
# There are several different
# `loss functions <http://pytorch.org/docs/nn.html#loss-functions>`_ under the
# nn package .
# A simple loss is: ``nn.MSELoss`` which computes the mean-squared error
# between the input and the target.  They are how ever only equivalent to an maximum
# likelihood approach in deep learning.
#
#
# In order to infer the posterior of the weights and thus capture the uncertainty
# of the weights as well, we have to use the ``infer`` package. In this example we
# will use ``infer.vi_loss`` function that automatically creates the best loss function
# for variational inference given the latent variables in your model.
#
# Similar to how it's done for random varibles, we can also observe on the
# module using keyword arguments matching the names of the random variables we
# want to observe. This will add those random variables to the likelihood term
# and we will not infer the distribution over it.
# For example:

target = torch.randint(10, (1,))  # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)

########################################################################
# Now, if you would follow ``loss`` in the backward direction you will see a graph of
# computations that looks like this:
# ::
#
#     input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
#           -> view -> linear -> relu -> linear ->
#           -> loss
#
# So, when we call ``loss.backward()``, the whole graph is differentiated
# w.r.t. the loss, and all Tensors in the graph that has ``requires_grad=True``
# will have their ``.grad`` Tensor accumulated with the gradient.


########################################################################
# Backprop
# --------
# To backpropagate the error all we have to do is to ``loss.backward()``.
# You need to clear the existing gradients though, else gradients will be
# accumulated to existing gradients.
#
#
# Now we shall call ``loss.backward()``, and have a look at conv1's bias
# gradients before and after the backward.

net.zero_grad()  # zeroes the gradient buffers of all parameters

##############################################
# The value for the `loc` paramater of the approximating distribution of 
# ``conv1.bias`` zeroing the gradients is
print(net.conv1.posterior.bias.loc.grad)

loss.backward()

##############################################
# after calling backward the value is
print(net.conv1.posterior.bias.loc.grad)

########################################################################
# **The only thing left to learn is:**
#
#   - Updating the weights of the network
#
# Update the weights
# ------------------
# The simplest update rule used in practice is the Stochastic Gradient
# Descent (SGD):
#
#      ``weight = weight - learning_rate * gradient``
#
# We can implement this using simple python code:
#
# .. code:: python
#
#     learning_rate = 0.01
#     for f in net.parameters():
#         f.data.sub_(f.grad.data * learning_rate)
#
# However, as you use neural networks, you want to use various different
# update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc.
# To enable this, `torch` built a small package: ``torch.optim`` that
# implements all these methods. Using it is very simple:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
n_batch_epoch = 10  # number of batches per epoch usually len(dataloader)
optimizer.zero_grad()  # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step()  # Does the update

##############################################################################
# Exercises
# ----------
# 1) The neural network package contains various modules and loss functions
#    that form the building blocks of deep neural networks. Have a look at the
#    documentation to see what is available.
#
# 2) Try designing yor own feed forward networks with two different types of
#    non lineareties ex. relu
