Just-in-time (JIT) compilationΒΆ

We will cover how one can use torch.jit together with pytorch

import io
import torch
import borch
from borch import distributions as dist, nn, as_tensor

In order to use jit functions with `RandomVariable`s one needs to manually send in the just the torch tesnor.

@torch.jit.script
def my_function(x):
    if x.sum() > 10:
        return x
    return x**2

rv = dist.StudentT(1, torch.tensor([20., 30.]), 4)
print(my_function(rv.tensor))
print(my_function(as_tensor(rv)))

Out:

tensor([21161.6992,   844.1932])
tensor([21161.6992,   844.1932])

In normal usage this is not a big deal as getattr from a borch.Module will only give the tensor anyways.

model = borch.Module()
model.rv = rv
print(my_function(model.rv))

Out:

tensor([33.8703, 21.6021], grad_fn=<AddBackward0>)

At the time of this writing, calling torch.jit.trace on a borch.Module does not work as one would hope. It basically freezes the network at the current sample and will not generate new ones. So to get around this one needs to add a forward hook that triggers a resample

class Perceptron(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(3,3)
        self.relu = torch.nn.ReLU()
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        return x


net = Perceptron()
for _ in range(2):
    borch.sample(net)
    print(net(torch.ones(2, 3)))

def trigger_sample(net, input):
    borch.sample(net)
net.register_forward_pre_hook(trigger_sample)

traced_net = torch.jit.trace(net, torch.ones(2, 3),check_trace=False)
for _ in range(3):
    traced_net(torch.ones(2, 3))

Out:

tensor([[0.2484, 0.3649, 0.0000],
        [0.2484, 0.3649, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[0.1692, 0.2857, 0.0000],
        [0.1692, 0.2857, 0.0000]], grad_fn=<ReluBackward0>)
/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)
  return self._grad
/home/docs/checkouts/readthedocs.org/user_builds/borch/checkouts/latest/src/borch/graph.py:152: UserWarning: volatile was removed (Variable.volatile is always False)
  return getattr(self.tensor, prop)

Sadly there is no onnx support at this time, this is due to some of the opperators torch.distributions.Distribution use are not supported by onnx at this time.

try:
    torch.onnx.export(net, torch.ones(2, 3), io.BytesIO())
except Exception as e:
    print(e)

Out:

Exporting the operator normal to ONNX opset version 9 is not supported. Support for this operator was added in version 11, try exporting with this version.

Also at the time of this writing, calling torch.jit.script on a borch.Module does not work

try:
    net_jit = torch.jit.script(net)
    net_jit.sample()
except Exception as e:
    print(e)

Out:

/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/torch/_tensor.py:1013: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:417.)
  return self._grad
/home/docs/checkouts/readthedocs.org/user_builds/borch/checkouts/latest/src/borch/graph.py:152: UserWarning: volatile was removed (Variable.volatile is always False)
  return getattr(self.tensor, prop)
type object 'type' has no attribute '__globals__'

Total running time of the script: ( 0 minutes 1.304 seconds)

Gallery generated by Sphinx-Gallery