The Borch GraphΒΆ

A core component to borch is borch.Graph, it is the foundation on which the RandomVariable``s are build up on. But it is useful for many other things as well. A ``Graph is a borch.Module that also can act as a tensor, where graph``s forward takes no arguments and returns a single tensor. This tensor is stored with the graph and the graph itself can act as the tensor.

This may all sound a bit abstract so here we will show a few ways it can be used. Lets show a basic example where we have an unconstrained parameter but want to constrain it when we use it in a model.

import borch
import torch

class Exp(borch.Graph):
    'Apply the exp transform'
    def __init__(self, param):
        super().__init__()
        self.register_param_or_buffer("param", param)
    def forward(self):
        return torch.exp(self.param)
param = torch.nn.Parameter(torch.zeros(1))
exp = Exp(param)
print(exp*1)
print(list(exp.parameters()))

Out:

tensor([1.], grad_fn=<MulBackward0>)
[Parameter containing:
tensor([0.], requires_grad=True)]

Here we basically bundled the logic for the transform in to one object such we can minimize some book keeping in some situation. We use this in borch a lot when we create approximating distributions in the posteriors. Like

rv = borch.distributions.Normal(torch.ones(1), exp)
print(list(rv.parameters()))

Out:

[Parameter containing:
tensor([0.], requires_grad=True)]

Since this is a common use case for us we have borch.Transform that can be used for this like

exp2 = borch.Transform(torch.exp, param)

One thing to keep in mind is how to update/refresh the graph when the computation has been done, it will not be changed until borch.sample has been called on the module or any parent module.

param.data += 1
print(exp)
borch.sample(exp)
print(exp)

Out:

Exp:
 tensor([1.], grad_fn=<ExpBackward0>)
Exp:
 tensor([2.7183], grad_fn=<ExpBackward0>)

Since the Graph can be used just like a tensor, one can easaly use it as a drop in for a tensor or parameter.

Using that it opens up a very useful pattern for us if one wants to use a specific approximating distribution when writing a model and one does not want to create a custom posterior for it.

class RVPair(borch.Graph):
    """
    Provide a prior and the corresponding approximating
    distribution.

    This is useful when one wants a custom approximating
    distribution.
    """

    def __init__(self, p_dist, q_dist):
        posterior = borch.posterior.Manual()
        posterior.distribution = q_dist
        super().__init__(posterior=posterior)
        self.distribution = p_dist

    def forward(self):
        """The forward"""
        return self.distribution

A Complete example of how one can utilize both borch.RVPair and borch.Transform can be illustrated using a basic linear regression

import borch.distributions as dist
class LinearRegression(borch.Module):
    def __init__(self):
        super().__init__()
        self.a = dist.Normal(0, 3)
        self.b = RVPair(
            dist.Normal(0, 3),
            dist.Normal(
                torch.nn.Parameter(torch.zeros(1)),
                borch.Transform(torch.exp, torch.nn.Parameter(torch.zeros(1)))
            )
        )
        # Lets constrain sigma to be positive using `exp`
        self.sigma = borch.Transform(torch.exp, dist.Normal(-.5, .4))

    def forward(self, x):
        mu = self.b * x + self.a
        self.y = dist.Normal(mu, self.sigma)
        return self.y, mu



import numpy as np
# Lets generate some fake data to use
def generate_dataset(n=100):
    x = np.linspace(0, 10, n)
    y = 2*x+4+np.random.normal(0, 2, n)
    return torch.tensor(y, dtype=torch.float32), torch.tensor(x, dtype=torch.float32)

y, x = generate_dataset(10)
model = LinearRegression()
model.observe(y=y)
optimizer=torch.optim.Adam(model.parameters(), lr=0.01, amsgrad=True)
subsamples = 10
for i in range(500):
    optimizer.zero_grad()
    loss = 0
    for _ in range(subsamples):
        borch.sample(model)
        yhat, mu = model(x)
        loss += borch.infer.vi_loss(**borch.pq_to_infer(model))
    loss.backward()
    torch.nn.utils.clip_grad_value_(model.parameters(), 2)
    optimizer.step()

    if i % 100 == 0:
        print("Loss: {}".format(loss))

Out:

Loss: 73243.046875
Loss: 2297.2294921875
Loss: 406.4707336425781
Loss: 345.68798828125
Loss: 334.0998229980469

Lets look at the predictions

import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
model.observe(None)
preds, loc = [], []
for i in range(20):
    borch.sample(model)
    ynew, mu = model(x)
    ynew, mu = ynew.detach().numpy(), mu.detach().numpy()
    preds.append(ynew)
    loc.append(mu)
    plt.plot(x, mu, 'blue', linewidth=2.0)
mean_pred = np.stack(preds).mean(0)
mean_loc= np.stack(loc).mean(0)
plt.plot(x, mean_loc, 'g', label='MeanLoc', linewidth=5)
plt.scatter(x, y, color='r', label='Actual', s=100)
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
plot graph

Out:

/home/docs/checkouts/readthedocs.org/user_builds/borch/checkouts/latest/tutorials/plot_graph.py:137: UserWarning:
This call to matplotlib.use() has no effect because the backend has already
been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.

The backend was *originally* set to 'agg' by the following code:
  File "/home/docs/.pyenv/versions/3.7.9/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/docs/.pyenv/versions/3.7.9/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/__main__.py", line 15, in <module>
    sys.exit(main(sys.argv[1:]))
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/cmd/build.py", line 290, in main
    return build_main(argv)
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/cmd/build.py", line 275, in build_main
    args.tags, args.verbosity, args.jobs, args.keep_going)
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/application.py", line 276, in __init__
    self._init_builder()
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/application.py", line 337, in _init_builder
    self.events.emit('builder-inited')
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/events.py", line 103, in emit
    results.append(callback(self.app, *args))
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 426, in generate_gallery_rst
    gallery_conf = parse_config(app)
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 124, in parse_config
    check_keys)
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 244, in _complete_gallery_conf
    _import_matplotlib()
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/scrapers.py", line 59, in _import_matplotlib
    import matplotlib.pyplot as plt
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/matplotlib/pyplot.py", line 71, in <module>
    from matplotlib.backends import pylab_setup
  File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/matplotlib/backends/__init__.py", line 16, in <module>
    line for line in traceback.format_stack()


  matplotlib.use("TkAgg")

Total running time of the script: ( 0 minutes 13.879 seconds)

Gallery generated by Sphinx-Gallery