"""
Posteriors
======
"""

###############################################################
# The concept of posteriors is as important as the concept of modules in the borch
# framework.
# The posteriors job is to create ``RandomVariable`` s for approximating distributions.
# Whenever we add a ``RandomVariable`` to a ``Module``, the posterior will pick it up and
# create an approximating distribution for it.

import numpy as np
import torch
from torch import optim

import borch
from borch import Module, posterior, distributions as dist, infer


#########################################################
# The default posterior for instantiating a module is the ``Automatic`` posterior that will infer
# from the prior what approximating distribution to use.
# Whenever we assign a ``RandomVariable`` to a
# module, that ``RandomVariable`` becomes the `prior` for that attribute.
# Whenever we assign a
# new prior on a module, the posterior will pick up on it and use this random variable to
# create an approximating distribution for it.
# By changing the posterior, we change the way the approximating
# distribution is created. The ``Normalposterior`` creates a normal posterior centered around
# a sample from the prior distribution.

module = Module(posterior=posterior.Normal())
module.rv = dist.StudentT(4, 0, 1)


#########################################################
# The prior of the variable ``rv`` is a normal distribution; Normal(0, 1). The
# approximating distribution however is created around a mean, which is a sample from
# the prior distribution. For stable training the ``Normalposterior`` instantiates
# this normal distribution with a ``log_scale == -3``. This is because having very wide
# approximating distributions gives very high variance gradients at first in training,
# and in order to have a clear gradient at the beginning of training, we make the
# distribution narrow. Over the course of training however, to find the equilibrium
# between divergence and negative log likelihood, we expect the width of these
# approximating distributions to widen.


print(f"Prior of rv {module.prior.rv.distribution}")
print(f"Posterior of rv {module.posterior.rv.distribution}")
print(
    f"the scale comes from exp(-3): {np.exp(-3)}, and the location is a sample from "
    f"the prior. "
)


#########################################################
# The optimisable parameters of the module are the posterior parameters and can be
# accessed by `.parameters()`

print(
      f"optimisable parameters of module: {list(module.parameters())}"
)

#########################################################
# The first parameter is the loc and the second parameter is the log_scale of
# the q-distribution. We can update the priors of the, model without having to change
# the learned posterior distributions.
# We have a few different posteriors available:


#########################################################
# Now we will see how to use manual posteriors in borch. Manual posteriors allows one to freely
# specifying the approximating distribution, where the control flow can be
# different compared to the model.
# For simplicity we will infer the loc and scale of a normal.

def forward(mod):
    mod.test = dist.Normal(5, 1)


######################################################
# The posterior is specified in the same way. In order for the parameters to be
# learnable we need to define them as ``torch.Parameters``. The ``Parameter`` wrapper
# for tensors just lets the framework know that it should be returned when calling
# ``model.parameters()``, and thus it is a convenient way to pass them to an
# optimiser.

man_posterior = posterior.Manual()
man_posterior.mean = torch.nn.Parameter(torch.ones(1))
man_posterior.sd = torch.nn.Parameter(torch.ones(1))


def forward_posterior(posterior):
    scale = torch.exp(posterior.sd)+0.01
    mean = posterior.mean.abs()
    posterior.test = dist.Normal(mean, scale)

######################################################
# When a ``torch.Parameter`` is added to the posterior, it will be accessible in
# ``.parameters()``
# or in the `.parameters()` of the model, thus enabling us to optimise them.
#
# When running inference with a manual posterior, one have to run the posterior before the
# model each time to reninstantiate the distributions on it using the learned
# parameters.


forward_posterior(man_posterior)

latent = Module(posterior=man_posterior)
optimizer = optim.Adam(latent.parameters(), lr=.1)

for _ in range(500):
    loss = 0
    for _ in range(10):
        forward_posterior(man_posterior)
        forward(latent)
        loss += infer.vi_loss(**borch.pq_to_infer(latent))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(latent.parameters(), 1)
    optimizer.step()

print('mean: ', man_posterior.mean.item())
print('sd: ', torch.exp(man_posterior.sd).item())
