"""
The Borch Graph
==========================

A core component to ``borch`` is ``borch.Graph``, it is the foundation on which the
``RandomVariable``s are build up on. But it is useful for many other things as well.
A ``Graph`` is a ``borch.Module`` that also can act as a tensor, where graph``s
forward takes no arguments and returns a single tensor. This tensor is stored with the graph
and the graph itself can act as the tensor.

This may all sound a bit abstract so here we will show a few ways it can be used.
Lets show a basic example where we have an unconstrained parameter but want to
constrain it when we use it in a model.
"""
import borch
import torch

class Exp(borch.Graph):
    'Apply the exp transform'
    def __init__(self, param):
        super().__init__()
        self.register_param_or_buffer("param", param)
    def forward(self):
        return torch.exp(self.param)
param = torch.nn.Parameter(torch.zeros(1))
exp = Exp(param)
print(exp*1)
print(list(exp.parameters()))

#############################################
# Here we basically bundled the logic for the transform in to one object
# such we can minimize some book keeping in some situation. We use this in ``borch``
# a lot when we create approximating distributions in the posteriors.
# Like
rv = borch.distributions.Normal(torch.ones(1), exp)
print(list(rv.parameters()))

#############################################
# Since this is a common use case for us we have ``borch.Transform`` that can be 
# used for this like
exp2 = borch.Transform(torch.exp, param)

#############################################
# One thing to keep in mind is how to update/refresh the graph
# when the computation has been done, it will not be changed until
# ``borch.sample`` has been called on the module or any parent module.
param.data += 1
print(exp)
borch.sample(exp)
print(exp)


#############################################
# Since the `Graph` can be used just like a tensor, one can easaly
# use it as a drop in for a tensor or parameter.
#
# Using that it opens up a very useful pattern for us if one wants to use
# a specific approximating distribution when writing a model and one does
# not want to create a custom posterior for it.

class RVPair(borch.Graph):
    """
    Provide a prior and the corresponding approximating
    distribution.

    This is useful when one wants a custom approximating
    distribution.
    """

    def __init__(self, p_dist, q_dist):
        posterior = borch.posterior.Manual()
        posterior.distribution = q_dist
        super().__init__(posterior=posterior)
        self.distribution = p_dist

    def forward(self):
        """The forward"""
        return self.distribution

###########################################
# A Complete example of how one can utilize both ``borch.RVPair`` and ``borch.Transform``
# can be illustrated using a basic linear regression

import borch.distributions as dist
class LinearRegression(borch.Module):
    def __init__(self):
        super().__init__()
        self.a = dist.Normal(0, 3)
        self.b = RVPair(
            dist.Normal(0, 3),
            dist.Normal(
                torch.nn.Parameter(torch.zeros(1)),
                borch.Transform(torch.exp, torch.nn.Parameter(torch.zeros(1)))
            )
        )
        # Lets constrain sigma to be positive using `exp`
        self.sigma = borch.Transform(torch.exp, dist.Normal(-.5, .4))

    def forward(self, x):
        mu = self.b * x + self.a
        self.y = dist.Normal(mu, self.sigma)
        return self.y, mu



import numpy as np
# Lets generate some fake data to use
def generate_dataset(n=100):
    x = np.linspace(0, 10, n)
    y = 2*x+4+np.random.normal(0, 2, n)
    return torch.tensor(y, dtype=torch.float32), torch.tensor(x, dtype=torch.float32)

y, x = generate_dataset(10)
model = LinearRegression()
model.observe(y=y)
optimizer=torch.optim.Adam(model.parameters(), lr=0.01, amsgrad=True)
subsamples = 10
for i in range(500):
    optimizer.zero_grad()
    loss = 0
    for _ in range(subsamples):
        borch.sample(model)
        yhat, mu = model(x)
        loss += borch.infer.vi_loss(**borch.pq_to_infer(model))
    loss.backward()
    torch.nn.utils.clip_grad_value_(model.parameters(), 2)
    optimizer.step()

    if i % 100 == 0: 
        print("Loss: {}".format(loss))


####################################################################
# Lets look at the predictions

import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
model.observe(None)
preds, loc = [], []
for i in range(20):
    borch.sample(model)
    ynew, mu = model(x)
    ynew, mu = ynew.detach().numpy(), mu.detach().numpy()
    preds.append(ynew)
    loc.append(mu)
    plt.plot(x, mu, 'blue', linewidth=2.0)
mean_pred = np.stack(preds).mean(0)
mean_loc= np.stack(loc).mean(0)
plt.plot(x, mean_loc, 'g', label='MeanLoc', linewidth=5)
plt.scatter(x, y, color='r', label='Actual', s=100)
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
