"""
Just-in-time (JIT) compilation
==============================
"""
####################################################
# We will cover how one can use torch.jit together with pytorch
#
#
# 
import io
import torch
import borch
from borch import distributions as dist, nn, as_tensor

###################################################
# In order to use jit functions with `RandomVariable`s one needs to
# manually send in the just the torch tesnor.

@torch.jit.script
def my_function(x):
    if x.sum() > 10:
        return x
    return x**2

rv = dist.StudentT(1, torch.tensor([20., 30.]), 4)
print(my_function(rv.tensor))
print(my_function(as_tensor(rv)))

###################################################
# In normal usage this is not a big deal as `getattr` from
# a `borch.Module` will only give the tensor anyways.
model = borch.Module()
model.rv = rv
print(my_function(model.rv))


#############################################################
# At the time of this writing, calling torch.jit.trace on a borch.Module
# does not work as one would hope. It basically freezes the network at the 
# current sample and will not generate new ones. So to get around this one
# needs to add a forward hook that triggers a resample

class Perceptron(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3,3)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        return x


net = Perceptron()
for _ in range(2):
    borch.sample(net)
    print(net(torch.ones(2, 3)))

def trigger_sample(net, input):
    borch.sample(net)
net.register_forward_pre_hook(trigger_sample)

traced_net = torch.jit.trace(net, torch.ones(2, 3),check_trace=False)
for _ in range(3):
    traced_net(torch.ones(2, 3))

######################################################
# Sadly there is no onnx support at this time, this is due to some of the opperators
# `torch.distributions.Distribution` use are not supported by onnx at this time.
try:
    torch.onnx.export(net, torch.ones(2, 3), io.BytesIO())
except Exception as e:
    print(e)

#############################################################
# Also at the time of this writing, calling torch.jit.script on a borch.Module
# does not work
try:
    net_jit = torch.jit.script(net)
    net_jit.sample()
except Exception as e:
    print(e)

