Note
Click here to download the full example code
The Borch GraphΒΆ
A core component to borch
is borch.Graph
, it is the foundation on which the
RandomVariable``s are build up on. But it is useful for many other things as well.
A ``Graph
is a borch.Module
that also can act as a tensor, where graph``s
forward takes no arguments and returns a single tensor. This tensor is stored with the graph
and the graph itself can act as the tensor.
This may all sound a bit abstract so here we will show a few ways it can be used. Lets show a basic example where we have an unconstrained parameter but want to constrain it when we use it in a model.
import borch
import torch
class Exp(borch.Graph):
'Apply the exp transform'
def __init__(self, param):
super().__init__()
self.register_param_or_buffer("param", param)
def forward(self):
return torch.exp(self.param)
param = torch.nn.Parameter(torch.zeros(1))
exp = Exp(param)
print(exp*1)
print(list(exp.parameters()))
Out:
tensor([1.], grad_fn=<MulBackward0>)
[Parameter containing:
tensor([0.], requires_grad=True)]
Here we basically bundled the logic for the transform in to one object
such we can minimize some book keeping in some situation. We use this in borch
a lot when we create approximating distributions in the posteriors.
Like
rv = borch.distributions.Normal(torch.ones(1), exp)
print(list(rv.parameters()))
Out:
[Parameter containing:
tensor([0.], requires_grad=True)]
Since this is a common use case for us we have borch.Transform
that can be
used for this like
exp2 = borch.Transform(torch.exp, param)
One thing to keep in mind is how to update/refresh the graph
when the computation has been done, it will not be changed until
borch.sample
has been called on the module or any parent module.
param.data += 1
print(exp)
borch.sample(exp)
print(exp)
Out:
Exp:
tensor([1.], grad_fn=<ExpBackward0>)
Exp:
tensor([2.7183], grad_fn=<ExpBackward0>)
Since the Graph can be used just like a tensor, one can easaly use it as a drop in for a tensor or parameter.
Using that it opens up a very useful pattern for us if one wants to use a specific approximating distribution when writing a model and one does not want to create a custom posterior for it.
class RVPair(borch.Graph):
"""
Provide a prior and the corresponding approximating
distribution.
This is useful when one wants a custom approximating
distribution.
"""
def __init__(self, p_dist, q_dist):
posterior = borch.posterior.Manual()
posterior.distribution = q_dist
super().__init__(posterior=posterior)
self.distribution = p_dist
def forward(self):
"""The forward"""
return self.distribution
A Complete example of how one can utilize both borch.RVPair
and borch.Transform
can be illustrated using a basic linear regression
import borch.distributions as dist
class LinearRegression(borch.Module):
def __init__(self):
super().__init__()
self.a = dist.Normal(0, 3)
self.b = RVPair(
dist.Normal(0, 3),
dist.Normal(
torch.nn.Parameter(torch.zeros(1)),
borch.Transform(torch.exp, torch.nn.Parameter(torch.zeros(1)))
)
)
# Lets constrain sigma to be positive using `exp`
self.sigma = borch.Transform(torch.exp, dist.Normal(-.5, .4))
def forward(self, x):
mu = self.b * x + self.a
self.y = dist.Normal(mu, self.sigma)
return self.y, mu
import numpy as np
# Lets generate some fake data to use
def generate_dataset(n=100):
x = np.linspace(0, 10, n)
y = 2*x+4+np.random.normal(0, 2, n)
return torch.tensor(y, dtype=torch.float32), torch.tensor(x, dtype=torch.float32)
y, x = generate_dataset(10)
model = LinearRegression()
model.observe(y=y)
optimizer=torch.optim.Adam(model.parameters(), lr=0.01, amsgrad=True)
subsamples = 10
for i in range(500):
optimizer.zero_grad()
loss = 0
for _ in range(subsamples):
borch.sample(model)
yhat, mu = model(x)
loss += borch.infer.vi_loss(**borch.pq_to_infer(model))
loss.backward()
torch.nn.utils.clip_grad_value_(model.parameters(), 2)
optimizer.step()
if i % 100 == 0:
print("Loss: {}".format(loss))
Out:
Loss: 73243.046875
Loss: 2297.2294921875
Loss: 406.4707336425781
Loss: 345.68798828125
Loss: 334.0998229980469
Lets look at the predictions
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
model.observe(None)
preds, loc = [], []
for i in range(20):
borch.sample(model)
ynew, mu = model(x)
ynew, mu = ynew.detach().numpy(), mu.detach().numpy()
preds.append(ynew)
loc.append(mu)
plt.plot(x, mu, 'blue', linewidth=2.0)
mean_pred = np.stack(preds).mean(0)
mean_loc= np.stack(loc).mean(0)
plt.plot(x, mean_loc, 'g', label='MeanLoc', linewidth=5)
plt.scatter(x, y, color='r', label='Actual', s=100)
plt.xlabel('x')
plt.ylabel('y')
plt.legend()
plt.show()
Out:
/home/docs/checkouts/readthedocs.org/user_builds/borch/checkouts/latest/tutorials/plot_graph.py:137: UserWarning:
This call to matplotlib.use() has no effect because the backend has already
been chosen; matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.
The backend was *originally* set to 'agg' by the following code:
File "/home/docs/.pyenv/versions/3.7.9/lib/python3.7/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/home/docs/.pyenv/versions/3.7.9/lib/python3.7/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/__main__.py", line 15, in <module>
sys.exit(main(sys.argv[1:]))
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/cmd/build.py", line 290, in main
return build_main(argv)
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/cmd/build.py", line 275, in build_main
args.tags, args.verbosity, args.jobs, args.keep_going)
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/application.py", line 276, in __init__
self._init_builder()
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/application.py", line 337, in _init_builder
self.events.emit('builder-inited')
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx/events.py", line 103, in emit
results.append(callback(self.app, *args))
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 426, in generate_gallery_rst
gallery_conf = parse_config(app)
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 124, in parse_config
check_keys)
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/gen_gallery.py", line 244, in _complete_gallery_conf
_import_matplotlib()
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/sphinx_gallery/scrapers.py", line 59, in _import_matplotlib
import matplotlib.pyplot as plt
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/matplotlib/pyplot.py", line 71, in <module>
from matplotlib.backends import pylab_setup
File "/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/matplotlib/backends/__init__.py", line 16, in <module>
line for line in traceback.format_stack()
matplotlib.use("TkAgg")
Total running time of the script: ( 0 minutes 13.879 seconds)