Note
Click here to download the full example code
Training an Image classifier¶
You will learn the basics of how to create an image classifier using the borch.nn package and fit it using the infer package.
Lets start of with importing what we need
import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F
The module borch.nn provides implementations of neural network modules that are used
for deep probabilistic programming. It provides an interface almost identical to
the torch.nn modules and in many cases it is possible to just switch
from torch import nn
to
from borch import nn
Data¶
In this example we will use simulated data and not run the fitting until convergence,
but show how the model is set up and how one can construct the training loop.
We will just generate some random data, where data represent the image and
target is the class.
data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)
Model¶
Lets set up the model.
In order to use infer and the borch to the fullest, we need to select a
a likelihood distribution. For classification the distributions.Categorical
is suitable.
class Net(borch.Module):
def __init__(self):
super(Net, self).__init__(posterior=borch.posterior.Automatic())
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 2)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = self.fc2(x)
# Specifying the likelihood function
self.classification = distributions.Categorical(logits=x)
return self.classification
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Out:
Net(
(posterior): Automatic()
(prior): Module()
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-0.1567, 0.0194, 0.0769, -0.0673, -0.0263],
[-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
[-0.1941, 0.1049, -0.0666, 0.1288, 0.1553],
[-0.0577, 0.0011, 0.0145, -0.0518, -0.1235],
[-0.1586, 0.1436, -0.1999, -0.0364, -0.0872]]],
[[[-0.0281, 0.1596, -0.0365, -0.0633, -0.1375],
[ 0.0499, -0.0787, 0.1923, -0.1312, 0.0520],
[ 0.0351, -0.0756, -0.0147, 0.1359, 0.1740],
[ 0.0815, 0.0631, 0.0970, 0.0557, -0.1079],
[-0.0794, -0.0029, -0.1620, -0.1147, 0.0705]]],
[[[ 0.1613, 0.0524, -0.1935, 0.0940, -0.0838],
[-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
[ 0.1950, -0.1447, -0.1897, 0.1429, 0.0565],
[ 0.1752, 0.0891, -0.0210, -0.0461, 0.1512],
[ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],
[[[ 0.1134, 0.0826, -0.1147, -0.1228, -0.0040],
[-0.1923, 0.1244, 0.1793, 0.0183, -0.0433],
[-0.1840, -0.1691, 0.1184, 0.0151, -0.1467],
[-0.1828, 0.0363, 0.1660, -0.0660, -0.1319],
[-0.1644, 0.1835, -0.0681, -0.0800, 0.0668]]],
[[[ 0.0005, 0.1977, 0.1792, 0.0681, 0.1410],
[ 0.0081, -0.0216, -0.0456, -0.1985, 0.0049],
[-0.0568, -0.1405, 0.0604, -0.1020, 0.0084],
[-0.1587, -0.1756, 0.1148, 0.0872, -0.1307],
[ 0.1244, 0.0193, 0.0314, 0.0787, -0.1329]]],
[[[-0.1363, 0.0538, 0.0294, 0.0635, -0.0873],
[-0.0751, -0.0357, -0.0701, 0.0053, -0.1896],
[-0.0187, -0.0438, -0.0541, -0.1959, 0.0103],
[-0.1991, 0.0912, 0.1853, -0.0772, 0.0883],
[-0.1368, 0.0420, 0.0322, -0.1162, -0.0013]]]], requires_grad=True)
tensor: tensor([[[[-0.1464, 0.0288, 0.1006, -0.0143, -0.0394],
[-0.0963, -0.0275, -0.0248, -0.0250, -0.0185],
[-0.1034, 0.1739, -0.1335, 0.1078, 0.1196],
[-0.0705, -0.0096, -0.0125, -0.0423, 0.0023],
[-0.1915, 0.1488, -0.1672, -0.0354, -0.1115]]],
[[[-0.0702, 0.1708, -0.0110, -0.0176, -0.1317],
[ 0.1053, -0.0288, 0.2829, -0.0449, 0.0364],
[ 0.0569, -0.1476, -0.0458, 0.1452, 0.1287],
[ 0.1165, 0.0803, 0.0561, 0.0671, -0.1254],
[-0.0583, -0.0479, -0.0769, -0.0862, 0.0184]]],
[[[ 0.0326, 0.0916, -0.1446, 0.1062, -0.1284],
[-0.1355, -0.0286, 0.0212, -0.0864, -0.1912],
[ 0.2453, -0.0990, -0.2416, 0.1289, 0.0360],
[ 0.1220, 0.2028, -0.0364, -0.0095, 0.2029],
[ 0.0586, -0.1530, -0.1777, -0.1552, -0.0960]]],
[[[ 0.1117, 0.1044, -0.1346, -0.0786, -0.0155],
[-0.1331, 0.1088, 0.1338, 0.0570, -0.0485],
[-0.1223, -0.1886, 0.1482, 0.0753, -0.1289],
[-0.1578, 0.0648, 0.1323, 0.0027, -0.1692],
[-0.1471, 0.2222, -0.0199, -0.1038, -0.0085]]],
[[[-0.0257, 0.2232, 0.1758, -0.0237, 0.1489],
[-0.0155, 0.1008, -0.0155, -0.2084, 0.0323],
[-0.0005, -0.1623, 0.0254, -0.1075, -0.0103],
[-0.1946, -0.1818, 0.2106, 0.0592, -0.1350],
[ 0.1469, -0.0297, 0.0166, 0.0733, -0.1201]]],
[[[-0.1150, -0.0068, -0.0089, 0.0525, -0.0929],
[-0.0980, 0.0597, -0.0527, 0.0507, -0.2522],
[-0.0187, -0.0277, -0.0706, -0.2390, -0.0100],
[-0.1499, -0.0335, 0.1549, -0.1291, 0.0113],
[-0.0331, 0.1316, 0.0597, -0.1865, -0.0048]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.1428, -0.1690, -0.0722, 0.0667, 0.0933, -0.0337],
requires_grad=True)
tensor: tensor([-0.0361, -0.1412, -0.1430, 0.1021, 0.1712, 0.0340],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.]]],
[[[-0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.]]],
[[[0., 0., -0., 0., -0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., -0., 0.],
[0., -0., -0., -0., -0.]]],
[[[0., 0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., 0., -0., -0.],
[-0., 0., -0., -0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., 0., -0., 0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., 0., -0.]]],
[[[-0., 0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., 0., 0., -0., -0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-0.1567, 0.0194, 0.0769, -0.0673, -0.0263],
[-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
[-0.1941, 0.1049, -0.0666, 0.1288, 0.1553],
[-0.0577, 0.0011, 0.0145, -0.0518, -0.1235],
[-0.1586, 0.1436, -0.1999, -0.0364, -0.0872]]],
[[[-0.0281, 0.1596, -0.0365, -0.0633, -0.1375],
[ 0.0499, -0.0787, 0.1923, -0.1312, 0.0520],
[ 0.0351, -0.0756, -0.0147, 0.1359, 0.1740],
[ 0.0815, 0.0631, 0.0970, 0.0557, -0.1079],
[-0.0794, -0.0029, -0.1620, -0.1147, 0.0705]]],
[[[ 0.1613, 0.0524, -0.1935, 0.0940, -0.0838],
[-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
[ 0.1950, -0.1447, -0.1897, 0.1429, 0.0565],
[ 0.1752, 0.0891, -0.0210, -0.0461, 0.1512],
[ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],
[[[ 0.1134, 0.0826, -0.1147, -0.1228, -0.0040],
[-0.1923, 0.1244, 0.1793, 0.0183, -0.0433],
[-0.1840, -0.1691, 0.1184, 0.0151, -0.1467],
[-0.1828, 0.0363, 0.1660, -0.0660, -0.1319],
[-0.1644, 0.1835, -0.0681, -0.0800, 0.0668]]],
[[[ 0.0005, 0.1977, 0.1792, 0.0681, 0.1410],
[ 0.0081, -0.0216, -0.0456, -0.1985, 0.0049],
[-0.0568, -0.1405, 0.0604, -0.1020, 0.0084],
[-0.1587, -0.1756, 0.1148, 0.0872, -0.1307],
[ 0.1244, 0.0193, 0.0314, 0.0787, -0.1329]]],
[[[-0.1363, 0.0538, 0.0294, 0.0635, -0.0873],
[-0.0751, -0.0357, -0.0701, 0.0053, -0.1896],
[-0.0187, -0.0438, -0.0541, -0.1959, 0.0103],
[-0.1991, 0.0912, 0.1853, -0.0772, 0.0883],
[-0.1368, 0.0420, 0.0322, -0.1162, -0.0013]]]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., 0., -0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.1428, -0.1690, -0.0722, 0.0667, 0.0933, -0.0337])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
...,
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02, 1.6377e-02, 6.2508e-02],
[ 4.0202e-02, 4.9382e-02, 1.0253e-02, 3.4932e-02, -2.9341e-02],
[-6.4074e-02, 5.2063e-02, -3.8546e-02, 1.5865e-02, -4.8314e-02],
[ 2.0698e-02, 8.0888e-02, 7.9982e-02, 5.6686e-02, 1.1419e-02],
[-6.1541e-02, 6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],
[[ 8.0672e-02, -7.8271e-02, 7.9289e-02, 2.1951e-02, 5.9238e-02],
[ 1.1237e-02, -5.2598e-02, 4.1551e-02, 4.3546e-02, -5.8132e-02],
[ 3.0288e-02, -7.5402e-03, 2.8025e-02, 1.5065e-02, -3.7260e-02],
[-2.1855e-02, 4.4424e-02, -6.7250e-02, 4.3303e-02, -4.0667e-02],
[ 4.9321e-02, -2.0541e-02, 5.6232e-02, -3.8466e-02, 5.1026e-02]],
[[-2.6513e-02, 4.9569e-02, 1.8676e-02, 6.9054e-02, -2.6767e-02],
[-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02, 2.8787e-04],
[ 1.3084e-02, 4.6926e-02, -1.3894e-02, 2.7116e-02, -7.9155e-02],
[ 7.5190e-02, -2.2396e-02, 5.5935e-02, 6.8270e-02, -3.2401e-02],
[-1.0770e-02, 3.6370e-02, -7.0476e-02, 6.4846e-02, -6.7540e-02]],
[[-3.7310e-02, 5.9737e-02, 5.9160e-05, 3.5063e-02, 7.1916e-02],
[ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02, 4.9856e-02],
[-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
[ 9.3012e-04, -2.4376e-02, -1.1352e-03, 1.1929e-02, -1.3466e-02],
[ 5.6119e-02, 2.2037e-02, 6.6074e-02, 3.7063e-02, 4.2302e-02]],
[[-3.6769e-02, 3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
[-2.8501e-02, -8.0159e-02, -2.8122e-02, 1.8005e-02, -3.7241e-02],
[ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02, 2.2794e-02],
[-6.3064e-02, 6.8354e-02, 4.3337e-02, -4.3442e-02, -6.5524e-02],
[-5.6791e-02, 7.4918e-02, 5.9263e-02, -1.2248e-04, -3.5795e-02]],
[[-2.1802e-02, -9.4069e-03, 7.6862e-02, 1.4458e-02, -7.0315e-02],
[-5.9874e-02, 5.0751e-02, 3.9215e-02, 7.7538e-02, 3.2766e-02],
[-4.4479e-02, 5.2775e-02, -1.9053e-03, 3.2342e-02, -4.6321e-02],
[ 3.8624e-02, 2.0750e-02, 4.6860e-02, -5.3962e-02, 6.4275e-02],
[-6.5864e-02, -1.3428e-02, 1.7864e-02, -7.0872e-02, -4.7788e-02]]],
[[[-6.8434e-03, 1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
[-4.6419e-03, 7.9605e-02, 2.4481e-02, -1.8153e-02, -6.5677e-02],
[-4.4087e-02, -3.3966e-02, 7.7397e-02, -2.4133e-02, -2.1281e-02],
[ 7.2108e-02, 5.7730e-02, 1.8671e-02, -6.6956e-02, -7.8317e-02],
[ 5.1377e-02, 4.0833e-02, 1.3915e-02, -7.3671e-02, -6.4171e-02]],
[[-2.2992e-02, -1.5135e-02, -1.6468e-02, 3.4729e-02, -5.0686e-02],
[ 3.3976e-02, 4.7649e-02, 9.0078e-03, 6.1119e-02, 7.7885e-02],
[ 2.0299e-02, 4.7323e-02, 5.7702e-02, 4.8393e-02, -4.6623e-02],
[ 4.0535e-02, 5.1355e-02, -3.0590e-02, 7.3194e-02, -6.9639e-02],
[-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02, 1.5673e-02]],
[[-2.4125e-02, -2.1872e-02, -4.1021e-02, 1.5590e-02, 7.0267e-02],
[-1.2325e-02, 4.8418e-02, 3.7418e-02, 5.6973e-02, 1.5516e-02],
[-5.0112e-02, 4.1789e-02, 5.5392e-02, -3.5548e-02, 2.8206e-02],
[ 5.9003e-02, 7.3764e-03, 1.5419e-02, -1.6909e-02, -4.0654e-02],
[ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],
[[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
[-1.6499e-02, 5.8059e-02, 3.5353e-02, 9.1365e-03, -2.8343e-02],
[ 4.6293e-02, 3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
[ 8.0904e-02, 1.1077e-02, 4.2119e-02, 5.4343e-02, 2.1999e-02],
[ 6.2393e-02, 2.9035e-02, 2.6746e-02, -2.8318e-02, -4.9989e-02]],
[[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
[-7.4939e-03, 9.3801e-03, 1.3410e-02, 5.7410e-02, 4.4898e-03],
[-1.7655e-02, 3.3112e-02, -6.4055e-02, -1.3577e-02, 6.3291e-02],
[-3.9472e-02, 1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
[-6.3974e-02, -6.9448e-02, -2.1668e-02, 2.4177e-02, 3.0717e-02]],
[[ 2.8530e-02, 2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
[ 5.4420e-02, -1.2094e-02, 2.1143e-02, -6.5897e-02, -2.8639e-02],
[ 4.7260e-02, 4.2903e-02, 7.4266e-02, 6.9400e-02, -6.4887e-02],
[-7.5132e-02, -5.4750e-02, 4.6103e-03, 3.0465e-02, -6.0162e-02],
[ 3.2707e-02, -3.7524e-02, -6.7505e-02, 2.1123e-02, -1.5651e-02]]],
[[[ 6.0358e-02, 5.9896e-02, -3.1081e-02, 5.8683e-02, 1.5452e-02],
[ 6.0256e-02, -7.1520e-02, 7.8586e-02, -3.8772e-02, -7.1890e-02],
[-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
[-3.2572e-02, -6.9616e-02, 2.9291e-02, 7.2235e-02, 2.5144e-02],
[ 7.6527e-02, 3.1913e-03, -1.8299e-02, 1.5759e-02, 6.3982e-02]],
[[-6.2217e-02, 2.4136e-02, 6.5684e-02, -5.2996e-02, -8.5318e-03],
[ 5.4878e-02, -7.0118e-02, 5.6222e-02, -5.8217e-02, -1.2457e-02],
[-9.0102e-03, 4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
[ 3.9352e-02, 4.2597e-02, 6.4178e-02, -1.6116e-02, 4.4007e-02],
[-4.6907e-02, 5.0872e-02, 1.4034e-02, -7.7642e-02, -3.1652e-02]],
[[ 1.8691e-02, 3.9128e-02, -1.8538e-03, 6.9222e-02, 4.7985e-02],
[ 2.5163e-02, -3.2308e-02, 5.8934e-02, 6.4200e-02, 7.5079e-02],
[-3.8752e-02, -6.2834e-02, -1.3630e-02, 4.7745e-02, -3.6710e-02],
[-6.5912e-02, -7.5509e-02, 2.0538e-03, 6.1806e-02, 4.7332e-02],
[ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02, 6.5297e-02]],
[[-4.0377e-02, 2.2776e-02, 2.0396e-03, 6.3307e-02, 7.7342e-02],
[-8.6686e-03, 6.8417e-02, 4.9833e-02, -5.8394e-02, -6.8530e-02],
[-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
[-7.8387e-02, 1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
[-6.5617e-02, 5.9709e-02, 7.6255e-02, 7.4220e-02, -3.9595e-02]],
[[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02, 1.1639e-02],
[ 4.5532e-02, 4.3314e-02, 2.7463e-02, -3.2127e-02, -6.0824e-02],
[ 4.7108e-02, 1.2112e-02, -1.6862e-02, -5.4160e-02, 4.8685e-03],
[-2.5893e-02, 6.4832e-02, 3.3282e-02, 4.9884e-02, 6.3713e-02],
[-1.9860e-02, 1.2712e-02, 7.0452e-04, 4.6135e-02, -5.4728e-02]],
[[-5.1852e-02, -1.5589e-02, -3.1799e-02, 4.5747e-02, -2.9827e-02],
[ 6.4932e-02, -5.3074e-02, -4.9272e-02, 1.8426e-02, -4.6095e-02],
[ 4.1712e-02, -7.9372e-02, -7.9577e-02, 1.2126e-02, 4.2022e-02],
[ 5.8650e-02, 1.3046e-02, -5.0546e-02, 7.1611e-02, 5.8748e-02],
[ 4.8559e-02, -8.1399e-02, 6.8672e-02, 7.3071e-02, -4.6508e-02]]],
...,
[[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02, 8.1223e-02],
[-2.6174e-02, 4.4648e-02, -1.7627e-05, 6.8991e-02, 9.3131e-04],
[-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
[-2.9851e-02, -7.2093e-02, 6.1742e-03, -6.0501e-02, -3.0240e-02],
[-3.6711e-02, -3.9918e-02, 1.8570e-02, 2.7867e-02, -3.9091e-02]],
[[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02, 2.6511e-02],
[-1.1918e-02, 1.9399e-02, 1.9778e-03, 4.6715e-02, 7.9662e-02],
[-3.9779e-02, -3.2971e-02, 2.6502e-03, 6.0599e-02, 6.1761e-02],
[-7.2092e-02, -3.8731e-02, -1.5203e-02, 3.3408e-02, -7.3232e-02],
[ 2.1354e-02, 4.9467e-02, 6.6561e-02, 6.4517e-02, -4.9400e-02]],
[[ 2.5279e-02, 7.5811e-02, -5.6423e-02, 7.1795e-03, -8.0469e-02],
[ 6.3054e-02, 1.5441e-02, -7.9545e-02, 6.0103e-02, 4.7542e-02],
[-2.6974e-02, 6.4899e-02, -7.6267e-02, 3.2200e-02, 9.7143e-03],
[ 4.1850e-02, -1.8550e-02, -5.0626e-02, 3.7149e-02, 7.6128e-02],
[-2.8989e-02, 2.5409e-03, -3.2850e-02, -1.1957e-02, 4.2580e-04]],
[[ 4.6736e-02, 7.7891e-02, 2.2977e-02, 3.5759e-02, -7.9195e-02],
[-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03, 2.4057e-02],
[-6.7396e-03, 2.7530e-02, 1.1896e-03, -1.6895e-02, -5.0218e-02],
[ 5.6456e-02, -7.6683e-02, 2.4498e-02, -5.4710e-02, 5.6294e-02],
[-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02, 4.4595e-02]],
[[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02, 5.5213e-02],
[ 1.1116e-02, -3.8443e-02, 4.3369e-02, -7.6902e-02, 5.1385e-02],
[ 5.6263e-02, -5.4902e-02, 8.0991e-02, 1.6011e-02, 6.6421e-02],
[-2.4895e-02, -4.4881e-02, 4.6953e-02, -4.1781e-02, -4.2947e-02],
[ 8.0550e-02, -7.2696e-02, -4.6141e-02, 6.7832e-03, -1.6691e-03]],
[[ 2.6609e-02, -3.9203e-02, -7.8157e-03, 2.2936e-04, -2.7554e-02],
[ 4.0520e-02, 1.1102e-02, -2.2165e-02, 6.4671e-02, 1.1872e-02],
[ 2.5477e-02, 3.2211e-02, 5.6317e-02, 5.1697e-02, 5.5899e-02],
[-3.0296e-02, -3.9487e-02, -2.5797e-02, 5.7478e-02, -4.8781e-03],
[-6.3375e-02, -4.3827e-02, 3.5311e-03, 4.7217e-02, 6.8362e-02]]],
[[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03, 6.8520e-02],
[ 2.3377e-02, -5.9736e-03, -6.8239e-02, 7.2911e-02, -6.6242e-02],
[-1.5282e-02, 1.7386e-02, 3.9979e-02, -6.8327e-03, -1.7662e-03],
[ 5.4649e-02, -4.8377e-03, 7.7069e-02, -8.0424e-02, -2.7894e-02],
[-6.3750e-02, -2.7770e-02, 5.7462e-02, -1.8159e-02, 5.8960e-02]],
[[ 1.5038e-02, -8.0078e-02, 1.0708e-02, 2.2493e-02, 2.2514e-02],
[-2.7322e-02, 4.5916e-02, 7.1295e-02, -5.6998e-02, -5.2429e-02],
[-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
[-8.0915e-02, 1.1733e-02, 7.0872e-02, 4.2211e-02, 3.7455e-03],
[ 5.6451e-02, -2.0291e-02, 5.9699e-02, -3.8810e-02, 9.7062e-03]],
[[ 7.0948e-02, 7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
[ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
[ 6.2984e-02, 7.1278e-02, 1.8568e-02, -8.1057e-02, -7.4301e-02],
[-1.7063e-02, 3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
[ 3.3733e-02, 3.2608e-02, -1.8692e-02, 6.1727e-02, 1.0257e-02]],
[[ 4.3113e-04, -6.9241e-02, 2.2611e-02, 4.1913e-02, -6.6395e-02],
[ 7.5128e-03, -7.3346e-02, 8.0353e-02, 1.2347e-02, 5.5333e-02],
[-9.7800e-03, 4.5897e-02, 2.8835e-02, -3.6708e-02, 3.9655e-02],
[ 2.7716e-02, -7.1659e-02, 7.1108e-03, 1.1511e-02, -4.8559e-02],
[-3.0865e-02, 7.5560e-02, 2.8310e-02, 7.4005e-02, -5.0888e-02]],
[[ 7.5087e-02, 6.3344e-02, 5.9466e-02, 1.0437e-02, 9.3939e-03],
[-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
[ 1.1036e-02, -2.6916e-02, -6.9722e-02, 5.9740e-02, 4.6108e-02],
[ 2.0379e-02, 3.6167e-02, 4.8153e-02, -3.0691e-02, -5.5250e-02],
[ 3.5924e-02, 4.5421e-02, -4.7335e-02, 6.4587e-02, -5.7064e-02]],
[[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02, 7.1452e-02],
[ 1.6848e-02, -7.5881e-02, 2.5285e-02, 2.5364e-02, -1.0818e-02],
[-3.0854e-02, -2.4429e-02, -6.4815e-02, 8.1414e-03, -7.9674e-02],
[-6.2038e-02, 7.4582e-02, -1.7759e-02, 2.3795e-02, -1.5795e-02],
[ 3.7823e-02, -3.3319e-04, 7.1363e-03, 7.7572e-02, 4.3771e-02]]],
[[[-4.8656e-03, 4.3062e-02, -3.2547e-02, 9.7140e-03, -5.3167e-02],
[ 4.2759e-02, -4.1656e-02, 6.4357e-02, 3.5642e-02, -7.8376e-02],
[-1.2937e-02, 6.4533e-02, 1.5182e-02, 1.1444e-02, -7.4220e-02],
[ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
[-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],
[[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02, 2.2487e-02],
[-2.9249e-03, 6.8871e-02, 4.3621e-03, 2.6227e-02, 4.3713e-02],
[ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
[-3.5359e-02, -1.3766e-02, 3.6173e-02, 9.1202e-03, -3.9747e-02],
[-7.2135e-02, -7.3420e-02, 6.0504e-02, 3.1594e-02, 7.6891e-02]],
[[ 7.2759e-02, -6.5420e-02, 6.7763e-02, 7.2741e-02, -7.4671e-02],
[-5.5163e-02, -7.5269e-02, 1.3287e-02, 1.8645e-02, -3.4054e-02],
[ 6.5525e-02, -4.1262e-03, 4.6500e-02, -6.6291e-02, 5.8884e-02],
[ 3.0486e-02, 3.6131e-03, 1.1222e-02, -3.3646e-02, -6.5889e-02],
[ 4.7762e-02, 3.6352e-02, 9.7470e-03, 7.7495e-03, -5.5064e-02]],
[[ 4.2110e-02, 5.1736e-02, 4.9755e-02, 1.8245e-02, 3.1093e-02],
[-5.8074e-02, -4.1158e-02, -5.9566e-03, 6.2394e-02, 1.6582e-02],
[-7.2003e-02, 1.4616e-02, -3.5987e-02, 3.0575e-02, -4.4705e-02],
[-4.7500e-02, -1.9091e-02, 1.2661e-02, 2.4751e-02, -7.1824e-02],
[ 4.3771e-02, 4.9023e-02, 7.2368e-02, -2.3195e-02, -3.0777e-02]],
[[-8.2526e-03, -1.3523e-02, -6.9580e-02, 2.5552e-02, -1.5779e-02],
[ 2.2318e-03, 2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
[ 7.4568e-02, -4.6680e-02, 7.4333e-02, -6.5834e-02, 8.0266e-02],
[ 1.0070e-02, 5.4708e-02, -1.4732e-03, 1.9077e-02, -2.5033e-02],
[-2.7357e-02, 1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],
[[ 4.0488e-02, -7.1390e-02, -2.3527e-02, 1.3764e-02, -1.5115e-02],
[-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
[-7.0937e-02, -5.7632e-02, -1.2339e-02, 5.2566e-02, -5.8696e-02],
[-6.4373e-02, 5.0876e-02, -6.8186e-02, 6.8750e-02, 4.6615e-02],
[ 3.0661e-02, -6.8377e-02, 1.7900e-02, -8.8543e-03, 1.4958e-02]]]],
requires_grad=True)
tensor: tensor([[[[-7.9040e-02, -3.7379e-03, 3.7772e-02, 3.2235e-03, 5.9213e-02],
[ 1.1029e-01, 5.3078e-02, -5.1533e-02, -1.9493e-02, -2.0278e-02],
[-3.1871e-02, 1.9686e-02, -2.9023e-02, 6.5129e-02, -1.2462e-01],
[ 1.5340e-02, 6.1715e-02, 6.1034e-02, 4.7931e-02, 4.5956e-02],
[-1.4301e-01, 7.0018e-02, -9.7622e-02, 1.0660e-01, -1.3165e-02]],
[[ 4.6854e-02, -7.3574e-02, 8.1291e-02, -8.2834e-02, 5.4429e-02],
[-2.5799e-02, -1.1986e-01, 9.7622e-02, -1.7025e-02, 5.7727e-02],
[ 1.7480e-02, 1.1082e-02, -2.8427e-02, 3.6260e-02, -3.2276e-02],
[-3.9838e-02, 2.2212e-02, -9.0553e-02, 8.0943e-02, -1.1631e-02],
[-1.0137e-02, -8.9539e-02, 8.1702e-02, -6.7811e-02, 1.3221e-01]],
[[-4.7101e-02, -6.6853e-02, -1.6752e-02, 8.6421e-02, -7.5081e-02],
[-9.1753e-03, -8.0882e-02, -1.0052e-01, -5.2712e-02, 6.5347e-03],
[ 1.2138e-01, 8.8738e-02, 2.7613e-02, 1.6467e-02, -2.6266e-02],
[ 1.4057e-01, -4.7242e-02, -2.1786e-03, 4.5640e-02, -1.3707e-01],
[-4.3512e-02, 6.7946e-02, 5.3988e-02, 5.0796e-02, -8.4169e-02]],
[[-3.3406e-02, 8.4323e-02, 1.2269e-02, -8.0027e-03, 6.7707e-02],
[ 3.4774e-02, 4.1956e-02, 4.8764e-02, -3.4235e-02, 8.3637e-02],
[-1.4833e-01, -2.8633e-02, -1.3108e-01, 2.1766e-03, 1.2428e-01],
[-3.0024e-02, -1.8942e-02, -3.2634e-02, 2.2048e-02, 7.3094e-03],
[ 2.8312e-02, 2.7016e-02, 1.1275e-01, -3.3630e-02, 5.8374e-02]],
[[-9.6861e-02, 5.6462e-02, -1.3434e-02, 1.3635e-02, 1.2531e-02],
[ 3.7592e-02, -4.0006e-02, -5.6685e-02, 4.3706e-02, -1.9282e-02],
[ 6.0806e-02, -1.0441e-01, -9.5292e-02, -8.2102e-03, 3.6336e-02],
[-7.0097e-03, 2.9375e-02, 2.9469e-02, -1.4813e-02, -1.4279e-01],
[-6.3974e-02, 1.3383e-01, 9.9203e-02, 3.4012e-02, -3.2517e-02]],
[[-7.2753e-03, 7.4082e-02, 1.8057e-03, -1.9648e-02, -6.6572e-02],
[-9.5007e-03, 1.5475e-01, -1.7310e-02, 1.0670e-01, 1.0902e-01],
[-2.5831e-02, 6.5105e-02, -2.5116e-03, 6.3226e-02, -8.0030e-02],
[ 9.8436e-02, 1.7270e-02, 2.2965e-03, 1.3588e-02, 9.2589e-02],
[-1.8574e-01, -4.7728e-02, -5.0609e-02, -4.6957e-02, -1.1836e-01]]],
[[[ 4.2598e-02, -4.7736e-03, -1.5925e-02, -1.3042e-01, -3.4655e-02],
[ 1.0344e-01, -8.2122e-02, -6.3094e-02, -1.3021e-02, -1.4217e-01],
[-2.2429e-02, -8.9024e-02, -2.6338e-02, -4.2198e-02, -3.3970e-02],
[ 6.1737e-02, 1.4096e-01, 8.1511e-02, -1.1718e-01, -7.9483e-02],
[ 1.3986e-01, 1.0246e-01, 1.3764e-02, -1.1664e-01, -1.2733e-02]],
[[ 4.0471e-02, -5.8450e-02, -9.7472e-02, -1.1759e-03, -2.6700e-02],
[-2.9101e-02, -1.7837e-02, -3.6606e-02, 7.5730e-02, 4.6381e-02],
[-6.4833e-02, 1.0151e-01, 4.6630e-02, -1.2530e-02, 1.4527e-02],
[-7.4575e-02, 4.4769e-02, 3.6406e-02, 3.2990e-02, -7.2557e-02],
[-3.3549e-02, -2.6702e-02, -5.1471e-02, -7.6490e-02, 6.8940e-02]],
[[-3.6364e-02, -6.5542e-03, 1.8009e-02, 5.2121e-02, 1.3648e-02],
[ 7.4661e-02, 5.3473e-02, 1.0069e-01, 4.3583e-02, -5.3611e-02],
[ 5.8234e-02, 1.0366e-01, 3.8814e-02, -1.0320e-01, 1.0476e-01],
[-5.2903e-03, 9.4834e-02, 8.0787e-02, 2.3429e-02, -7.9876e-02],
[ 1.9255e-02, -7.6146e-02, 1.4325e-02, 1.1370e-01, -6.8911e-02]],
[[ 2.2207e-02, -9.1128e-02, -4.1531e-02, -6.9817e-02, 6.1896e-02],
[-9.3361e-02, 4.2035e-02, 4.2046e-02, 3.0422e-02, 2.4328e-02],
[ 3.8398e-02, -5.3092e-02, -1.9244e-02, -5.9219e-02, -4.7959e-02],
[ 1.1761e-01, 8.1741e-03, 9.8403e-02, 9.8197e-02, 3.1676e-03],
[ 3.4369e-02, -4.2303e-02, 8.5386e-02, -3.3772e-02, -3.5091e-02]],
[[-2.3272e-02, -4.7407e-02, -7.4221e-02, -4.5022e-02, 6.2653e-02],
[-3.3159e-02, 8.0423e-02, -3.4854e-02, 1.0360e-01, -6.0444e-02],
[-6.4923e-02, 1.8660e-03, -5.3824e-02, -3.3233e-02, 5.3356e-02],
[ 2.6758e-02, 5.6524e-02, 1.4273e-02, 2.0408e-02, -6.0961e-03],
[-7.6553e-02, -1.5248e-01, -1.6458e-01, 6.0101e-02, -6.5330e-02]],
[[ 7.4326e-02, 5.1909e-02, -1.0165e-01, -5.6569e-02, 3.2163e-03],
[ 2.2902e-02, 4.9499e-02, 1.2251e-01, -1.2251e-02, -8.8281e-02],
[ 7.0423e-02, 7.8330e-03, 1.1058e-01, 4.6035e-02, 3.0721e-02],
[-1.3987e-01, -2.5666e-02, 6.7374e-03, 1.6375e-02, -1.1733e-01],
[ 7.9805e-02, -8.5537e-02, 5.2054e-02, 2.3272e-02, 1.0554e-01]]],
[[[ 7.6109e-02, 1.1910e-01, 5.8049e-03, 1.1300e-01, 4.7126e-02],
[ 4.7814e-02, -7.1319e-02, 9.9547e-02, -3.3557e-02, -1.1612e-01],
[-6.4900e-02, -3.6347e-02, -9.5536e-02, -3.6929e-02, 7.2433e-03],
[-8.8529e-02, -5.4179e-02, 1.1549e-01, 5.7148e-02, 2.8013e-02],
[ 1.2830e-01, 4.1872e-02, -3.4477e-02, 8.0472e-02, 6.5879e-02]],
[[-4.7669e-02, -9.9669e-03, 7.2031e-02, -7.7567e-02, -8.9658e-02],
[ 1.5114e-01, -9.9826e-02, 2.3776e-02, -2.5093e-02, -1.1878e-01],
[-4.9512e-02, 1.5042e-01, -5.2417e-02, -3.5301e-03, -5.2818e-02],
[ 9.4058e-03, 1.0676e-02, 8.0632e-02, 2.8295e-02, 6.7808e-02],
[-5.5620e-02, 1.2470e-01, 7.8352e-02, 1.2027e-02, -4.2022e-02]],
[[ 1.0013e-01, 1.1369e-01, -1.0589e-01, -2.8328e-02, 1.3427e-01],
[ 2.5859e-02, 7.2207e-02, 1.0146e-02, -4.0904e-02, -1.6968e-02],
[-1.2773e-02, -6.0562e-02, -1.5470e-02, 7.2821e-02, 1.7777e-02],
[-8.7816e-02, 1.0456e-03, 7.3066e-02, -9.0929e-03, 4.0644e-03],
[ 1.4177e-02, -4.2387e-02, 1.8873e-02, -2.0339e-02, 1.3091e-03]],
[[-1.1817e-01, 7.3392e-03, 2.8750e-04, 1.5765e-01, 1.0838e-01],
[-8.5690e-02, 7.2779e-03, 1.7669e-01, 3.7189e-03, -7.1858e-02],
[-1.0172e-01, 2.2364e-02, 8.7263e-03, -6.2377e-02, -6.5125e-02],
[ 1.8006e-02, -9.0645e-02, 4.0923e-02, -8.8384e-02, -3.2109e-02],
[-1.1282e-01, 7.5887e-02, -2.0569e-02, 6.8537e-02, -9.1945e-02]],
[[ 8.7547e-03, -7.3805e-02, -2.4725e-02, -9.7951e-02, 1.4958e-02],
[ 1.5039e-03, 3.5538e-02, 6.9676e-03, -3.4485e-02, -7.6818e-02],
[ 5.0455e-02, -1.6215e-02, -4.2158e-02, -4.7410e-02, 1.1642e-01],
[ 1.3294e-02, 1.3026e-01, -1.5100e-02, 8.2574e-02, 7.0877e-02],
[-8.0356e-02, 5.9577e-02, -2.4652e-02, 4.9606e-03, -5.2893e-02]],
[[-2.6808e-02, -4.1903e-02, 1.0307e-01, 4.8184e-02, 5.5327e-02],
[ 1.1018e-01, -1.9608e-02, -7.4902e-02, 4.5365e-02, 6.8104e-03],
[ 1.6191e-02, -1.0323e-01, -1.4568e-01, 1.1337e-01, 1.1910e-01],
[ 2.0553e-02, 2.9168e-02, -3.5671e-02, 9.2641e-02, -4.0162e-02],
[ 8.6601e-02, -1.6433e-01, 1.2801e-01, 3.8804e-02, -9.1625e-02]]],
...,
[[[ 6.2546e-02, -1.4037e-02, -6.0003e-02, -1.2287e-01, 1.0268e-01],
[-6.6390e-02, 1.0820e-01, 3.2333e-02, 9.1314e-02, -2.1904e-02],
[-7.6083e-02, -8.3289e-02, 2.6048e-02, 3.8830e-02, -7.6226e-02],
[-3.3048e-02, -3.1125e-02, -1.9828e-02, -7.9846e-02, -6.0296e-02],
[-5.7660e-02, 2.9841e-02, 4.6691e-02, -5.9227e-02, -8.6609e-03]],
[[ 1.9611e-02, -6.0364e-02, -6.5186e-02, -1.4345e-01, 6.7694e-02],
[ 2.5685e-03, -2.5887e-02, -3.7028e-02, 8.4842e-02, 5.7676e-02],
[-8.0969e-02, 8.7602e-03, 2.1688e-02, 1.3805e-01, 8.9147e-02],
[-1.3020e-01, 1.3120e-02, -2.7499e-02, 1.6288e-02, -8.0340e-02],
[ 4.6616e-02, 5.5087e-02, 1.2293e-01, 3.0804e-02, -3.4891e-02]],
[[-2.7551e-02, 1.1154e-01, -1.7306e-01, 6.3198e-02, -8.9437e-02],
[ 1.0023e-01, 7.0259e-02, -1.0258e-01, -5.0549e-02, 1.6971e-02],
[ 2.5675e-02, -1.1538e-02, -3.6890e-02, 1.8445e-02, 2.4550e-02],
[ 6.6892e-02, 7.9879e-02, -2.2027e-02, 7.8564e-03, 1.3598e-01],
[-1.3008e-02, 5.2810e-02, -1.6301e-01, -4.4813e-02, 2.5996e-02]],
[[ 8.5186e-02, 9.1058e-02, -1.3829e-02, 3.8593e-02, -4.9633e-02],
[-7.4777e-02, -1.3623e-01, -4.2233e-02, -2.4843e-02, 4.7866e-02],
[ 1.1859e-01, -3.9103e-02, 6.5936e-02, -6.2660e-03, -8.1931e-02],
[ 6.5915e-02, -3.5195e-02, -8.3154e-02, 2.5878e-02, 1.0228e-01],
[-2.0687e-02, 4.4530e-02, -9.8919e-03, 2.0240e-02, 1.0023e-01]],
[[-3.2076e-02, 3.3811e-02, -4.0684e-02, -1.0620e-01, 4.5157e-02],
[ 1.5433e-02, -5.7362e-02, -2.0055e-02, -1.7903e-01, 3.5422e-02],
[ 5.4437e-02, -6.6113e-02, 5.3977e-02, -1.0462e-02, 6.6393e-02],
[-4.2926e-02, -9.2950e-02, 7.8664e-03, -1.4528e-01, -7.3950e-02],
[ 1.1710e-01, -7.6187e-03, -3.6296e-02, -3.4596e-02, -8.7469e-02]],
[[-1.5697e-02, 1.8503e-03, -1.0732e-02, 4.4367e-02, 6.3441e-02],
[ 6.3576e-02, -7.5327e-02, -2.1156e-02, 1.0805e-01, 5.9616e-03],
[-5.1490e-02, 6.6770e-02, 3.8325e-02, 9.7658e-02, 6.7618e-02],
[-9.1422e-03, -2.5269e-02, -6.2489e-02, -3.4347e-02, -3.2664e-02],
[-7.2520e-02, -8.9599e-02, 6.2424e-02, 7.9191e-02, 7.9724e-02]]],
[[[-3.9622e-02, -4.6573e-02, -1.0472e-01, 1.1712e-02, 6.0660e-02],
[-1.1623e-01, -5.1286e-03, -1.4487e-03, 1.1639e-01, -1.7682e-01],
[-8.0438e-03, -2.5883e-02, -6.0686e-02, -4.3806e-02, -5.9083e-02],
[ 6.1406e-02, 2.7634e-02, 8.6463e-02, -2.1471e-02, 1.3222e-03],
[-3.6193e-02, -1.3771e-01, -3.6332e-02, -1.2857e-02, 4.2172e-02]],
[[ 6.7233e-02, -1.2348e-01, -1.0973e-02, 9.5008e-02, -2.0650e-02],
[-5.6706e-02, 6.6555e-02, 7.6050e-02, -1.5699e-01, -6.1664e-02],
[ 1.7243e-02, -5.0797e-02, -1.3130e-01, -6.8116e-02, -2.0173e-02],
[-1.7038e-01, 6.6691e-02, 1.3256e-02, 6.9085e-02, -2.6799e-02],
[ 6.1442e-02, -3.6632e-02, 3.3934e-02, -9.6146e-02, 1.9848e-02]],
[[ 5.0492e-02, 3.8202e-02, 8.3198e-03, -4.8258e-02, 1.4369e-01],
[ 1.2627e-01, -1.1377e-02, 2.7590e-02, -5.6298e-02, -1.0798e-01],
[ 9.1512e-02, 3.4963e-02, -3.0616e-02, 7.3307e-02, -8.8587e-02],
[-4.4653e-02, 9.5914e-02, -3.9202e-02, 1.3692e-02, -4.4345e-02],
[ 6.6634e-02, 5.8251e-02, 1.8750e-03, 2.9827e-02, -2.3503e-02]],
[[-2.7796e-02, -6.2559e-02, 1.1090e-01, -2.0772e-02, -6.1026e-02],
[ 3.4603e-06, -8.2480e-02, 1.2860e-01, -8.0894e-02, 1.0519e-01],
[-9.9585e-02, 1.1226e-01, 8.5616e-02, -3.4187e-02, -2.4726e-02],
[ 7.1098e-02, -8.6928e-02, 1.1157e-02, 5.2506e-02, -1.0666e-01],
[-1.0207e-01, 3.2715e-02, 2.3789e-02, 1.2132e-02, -5.6706e-02]],
[[ 1.4586e-01, 7.7736e-02, 8.8786e-02, 6.8187e-02, -3.1638e-02],
[ 3.9524e-03, -1.0926e-01, -1.1489e-01, -8.6353e-02, -1.1245e-01],
[ 2.9333e-02, -8.5913e-03, -5.3009e-02, 9.1673e-02, 8.5318e-02],
[ 6.1508e-02, 1.0515e-01, 1.2023e-01, -1.4283e-02, 4.9522e-02],
[ 3.0185e-02, 1.1706e-01, -2.6336e-02, 4.6261e-02, -5.2317e-02]],
[[-1.7293e-02, 2.8523e-02, -2.8706e-02, -1.2312e-01, 1.0700e-04],
[-2.4223e-02, -7.1036e-02, 1.6568e-02, 7.3946e-02, 3.1011e-02],
[-2.9600e-02, -4.9499e-02, -9.9627e-02, 5.0107e-02, -8.2558e-02],
[-8.7629e-02, 7.3977e-02, -1.1190e-02, 5.6551e-02, 5.6759e-03],
[ 7.6054e-02, 2.6358e-02, 5.8867e-03, 4.6918e-02, 8.0736e-02]]],
[[[ 7.6387e-02, 3.5164e-02, -1.0085e-01, 2.5791e-02, 1.2621e-02],
[ 4.0133e-02, -6.2389e-02, -7.0186e-03, -1.5747e-02, -5.8938e-02],
[-2.1233e-02, 5.9552e-02, -1.6451e-02, 5.8491e-02, -1.4009e-01],
[ 7.7164e-02, 7.3244e-02, 3.9004e-02, -5.1173e-02, 1.8552e-02],
[-9.8993e-03, -3.2891e-03, -1.2678e-02, 3.2900e-02, -1.6397e-02]],
[[ 4.8623e-02, -1.6285e-01, 2.6418e-02, -4.4003e-02, 1.0522e-02],
[ 6.5607e-02, -4.3938e-02, -4.4469e-02, 2.8729e-02, 7.2077e-02],
[ 8.6301e-02, -5.3798e-02, 2.2424e-03, -8.6657e-02, 2.3395e-02],
[-5.0502e-02, 2.6603e-02, 6.7704e-02, -7.1994e-02, -1.0376e-01],
[-1.1587e-01, -4.1993e-02, 4.0082e-02, 8.4338e-03, 5.3974e-02]],
[[ 9.3545e-03, -3.9613e-02, 1.1917e-01, 5.5167e-02, -5.5481e-02],
[-7.8236e-02, -3.6044e-02, -6.9676e-02, 2.0344e-02, 8.6851e-02],
[ 2.0640e-02, -2.4200e-02, 9.8858e-02, -1.3632e-02, 3.1430e-02],
[ 5.8873e-02, 9.5151e-03, 1.6263e-02, 4.1727e-02, -1.2870e-01],
[-4.9029e-02, 3.8735e-02, 2.3564e-02, -4.3188e-02, -6.6108e-02]],
[[ 4.8084e-02, 7.8817e-02, 3.1851e-02, 4.9276e-02, 3.8264e-02],
[-9.3058e-02, -6.3315e-02, -1.7968e-02, 7.5934e-02, 3.9757e-02],
[-1.2371e-01, 4.7944e-02, -2.0174e-02, -3.3382e-02, -4.8501e-02],
[-9.3338e-02, -1.7285e-03, 8.4074e-04, -1.8147e-02, -1.1448e-01],
[ 6.5431e-03, 1.3819e-01, 5.7124e-02, -5.3819e-02, 2.9704e-02]],
[[-3.9365e-02, 8.3383e-02, -5.7147e-02, -1.6916e-02, 8.8240e-02],
[-2.1344e-02, 6.3516e-02, -6.2374e-03, -6.7964e-02, -6.1580e-02],
[ 3.6108e-02, -8.6570e-02, -5.2997e-02, -6.8391e-02, 7.4299e-02],
[ 1.0596e-01, 1.0751e-02, 6.2705e-02, -2.0067e-02, -8.1794e-02],
[-3.2689e-02, 9.9929e-03, -6.4664e-02, -1.9183e-02, -1.2894e-01]],
[[ 5.6596e-02, 7.1493e-03, 2.1555e-02, 3.2155e-02, -4.1433e-02],
[-1.1520e-01, -1.6532e-02, -2.5351e-02, -1.0983e-01, -2.3312e-03],
[-1.0072e-01, -1.6166e-02, -4.6829e-03, 1.3726e-01, -3.1124e-02],
[-1.3071e-02, 1.0966e-01, -1.1188e-01, -2.2783e-02, 9.0083e-02],
[ 8.3369e-03, -9.1546e-02, 5.8801e-02, -2.5113e-03, -5.2018e-02]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0482, 0.0700, -0.0483, 0.0543, 0.0427, -0.0160, 0.0652, 0.0702,
-0.0693, -0.0685, 0.0211, 0.0540, 0.0356, 0.0235, 0.0208, 0.0305],
requires_grad=True)
tensor: tensor([-0.0018, 0.0913, -0.0619, 0.0713, 0.0755, -0.0322, 0.0876, 0.1089,
-0.1069, -0.0383, 0.0331, 0.0550, 0.0829, 0.0486, -0.0027, -0.0946],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[-0., 0., -0., -0., -0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., -0., 0.]],
[[-0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.]],
[[-0., 0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.]],
[[-0., 0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[-0., 0., 0., -0., -0.]],
[[-0., -0., 0., 0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0.],
[-0., -0., -0., -0., 0.]],
[[-0., -0., -0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., -0., 0., 0.]],
[[0., 0., -0., -0., -0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.]]],
[[[0., 0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., -0., 0., 0., 0.],
[0., 0., -0., 0., 0.]],
[[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., -0., -0.]],
[[0., 0., -0., 0., 0.],
[0., -0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.],
[0., -0., -0., -0., 0.]],
[[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., -0., -0.],
[-0., 0., 0., 0., -0.]],
[[-0., -0., -0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., -0., -0., 0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., 0., 0.],
[0., -0., 0., 0., -0.]]],
...,
[[[0., -0., -0., -0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., -0., -0.],
[-0., -0., 0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[-0., -0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., -0.]],
[[0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., 0., 0.],
[-0., 0., -0., -0., 0.]],
[[0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.]],
[[-0., -0., -0., -0., 0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.],
[-0., -0., 0., -0., -0.],
[0., -0., -0., 0., -0.]],
[[0., -0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[-0., -0., -0., -0., 0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[-0., -0., 0., -0., 0.]],
[[0., -0., 0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., 0.]],
[[0., 0., -0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[0., -0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., 0., -0.]],
[[-0., -0., -0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0.]]],
[[[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[0., -0., -0., -0., -0.],
[-0., -0., -0., -0., -0.]],
[[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., -0., 0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., -0., 0., -0.],
[-0., 0., -0., -0., -0.]],
[[0., -0., -0., 0., -0.],
[-0., -0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., -0., 0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02, 1.6377e-02, 6.2508e-02],
[ 4.0202e-02, 4.9382e-02, 1.0253e-02, 3.4932e-02, -2.9341e-02],
[-6.4074e-02, 5.2063e-02, -3.8546e-02, 1.5865e-02, -4.8314e-02],
[ 2.0698e-02, 8.0888e-02, 7.9982e-02, 5.6686e-02, 1.1419e-02],
[-6.1541e-02, 6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],
[[ 8.0672e-02, -7.8271e-02, 7.9289e-02, 2.1951e-02, 5.9238e-02],
[ 1.1237e-02, -5.2598e-02, 4.1551e-02, 4.3546e-02, -5.8132e-02],
[ 3.0288e-02, -7.5402e-03, 2.8025e-02, 1.5065e-02, -3.7260e-02],
[-2.1855e-02, 4.4424e-02, -6.7250e-02, 4.3303e-02, -4.0667e-02],
[ 4.9321e-02, -2.0541e-02, 5.6232e-02, -3.8466e-02, 5.1026e-02]],
[[-2.6513e-02, 4.9569e-02, 1.8676e-02, 6.9054e-02, -2.6767e-02],
[-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02, 2.8787e-04],
[ 1.3084e-02, 4.6926e-02, -1.3894e-02, 2.7116e-02, -7.9155e-02],
[ 7.5190e-02, -2.2396e-02, 5.5935e-02, 6.8270e-02, -3.2401e-02],
[-1.0770e-02, 3.6370e-02, -7.0476e-02, 6.4846e-02, -6.7540e-02]],
[[-3.7310e-02, 5.9737e-02, 5.9160e-05, 3.5063e-02, 7.1916e-02],
[ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02, 4.9856e-02],
[-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
[ 9.3012e-04, -2.4376e-02, -1.1352e-03, 1.1929e-02, -1.3466e-02],
[ 5.6119e-02, 2.2037e-02, 6.6074e-02, 3.7063e-02, 4.2302e-02]],
[[-3.6769e-02, 3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
[-2.8501e-02, -8.0159e-02, -2.8122e-02, 1.8005e-02, -3.7241e-02],
[ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02, 2.2794e-02],
[-6.3064e-02, 6.8354e-02, 4.3337e-02, -4.3442e-02, -6.5524e-02],
[-5.6791e-02, 7.4918e-02, 5.9263e-02, -1.2248e-04, -3.5795e-02]],
[[-2.1802e-02, -9.4069e-03, 7.6862e-02, 1.4458e-02, -7.0315e-02],
[-5.9874e-02, 5.0751e-02, 3.9215e-02, 7.7538e-02, 3.2766e-02],
[-4.4479e-02, 5.2775e-02, -1.9053e-03, 3.2342e-02, -4.6321e-02],
[ 3.8624e-02, 2.0750e-02, 4.6860e-02, -5.3962e-02, 6.4275e-02],
[-6.5864e-02, -1.3428e-02, 1.7864e-02, -7.0872e-02, -4.7788e-02]]],
[[[-6.8434e-03, 1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
[-4.6419e-03, 7.9605e-02, 2.4481e-02, -1.8153e-02, -6.5677e-02],
[-4.4087e-02, -3.3966e-02, 7.7397e-02, -2.4133e-02, -2.1281e-02],
[ 7.2108e-02, 5.7730e-02, 1.8671e-02, -6.6956e-02, -7.8317e-02],
[ 5.1377e-02, 4.0833e-02, 1.3915e-02, -7.3671e-02, -6.4171e-02]],
[[-2.2992e-02, -1.5135e-02, -1.6468e-02, 3.4729e-02, -5.0686e-02],
[ 3.3976e-02, 4.7649e-02, 9.0078e-03, 6.1119e-02, 7.7885e-02],
[ 2.0299e-02, 4.7323e-02, 5.7702e-02, 4.8393e-02, -4.6623e-02],
[ 4.0535e-02, 5.1355e-02, -3.0590e-02, 7.3194e-02, -6.9639e-02],
[-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02, 1.5673e-02]],
[[-2.4125e-02, -2.1872e-02, -4.1021e-02, 1.5590e-02, 7.0267e-02],
[-1.2325e-02, 4.8418e-02, 3.7418e-02, 5.6973e-02, 1.5516e-02],
[-5.0112e-02, 4.1789e-02, 5.5392e-02, -3.5548e-02, 2.8206e-02],
[ 5.9003e-02, 7.3764e-03, 1.5419e-02, -1.6909e-02, -4.0654e-02],
[ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],
[[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
[-1.6499e-02, 5.8059e-02, 3.5353e-02, 9.1365e-03, -2.8343e-02],
[ 4.6293e-02, 3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
[ 8.0904e-02, 1.1077e-02, 4.2119e-02, 5.4343e-02, 2.1999e-02],
[ 6.2393e-02, 2.9035e-02, 2.6746e-02, -2.8318e-02, -4.9989e-02]],
[[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
[-7.4939e-03, 9.3801e-03, 1.3410e-02, 5.7410e-02, 4.4898e-03],
[-1.7655e-02, 3.3112e-02, -6.4055e-02, -1.3577e-02, 6.3291e-02],
[-3.9472e-02, 1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
[-6.3974e-02, -6.9448e-02, -2.1668e-02, 2.4177e-02, 3.0717e-02]],
[[ 2.8530e-02, 2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
[ 5.4420e-02, -1.2094e-02, 2.1143e-02, -6.5897e-02, -2.8639e-02],
[ 4.7260e-02, 4.2903e-02, 7.4266e-02, 6.9400e-02, -6.4887e-02],
[-7.5132e-02, -5.4750e-02, 4.6103e-03, 3.0465e-02, -6.0162e-02],
[ 3.2707e-02, -3.7524e-02, -6.7505e-02, 2.1123e-02, -1.5651e-02]]],
[[[ 6.0358e-02, 5.9896e-02, -3.1081e-02, 5.8683e-02, 1.5452e-02],
[ 6.0256e-02, -7.1520e-02, 7.8586e-02, -3.8772e-02, -7.1890e-02],
[-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
[-3.2572e-02, -6.9616e-02, 2.9291e-02, 7.2235e-02, 2.5144e-02],
[ 7.6527e-02, 3.1913e-03, -1.8299e-02, 1.5759e-02, 6.3982e-02]],
[[-6.2217e-02, 2.4136e-02, 6.5684e-02, -5.2996e-02, -8.5318e-03],
[ 5.4878e-02, -7.0118e-02, 5.6222e-02, -5.8217e-02, -1.2457e-02],
[-9.0102e-03, 4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
[ 3.9352e-02, 4.2597e-02, 6.4178e-02, -1.6116e-02, 4.4007e-02],
[-4.6907e-02, 5.0872e-02, 1.4034e-02, -7.7642e-02, -3.1652e-02]],
[[ 1.8691e-02, 3.9128e-02, -1.8538e-03, 6.9222e-02, 4.7985e-02],
[ 2.5163e-02, -3.2308e-02, 5.8934e-02, 6.4200e-02, 7.5079e-02],
[-3.8752e-02, -6.2834e-02, -1.3630e-02, 4.7745e-02, -3.6710e-02],
[-6.5912e-02, -7.5509e-02, 2.0538e-03, 6.1806e-02, 4.7332e-02],
[ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02, 6.5297e-02]],
[[-4.0377e-02, 2.2776e-02, 2.0396e-03, 6.3307e-02, 7.7342e-02],
[-8.6686e-03, 6.8417e-02, 4.9833e-02, -5.8394e-02, -6.8530e-02],
[-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
[-7.8387e-02, 1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
[-6.5617e-02, 5.9709e-02, 7.6255e-02, 7.4220e-02, -3.9595e-02]],
[[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02, 1.1639e-02],
[ 4.5532e-02, 4.3314e-02, 2.7463e-02, -3.2127e-02, -6.0824e-02],
[ 4.7108e-02, 1.2112e-02, -1.6862e-02, -5.4160e-02, 4.8685e-03],
[-2.5893e-02, 6.4832e-02, 3.3282e-02, 4.9884e-02, 6.3713e-02],
[-1.9860e-02, 1.2712e-02, 7.0452e-04, 4.6135e-02, -5.4728e-02]],
[[-5.1852e-02, -1.5589e-02, -3.1799e-02, 4.5747e-02, -2.9827e-02],
[ 6.4932e-02, -5.3074e-02, -4.9272e-02, 1.8426e-02, -4.6095e-02],
[ 4.1712e-02, -7.9372e-02, -7.9577e-02, 1.2126e-02, 4.2022e-02],
[ 5.8650e-02, 1.3046e-02, -5.0546e-02, 7.1611e-02, 5.8748e-02],
[ 4.8559e-02, -8.1399e-02, 6.8672e-02, 7.3071e-02, -4.6508e-02]]],
...,
[[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02, 8.1223e-02],
[-2.6174e-02, 4.4648e-02, -1.7627e-05, 6.8991e-02, 9.3131e-04],
[-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
[-2.9851e-02, -7.2093e-02, 6.1742e-03, -6.0501e-02, -3.0240e-02],
[-3.6711e-02, -3.9918e-02, 1.8570e-02, 2.7867e-02, -3.9091e-02]],
[[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02, 2.6511e-02],
[-1.1918e-02, 1.9399e-02, 1.9778e-03, 4.6715e-02, 7.9662e-02],
[-3.9779e-02, -3.2971e-02, 2.6502e-03, 6.0599e-02, 6.1761e-02],
[-7.2092e-02, -3.8731e-02, -1.5203e-02, 3.3408e-02, -7.3232e-02],
[ 2.1354e-02, 4.9467e-02, 6.6561e-02, 6.4517e-02, -4.9400e-02]],
[[ 2.5279e-02, 7.5811e-02, -5.6423e-02, 7.1795e-03, -8.0469e-02],
[ 6.3054e-02, 1.5441e-02, -7.9545e-02, 6.0103e-02, 4.7542e-02],
[-2.6974e-02, 6.4899e-02, -7.6267e-02, 3.2200e-02, 9.7143e-03],
[ 4.1850e-02, -1.8550e-02, -5.0626e-02, 3.7149e-02, 7.6128e-02],
[-2.8989e-02, 2.5409e-03, -3.2850e-02, -1.1957e-02, 4.2580e-04]],
[[ 4.6736e-02, 7.7891e-02, 2.2977e-02, 3.5759e-02, -7.9195e-02],
[-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03, 2.4057e-02],
[-6.7396e-03, 2.7530e-02, 1.1896e-03, -1.6895e-02, -5.0218e-02],
[ 5.6456e-02, -7.6683e-02, 2.4498e-02, -5.4710e-02, 5.6294e-02],
[-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02, 4.4595e-02]],
[[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02, 5.5213e-02],
[ 1.1116e-02, -3.8443e-02, 4.3369e-02, -7.6902e-02, 5.1385e-02],
[ 5.6263e-02, -5.4902e-02, 8.0991e-02, 1.6011e-02, 6.6421e-02],
[-2.4895e-02, -4.4881e-02, 4.6953e-02, -4.1781e-02, -4.2947e-02],
[ 8.0550e-02, -7.2696e-02, -4.6141e-02, 6.7832e-03, -1.6691e-03]],
[[ 2.6609e-02, -3.9203e-02, -7.8157e-03, 2.2936e-04, -2.7554e-02],
[ 4.0520e-02, 1.1102e-02, -2.2165e-02, 6.4671e-02, 1.1872e-02],
[ 2.5477e-02, 3.2211e-02, 5.6317e-02, 5.1697e-02, 5.5899e-02],
[-3.0296e-02, -3.9487e-02, -2.5797e-02, 5.7478e-02, -4.8781e-03],
[-6.3375e-02, -4.3827e-02, 3.5311e-03, 4.7217e-02, 6.8362e-02]]],
[[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03, 6.8520e-02],
[ 2.3377e-02, -5.9736e-03, -6.8239e-02, 7.2911e-02, -6.6242e-02],
[-1.5282e-02, 1.7386e-02, 3.9979e-02, -6.8327e-03, -1.7662e-03],
[ 5.4649e-02, -4.8377e-03, 7.7069e-02, -8.0424e-02, -2.7894e-02],
[-6.3750e-02, -2.7770e-02, 5.7462e-02, -1.8159e-02, 5.8960e-02]],
[[ 1.5038e-02, -8.0078e-02, 1.0708e-02, 2.2493e-02, 2.2514e-02],
[-2.7322e-02, 4.5916e-02, 7.1295e-02, -5.6998e-02, -5.2429e-02],
[-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
[-8.0915e-02, 1.1733e-02, 7.0872e-02, 4.2211e-02, 3.7455e-03],
[ 5.6451e-02, -2.0291e-02, 5.9699e-02, -3.8810e-02, 9.7062e-03]],
[[ 7.0948e-02, 7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
[ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
[ 6.2984e-02, 7.1278e-02, 1.8568e-02, -8.1057e-02, -7.4301e-02],
[-1.7063e-02, 3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
[ 3.3733e-02, 3.2608e-02, -1.8692e-02, 6.1727e-02, 1.0257e-02]],
[[ 4.3113e-04, -6.9241e-02, 2.2611e-02, 4.1913e-02, -6.6395e-02],
[ 7.5128e-03, -7.3346e-02, 8.0353e-02, 1.2347e-02, 5.5333e-02],
[-9.7800e-03, 4.5897e-02, 2.8835e-02, -3.6708e-02, 3.9655e-02],
[ 2.7716e-02, -7.1659e-02, 7.1108e-03, 1.1511e-02, -4.8559e-02],
[-3.0865e-02, 7.5560e-02, 2.8310e-02, 7.4005e-02, -5.0888e-02]],
[[ 7.5087e-02, 6.3344e-02, 5.9466e-02, 1.0437e-02, 9.3939e-03],
[-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
[ 1.1036e-02, -2.6916e-02, -6.9722e-02, 5.9740e-02, 4.6108e-02],
[ 2.0379e-02, 3.6167e-02, 4.8153e-02, -3.0691e-02, -5.5250e-02],
[ 3.5924e-02, 4.5421e-02, -4.7335e-02, 6.4587e-02, -5.7064e-02]],
[[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02, 7.1452e-02],
[ 1.6848e-02, -7.5881e-02, 2.5285e-02, 2.5364e-02, -1.0818e-02],
[-3.0854e-02, -2.4429e-02, -6.4815e-02, 8.1414e-03, -7.9674e-02],
[-6.2038e-02, 7.4582e-02, -1.7759e-02, 2.3795e-02, -1.5795e-02],
[ 3.7823e-02, -3.3319e-04, 7.1363e-03, 7.7572e-02, 4.3771e-02]]],
[[[-4.8656e-03, 4.3062e-02, -3.2547e-02, 9.7140e-03, -5.3167e-02],
[ 4.2759e-02, -4.1656e-02, 6.4357e-02, 3.5642e-02, -7.8376e-02],
[-1.2937e-02, 6.4533e-02, 1.5182e-02, 1.1444e-02, -7.4220e-02],
[ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
[-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],
[[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02, 2.2487e-02],
[-2.9249e-03, 6.8871e-02, 4.3621e-03, 2.6227e-02, 4.3713e-02],
[ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
[-3.5359e-02, -1.3766e-02, 3.6173e-02, 9.1202e-03, -3.9747e-02],
[-7.2135e-02, -7.3420e-02, 6.0504e-02, 3.1594e-02, 7.6891e-02]],
[[ 7.2759e-02, -6.5420e-02, 6.7763e-02, 7.2741e-02, -7.4671e-02],
[-5.5163e-02, -7.5269e-02, 1.3287e-02, 1.8645e-02, -3.4054e-02],
[ 6.5525e-02, -4.1262e-03, 4.6500e-02, -6.6291e-02, 5.8884e-02],
[ 3.0486e-02, 3.6131e-03, 1.1222e-02, -3.3646e-02, -6.5889e-02],
[ 4.7762e-02, 3.6352e-02, 9.7470e-03, 7.7495e-03, -5.5064e-02]],
[[ 4.2110e-02, 5.1736e-02, 4.9755e-02, 1.8245e-02, 3.1093e-02],
[-5.8074e-02, -4.1158e-02, -5.9566e-03, 6.2394e-02, 1.6582e-02],
[-7.2003e-02, 1.4616e-02, -3.5987e-02, 3.0575e-02, -4.4705e-02],
[-4.7500e-02, -1.9091e-02, 1.2661e-02, 2.4751e-02, -7.1824e-02],
[ 4.3771e-02, 4.9023e-02, 7.2368e-02, -2.3195e-02, -3.0777e-02]],
[[-8.2526e-03, -1.3523e-02, -6.9580e-02, 2.5552e-02, -1.5779e-02],
[ 2.2318e-03, 2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
[ 7.4568e-02, -4.6680e-02, 7.4333e-02, -6.5834e-02, 8.0266e-02],
[ 1.0070e-02, 5.4708e-02, -1.4732e-03, 1.9077e-02, -2.5033e-02],
[-2.7357e-02, 1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],
[[ 4.0488e-02, -7.1390e-02, -2.3527e-02, 1.3764e-02, -1.5115e-02],
[-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
[-7.0937e-02, -5.7632e-02, -1.2339e-02, 5.2566e-02, -5.8696e-02],
[-6.4373e-02, 5.0876e-02, -6.8186e-02, 6.8750e-02, 4.6615e-02],
[ 3.0661e-02, -6.8377e-02, 1.7900e-02, -8.8543e-03, 1.4958e-02]]]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0482, 0.0700, -0.0483, 0.0543, 0.0427, -0.0160, 0.0652, 0.0702,
-0.0693, -0.0685, 0.0211, 0.0540, 0.0356, 0.0235, 0.0208, 0.0305])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[-0.0222, 0.0364, -0.0196, ..., -0.0034, 0.0235, 0.0133],
[ 0.0133, 0.0200, -0.0305, ..., -0.0283, 0.0068, 0.0075],
[ 0.0367, -0.0009, 0.0018, ..., 0.0404, -0.0319, -0.0467],
...,
[-0.0015, -0.0323, 0.0344, ..., 0.0226, -0.0282, 0.0235],
[ 0.0195, 0.0296, 0.0472, ..., -0.0476, 0.0057, -0.0356],
[-0.0482, 0.0452, 0.0364, ..., 0.0029, 0.0322, 0.0217]],
requires_grad=True)
tensor: tensor([[-3.8304e-02, 5.6343e-02, -1.1556e-01, ..., 8.4700e-02,
-2.3077e-02, -9.6969e-02],
[-6.9593e-02, 4.2074e-02, 2.1176e-03, ..., -5.1170e-02,
-1.7429e-02, 6.8395e-02],
[-1.5431e-02, 4.9186e-03, -4.7218e-02, ..., 4.2745e-02,
-3.6945e-02, -8.5793e-02],
...,
[ 5.5533e-02, 2.4006e-02, 2.3173e-02, ..., 2.2089e-02,
-4.7982e-03, -3.0608e-02],
[-9.1586e-06, 6.9304e-02, 6.9180e-02, ..., -1.3455e-02,
4.2434e-02, -3.7071e-02],
[-2.3114e-02, 5.0826e-02, -7.1701e-02, ..., 5.6384e-02,
2.2557e-02, -4.0097e-02]], grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0030, -0.0106, -0.0147, 0.0311, -0.0320, 0.0124, -0.0280, 0.0140,
0.0088, -0.0196, 0.0488, -0.0001, 0.0003, 0.0401, -0.0014, 0.0427,
-0.0403, -0.0008, 0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
0.0320, -0.0158, -0.0113, -0.0198, 0.0193, 0.0356, -0.0264, -0.0160,
0.0050, 0.0121, -0.0498, 0.0146, -0.0372, 0.0089, -0.0298, 0.0399,
0.0347, -0.0108, 0.0353, -0.0157, -0.0174, -0.0355, 0.0131, 0.0192,
0.0432, -0.0373, 0.0332, -0.0114, 0.0318, -0.0132, -0.0002, -0.0403,
0.0447, -0.0203, 0.0274, -0.0342, -0.0080, 0.0389, 0.0318, 0.0043,
0.0192, 0.0158, 0.0490, 0.0272, -0.0142, -0.0218, 0.0353, -0.0035,
0.0169, -0.0432, 0.0079, 0.0499, -0.0018, 0.0296, -0.0337, -0.0214,
0.0376, 0.0054, 0.0384, 0.0403, -0.0050, -0.0075, -0.0203, 0.0318,
0.0285, -0.0415, 0.0395, -0.0045, -0.0020, 0.0245, -0.0361, 0.0150,
0.0347, 0.0185, -0.0093, -0.0056, 0.0021, -0.0026, 0.0046, 0.0380,
-0.0403, -0.0422, -0.0218, -0.0198, -0.0446, 0.0296, 0.0276, 0.0089,
-0.0049, 0.0100, 0.0118, 0.0283, -0.0334, 0.0287, 0.0236, -0.0404],
requires_grad=True)
tensor: tensor([ 0.0523, -0.0292, 0.0273, 0.1130, -0.1221, -0.0301, -0.0207, -0.0004,
0.0262, -0.0813, 0.0444, -0.0165, -0.0252, 0.0806, -0.0251, -0.0477,
-0.0964, -0.0773, 0.0956, -0.1304, -0.0227, -0.0166, -0.0507, -0.0412,
0.0829, -0.0533, 0.0776, -0.0555, 0.1376, 0.0208, -0.0055, -0.0125,
-0.0041, -0.0186, 0.0247, 0.0902, -0.0040, -0.0262, -0.0154, 0.0425,
0.0448, -0.0340, 0.0756, -0.0699, -0.0605, -0.0363, 0.0278, 0.0412,
-0.0084, 0.0181, 0.0751, -0.0466, -0.0437, -0.0335, 0.0715, -0.0624,
0.0333, 0.0213, 0.0205, 0.0236, -0.0448, -0.0932, 0.0688, 0.0458,
0.0226, 0.0806, 0.0200, 0.0317, -0.0296, -0.0642, 0.0867, -0.1344,
0.0021, 0.0203, -0.0926, -0.0053, -0.0611, 0.0150, -0.0447, -0.0945,
0.0837, 0.0649, 0.0609, 0.0122, 0.0407, -0.0284, 0.0042, 0.0202,
0.0736, -0.0268, -0.0105, -0.0601, 0.0477, 0.0708, 0.0400, 0.0461,
-0.0125, -0.0138, -0.0739, 0.0362, 0.0398, -0.0847, -0.0054, 0.0625,
-0.1151, -0.0861, -0.1333, -0.1284, 0.0016, 0.0346, 0.0692, 0.0263,
-0.0565, -0.0479, -0.0471, 0.0167, -0.0569, 0.1117, 0.0053, -0.0830],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., 0., -0., ..., -0., 0., 0.],
[0., 0., -0., ..., -0., 0., 0.],
[0., -0., 0., ..., 0., -0., -0.],
...,
[-0., -0., 0., ..., 0., -0., 0.],
[0., 0., 0., ..., -0., 0., -0.],
[-0., 0., 0., ..., 0., 0., 0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.0222, 0.0364, -0.0196, ..., -0.0034, 0.0235, 0.0133],
[ 0.0133, 0.0200, -0.0305, ..., -0.0283, 0.0068, 0.0075],
[ 0.0367, -0.0009, 0.0018, ..., 0.0404, -0.0319, -0.0467],
...,
[-0.0015, -0.0323, 0.0344, ..., 0.0226, -0.0282, 0.0235],
[ 0.0195, 0.0296, 0.0472, ..., -0.0476, 0.0057, -0.0356],
[-0.0482, 0.0452, 0.0364, ..., 0.0029, 0.0322, 0.0217]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0030, -0.0106, -0.0147, 0.0311, -0.0320, 0.0124, -0.0280, 0.0140,
0.0088, -0.0196, 0.0488, -0.0001, 0.0003, 0.0401, -0.0014, 0.0427,
-0.0403, -0.0008, 0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
0.0320, -0.0158, -0.0113, -0.0198, 0.0193, 0.0356, -0.0264, -0.0160,
0.0050, 0.0121, -0.0498, 0.0146, -0.0372, 0.0089, -0.0298, 0.0399,
0.0347, -0.0108, 0.0353, -0.0157, -0.0174, -0.0355, 0.0131, 0.0192,
0.0432, -0.0373, 0.0332, -0.0114, 0.0318, -0.0132, -0.0002, -0.0403,
0.0447, -0.0203, 0.0274, -0.0342, -0.0080, 0.0389, 0.0318, 0.0043,
0.0192, 0.0158, 0.0490, 0.0272, -0.0142, -0.0218, 0.0353, -0.0035,
0.0169, -0.0432, 0.0079, 0.0499, -0.0018, 0.0296, -0.0337, -0.0214,
0.0376, 0.0054, 0.0384, 0.0403, -0.0050, -0.0075, -0.0203, 0.0318,
0.0285, -0.0415, 0.0395, -0.0045, -0.0020, 0.0245, -0.0361, 0.0150,
0.0347, 0.0185, -0.0093, -0.0056, 0.0021, -0.0026, 0.0046, 0.0380,
-0.0403, -0.0422, -0.0218, -0.0198, -0.0446, 0.0296, 0.0276, 0.0089,
-0.0049, 0.0100, 0.0118, 0.0283, -0.0334, 0.0287, 0.0236, -0.0404])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=2, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 5.3861e-02, -3.2810e-02, 1.8512e-02, 4.5986e-02, 5.1913e-02,
-4.3799e-02, 5.6503e-03, 2.7175e-02, 6.8377e-02, 1.3057e-02,
-3.2690e-02, -6.1064e-02, 3.8826e-02, 6.6544e-02, 5.5398e-02,
3.0196e-02, 1.7987e-02, 5.5327e-02, -5.0021e-02, -7.3435e-02,
6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
-7.9152e-03, 2.1925e-02, 4.2662e-02, 2.8947e-06, -8.5279e-02,
-7.4841e-03, 9.8418e-03, -7.7589e-02, 8.2463e-02, 4.1143e-02,
3.5816e-03, 4.9777e-02, -4.8500e-02, 5.3974e-02, -6.0017e-02,
6.7712e-02, 2.5674e-02, 6.1376e-02, 4.4900e-03, -6.7039e-02,
-7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
-7.6909e-02, 7.4443e-02, -2.8730e-02, 5.1316e-02, 1.8881e-04,
-3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
-5.9713e-03, -5.5715e-02, 5.2674e-02, -1.8986e-02, 1.1715e-02,
7.8946e-02, 8.7757e-02, 7.8873e-02, -4.0747e-02, -9.0934e-02,
8.4907e-02, 6.0053e-02, 6.8620e-02, -6.7833e-02, 4.9746e-02,
-9.3147e-03, 4.5588e-02, 2.6278e-02, 7.8289e-02, 6.9648e-02,
-1.8778e-02, 7.0376e-02, 4.2418e-02, 8.1578e-02, -2.5462e-03,
-8.2635e-02, -3.2393e-02, -1.8944e-02, 5.2082e-02, 7.6844e-02,
7.8650e-02, 1.0107e-02, 5.8002e-02, 5.6146e-02, 1.0183e-02,
8.1826e-02, 2.2654e-02, 1.4227e-02, 6.7762e-02, -1.4747e-02,
-1.7642e-02, -6.6754e-02, -3.3121e-02, 8.3696e-02, -1.6725e-02,
5.1801e-02, 8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
-2.1219e-02, -8.6543e-02, 1.6206e-02, -3.9126e-02, 5.2650e-02,
8.4007e-02, 6.5445e-03, -2.2124e-02, -2.6831e-03, 9.0644e-02],
[-7.6951e-02, -7.7030e-02, -2.8385e-02, 8.8830e-02, 1.8159e-02,
1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
-3.8862e-02, 8.1929e-02, 9.8960e-03, -7.4610e-02, 4.4434e-02,
9.1217e-02, 1.7353e-02, -4.1310e-02, 4.6109e-03, -6.0577e-03,
-5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
-2.7643e-02, 2.0528e-02, -3.5310e-02, -7.1740e-02, 9.2993e-03,
-4.4849e-02, -2.9480e-02, 8.8236e-02, 6.5175e-03, -1.8974e-02,
-6.3304e-02, -4.7579e-02, -7.5561e-02, 6.3819e-03, 6.4816e-02,
-5.2397e-03, 4.9444e-02, -7.8613e-02, 3.2730e-02, -4.0234e-02,
7.0352e-02, 9.4854e-03, -6.6504e-02, -5.6702e-02, 8.2212e-02,
-1.0986e-03, -7.2008e-02, -2.7400e-02, 2.1016e-02, 3.5173e-02,
-7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02, 3.6873e-02,
-4.1272e-02, -5.1494e-03, -7.6218e-03, 6.9450e-02, 3.0325e-02,
1.2730e-02, -6.4263e-02, 8.1150e-02, -7.4637e-02, 4.3607e-02,
-5.5105e-02, 7.7573e-02, 7.7680e-02, 4.4945e-02, -7.6075e-02,
3.1183e-02, 4.1456e-02, -5.6365e-02, -6.4912e-02, 1.1328e-02,
-6.9009e-02, 2.9723e-02, 3.2704e-02, 1.9641e-02, -2.5972e-02,
-4.7365e-02, -6.7854e-02, 3.1308e-02, -4.7928e-02, -3.9339e-02,
1.7182e-03, -7.1567e-02, -7.8225e-02, 8.7315e-02, 4.9131e-02,
-8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02, 3.4342e-02,
-1.7424e-02, 2.5424e-02, -7.5859e-02, 8.9858e-02, -6.9163e-02,
-7.2367e-02, 8.4212e-02, 5.0855e-02, 1.1414e-03, 7.7345e-02,
5.0493e-02, 7.5050e-02, -2.1079e-02, 7.8312e-02, 7.6594e-03,
4.9073e-02, 6.5474e-02, 2.7924e-02, -1.4094e-02, 4.2029e-02]],
requires_grad=True)
tensor: tensor([[ 0.0610, -0.0128, -0.0120, 0.0714, 0.0156, 0.0183, 0.0681, -0.0840,
0.1129, 0.0066, -0.0570, -0.0925, 0.0874, 0.0386, 0.0270, 0.0112,
0.0199, 0.0883, -0.0757, -0.0829, 0.0854, -0.1206, -0.0577, -0.0897,
0.0741, -0.0573, 0.0606, -0.0012, 0.0103, -0.1070, -0.0114, 0.1001,
-0.0395, -0.0006, 0.0756, 0.0616, 0.0758, -0.0650, 0.0724, -0.0838,
0.1125, 0.0607, 0.0729, -0.0289, -0.1317, -0.1161, -0.0362, 0.0083,
-0.0673, 0.0753, -0.0666, 0.0700, -0.0002, 0.0904, 0.0204, -0.0528,
-0.1112, 0.0246, -0.0508, -0.0364, -0.0725, -0.0273, 0.0819, 0.0152,
-0.0765, 0.0628, 0.0228, -0.0045, 0.0116, -0.1370, 0.0286, 0.0078,
0.0368, -0.1036, -0.0161, -0.0211, 0.0430, -0.0009, 0.0646, 0.0536,
-0.0243, 0.1421, -0.0034, 0.0893, -0.0097, -0.1771, 0.0357, -0.0548,
0.0414, 0.0090, 0.0756, -0.0161, 0.0447, 0.0344, 0.0186, 0.1482,
-0.0190, 0.0250, 0.0333, -0.1116, -0.0008, -0.0638, 0.0016, 0.1462,
0.0523, -0.0028, 0.1291, 0.0462, -0.0976, -0.0499, -0.0413, -0.0380,
0.0625, 0.0406, -0.0037, 0.0522, -0.0341, -0.0405, -0.0160, 0.1534],
[-0.0298, -0.1157, -0.0818, 0.0838, -0.0245, 0.0392, -0.1349, -0.0704,
-0.1346, -0.0168, 0.0534, 0.0690, -0.0314, -0.0291, 0.0875, 0.1339,
0.0260, -0.1362, 0.0432, 0.0740, -0.0278, -0.0642, -0.1747, 0.0349,
-0.0983, 0.0362, -0.0348, -0.1002, -0.0262, 0.0045, -0.1418, -0.0607,
0.0709, -0.0468, -0.0267, -0.0706, -0.0858, -0.0905, 0.0190, 0.0397,
-0.0106, -0.0135, -0.0531, -0.0055, -0.0194, 0.0866, 0.0335, -0.0849,
-0.0636, -0.0231, -0.0386, -0.0960, -0.0768, 0.0532, 0.0323, -0.0124,
-0.0476, -0.0436, -0.0623, 0.0220, -0.0545, 0.0014, -0.0546, 0.1407,
-0.0416, -0.0167, -0.0326, 0.0189, -0.0989, 0.0821, -0.0453, 0.1176,
0.0483, 0.1579, -0.1453, -0.0338, 0.1282, -0.0891, -0.0673, -0.0274,
-0.1313, 0.0243, -0.0285, -0.0651, -0.0463, 0.0518, -0.0594, -0.0117,
-0.0808, 0.0316, 0.0344, -0.0550, -0.0490, 0.0915, 0.0668, -0.1171,
0.0758, -0.0164, -0.1157, 0.1160, 0.0045, 0.0031, -0.1212, 0.0278,
-0.0541, -0.1058, 0.0086, 0.0400, 0.0132, 0.1076, 0.1130, 0.0112,
0.0434, 0.0284, 0.0447, 0.0700, 0.0243, -0.0166, -0.0543, 0.0415]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0687, -0.0438], requires_grad=True)
tensor: tensor([-0.0407, 0.0235], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., 0., 0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0., -0., -0.,
-0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., -0., -0., -0., -0.,
-0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0.],
[-0., -0., -0., 0., 0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., -0., -0.,
-0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0.,
-0., 0., -0., -0., -0., 0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0.,
0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0.,
-0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., 0., -0., 0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 5.3861e-02, -3.2810e-02, 1.8512e-02, 4.5986e-02, 5.1913e-02,
-4.3799e-02, 5.6503e-03, 2.7175e-02, 6.8377e-02, 1.3057e-02,
-3.2690e-02, -6.1064e-02, 3.8826e-02, 6.6544e-02, 5.5398e-02,
3.0196e-02, 1.7987e-02, 5.5327e-02, -5.0021e-02, -7.3435e-02,
6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
-7.9152e-03, 2.1925e-02, 4.2662e-02, 2.8947e-06, -8.5279e-02,
-7.4841e-03, 9.8418e-03, -7.7589e-02, 8.2463e-02, 4.1143e-02,
3.5816e-03, 4.9777e-02, -4.8500e-02, 5.3974e-02, -6.0017e-02,
6.7712e-02, 2.5674e-02, 6.1376e-02, 4.4900e-03, -6.7039e-02,
-7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
-7.6909e-02, 7.4443e-02, -2.8730e-02, 5.1316e-02, 1.8881e-04,
-3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
-5.9713e-03, -5.5715e-02, 5.2674e-02, -1.8986e-02, 1.1715e-02,
7.8946e-02, 8.7757e-02, 7.8873e-02, -4.0747e-02, -9.0934e-02,
8.4907e-02, 6.0053e-02, 6.8620e-02, -6.7833e-02, 4.9746e-02,
-9.3147e-03, 4.5588e-02, 2.6278e-02, 7.8289e-02, 6.9648e-02,
-1.8778e-02, 7.0376e-02, 4.2418e-02, 8.1578e-02, -2.5462e-03,
-8.2635e-02, -3.2393e-02, -1.8944e-02, 5.2082e-02, 7.6844e-02,
7.8650e-02, 1.0107e-02, 5.8002e-02, 5.6146e-02, 1.0183e-02,
8.1826e-02, 2.2654e-02, 1.4227e-02, 6.7762e-02, -1.4747e-02,
-1.7642e-02, -6.6754e-02, -3.3121e-02, 8.3696e-02, -1.6725e-02,
5.1801e-02, 8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
-2.1219e-02, -8.6543e-02, 1.6206e-02, -3.9126e-02, 5.2650e-02,
8.4007e-02, 6.5445e-03, -2.2124e-02, -2.6831e-03, 9.0644e-02],
[-7.6951e-02, -7.7030e-02, -2.8385e-02, 8.8830e-02, 1.8159e-02,
1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
-3.8862e-02, 8.1929e-02, 9.8960e-03, -7.4610e-02, 4.4434e-02,
9.1217e-02, 1.7353e-02, -4.1310e-02, 4.6109e-03, -6.0577e-03,
-5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
-2.7643e-02, 2.0528e-02, -3.5310e-02, -7.1740e-02, 9.2993e-03,
-4.4849e-02, -2.9480e-02, 8.8236e-02, 6.5175e-03, -1.8974e-02,
-6.3304e-02, -4.7579e-02, -7.5561e-02, 6.3819e-03, 6.4816e-02,
-5.2397e-03, 4.9444e-02, -7.8613e-02, 3.2730e-02, -4.0234e-02,
7.0352e-02, 9.4854e-03, -6.6504e-02, -5.6702e-02, 8.2212e-02,
-1.0986e-03, -7.2008e-02, -2.7400e-02, 2.1016e-02, 3.5173e-02,
-7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02, 3.6873e-02,
-4.1272e-02, -5.1494e-03, -7.6218e-03, 6.9450e-02, 3.0325e-02,
1.2730e-02, -6.4263e-02, 8.1150e-02, -7.4637e-02, 4.3607e-02,
-5.5105e-02, 7.7573e-02, 7.7680e-02, 4.4945e-02, -7.6075e-02,
3.1183e-02, 4.1456e-02, -5.6365e-02, -6.4912e-02, 1.1328e-02,
-6.9009e-02, 2.9723e-02, 3.2704e-02, 1.9641e-02, -2.5972e-02,
-4.7365e-02, -6.7854e-02, 3.1308e-02, -4.7928e-02, -3.9339e-02,
1.7182e-03, -7.1567e-02, -7.8225e-02, 8.7315e-02, 4.9131e-02,
-8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02, 3.4342e-02,
-1.7424e-02, 2.5424e-02, -7.5859e-02, 8.9858e-02, -6.9163e-02,
-7.2367e-02, 8.4212e-02, 5.0855e-02, 1.1414e-03, 7.7345e-02,
5.0493e-02, 7.5050e-02, -2.1079e-02, 7.8312e-02, 7.6594e-03,
4.9073e-02, 6.5474e-02, 2.7924e-02, -1.4094e-02, 4.2029e-02]])
(bias): Normal:
loc: tensor([-0., -0.])
scale: tensor([0.7071, 0.7071])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0687, -0.0438])
)
(observed): Observed()
)
)
Fit the model¶
Finally we can set up the training loop
optim = torch.optim.Adam(net.parameters())
for i in range(1):
for data, target in loader:
net.observe(classification=target)
borch.sample(net)
net(data)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
loss.backward()
optim.step()
optim.zero_grad()
Now we can check the accuracy, Note that one should stop condtioning on the target by setting net.observe(None)
net.observe(None)
tot_acc = 0
with torch.no_grad():
for i, (data, target) in enumerate(loader):
borch.sample(net)
out = net(data)
acc = float((target == out).sum().float() / target.shape[0]) * 100
tot_acc += acc
tot_acc /= i + 1
print(tot_acc)
Out:
50.0
the accuracy is basically random, this is due to the fact that we are fitting white noise so it to be expected.
But in case you have trouble getting higher accuracy you should consider running for more epochs, setting up an augmentation pipeline (see: the data loading tutorial) and changing your posterior. The posterior can be changed using
net.apply(borch.set_posteriors(borch.posterior.Automatic))
Out:
Net(
(posterior): Automatic()
(prior): Module(
(classification): Categorical:
logits: tensor([[ 0.3925, -0.3818],
[ 0.2547, -0.3470],
[ 0.3052, -0.3358],
[ 0.3430, -0.1107],
[ 0.2112, -0.4776],
[ 0.4108, -0.1150],
[ 0.5736, -0.1190],
[ 0.5299, -0.1948],
[ 0.4381, -0.1530],
[ 0.4474, -0.2609],
[ 0.1968, -0.2226],
[ 0.3060, -0.2168],
[ 0.5697, -0.3060],
[ 0.4261, -0.2987],
[ 0.3829, -0.3391],
[ 0.3738, -0.3848],
[ 0.2025, -0.4219],
[ 0.1772, -0.2161],
[ 0.3705, -0.2277],
[ 0.3645, -0.5349]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([])
)
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., -0., -0., 0., 0., -0.], requires_grad=True)
tensor: tensor([ 0.3744, 0.0923, 0.0835, -0.4157, -0.3875, 0.3015],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.]]],
[[[-0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.]]],
[[[0., 0., -0., 0., -0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., -0., 0.],
[0., -0., -0., -0., -0.]]],
[[[0., 0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., 0., -0., -0.],
[-0., 0., -0., -0., 0.]]],
[[[0., 0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., 0., -0., 0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., 0., -0.]]],
[[[-0., 0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., 0., 0., -0., -0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-0.1567, 0.0194, 0.0769, -0.0673, -0.0263],
[-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
[-0.1941, 0.1049, -0.0666, 0.1288, 0.1553],
[-0.0577, 0.0011, 0.0145, -0.0518, -0.1235],
[-0.1586, 0.1436, -0.1999, -0.0364, -0.0872]]],
[[[-0.0281, 0.1596, -0.0365, -0.0633, -0.1375],
[ 0.0499, -0.0787, 0.1923, -0.1312, 0.0520],
[ 0.0351, -0.0756, -0.0147, 0.1359, 0.1740],
[ 0.0815, 0.0631, 0.0970, 0.0557, -0.1079],
[-0.0794, -0.0029, -0.1620, -0.1147, 0.0705]]],
[[[ 0.1613, 0.0524, -0.1935, 0.0940, -0.0838],
[-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
[ 0.1950, -0.1447, -0.1897, 0.1429, 0.0565],
[ 0.1752, 0.0891, -0.0210, -0.0461, 0.1512],
[ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],
[[[ 0.1134, 0.0826, -0.1147, -0.1228, -0.0040],
[-0.1923, 0.1244, 0.1793, 0.0183, -0.0433],
[-0.1840, -0.1691, 0.1184, 0.0151, -0.1467],
[-0.1828, 0.0363, 0.1660, -0.0660, -0.1319],
[-0.1644, 0.1835, -0.0681, -0.0800, 0.0668]]],
[[[ 0.0005, 0.1977, 0.1792, 0.0681, 0.1410],
[ 0.0081, -0.0216, -0.0456, -0.1985, 0.0049],
[-0.0568, -0.1405, 0.0604, -0.1020, 0.0084],
[-0.1587, -0.1756, 0.1148, 0.0872, -0.1307],
[ 0.1244, 0.0193, 0.0314, 0.0787, -0.1329]]],
[[[-0.1363, 0.0538, 0.0294, 0.0635, -0.0873],
[-0.0751, -0.0357, -0.0701, 0.0053, -0.1896],
[-0.0187, -0.0438, -0.0541, -0.1959, 0.0103],
[-0.1991, 0.0912, 0.1853, -0.0772, 0.0883],
[-0.1368, 0.0420, 0.0322, -0.1162, -0.0013]]]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., 0., -0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.1428, -0.1690, -0.0722, 0.0667, 0.0933, -0.0337])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
requires_grad=True)
tensor: tensor([ 0.3320, 0.1212, -0.0444, 0.3015, 0.2103, -0.0697, -0.1125, 0.1532,
-0.1150, 0.0043, -0.0299, -0.2236, 0.1937, 0.1273, -0.6395, -0.3135],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[-0., 0., -0., -0., -0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., -0., 0.]],
[[-0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.]],
[[-0., 0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.]],
[[-0., 0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[-0., 0., 0., -0., -0.]],
[[-0., -0., 0., 0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0.],
[-0., -0., -0., -0., 0.]],
[[-0., -0., -0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., -0., 0., 0.]],
[[0., 0., -0., -0., -0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.]]],
[[[0., 0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., -0., 0., 0., 0.],
[0., 0., -0., 0., 0.]],
[[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., -0., -0.]],
[[0., 0., -0., 0., 0.],
[0., -0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.],
[0., -0., -0., -0., 0.]],
[[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., -0., -0.],
[-0., 0., 0., 0., -0.]],
[[-0., -0., -0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., -0., -0., 0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., 0., 0.],
[0., -0., 0., 0., -0.]]],
...,
[[[0., -0., -0., -0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., -0., -0.],
[-0., -0., 0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[-0., -0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., -0.]],
[[0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., 0., 0.],
[-0., 0., -0., -0., 0.]],
[[0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.]],
[[-0., -0., -0., -0., 0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.],
[-0., -0., 0., -0., -0.],
[0., -0., -0., 0., -0.]],
[[0., -0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[-0., -0., -0., -0., 0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[-0., -0., 0., -0., 0.]],
[[0., -0., 0., 0., 0.],
[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., 0.]],
[[0., 0., -0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[0., -0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[-0., -0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., 0., -0.]],
[[-0., -0., -0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0.]]],
[[[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[0., -0., -0., -0., -0.],
[-0., -0., -0., -0., -0.]],
[[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., -0., 0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., -0., 0., -0.],
[-0., 0., -0., -0., -0.]],
[[0., -0., -0., 0., -0.],
[-0., -0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., -0., 0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02, 1.6377e-02, 6.2508e-02],
[ 4.0202e-02, 4.9382e-02, 1.0253e-02, 3.4932e-02, -2.9341e-02],
[-6.4074e-02, 5.2063e-02, -3.8546e-02, 1.5865e-02, -4.8314e-02],
[ 2.0698e-02, 8.0888e-02, 7.9982e-02, 5.6686e-02, 1.1419e-02],
[-6.1541e-02, 6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],
[[ 8.0672e-02, -7.8271e-02, 7.9289e-02, 2.1951e-02, 5.9238e-02],
[ 1.1237e-02, -5.2598e-02, 4.1551e-02, 4.3546e-02, -5.8132e-02],
[ 3.0288e-02, -7.5402e-03, 2.8025e-02, 1.5065e-02, -3.7260e-02],
[-2.1855e-02, 4.4424e-02, -6.7250e-02, 4.3303e-02, -4.0667e-02],
[ 4.9321e-02, -2.0541e-02, 5.6232e-02, -3.8466e-02, 5.1026e-02]],
[[-2.6513e-02, 4.9569e-02, 1.8676e-02, 6.9054e-02, -2.6767e-02],
[-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02, 2.8787e-04],
[ 1.3084e-02, 4.6926e-02, -1.3894e-02, 2.7116e-02, -7.9155e-02],
[ 7.5190e-02, -2.2396e-02, 5.5935e-02, 6.8270e-02, -3.2401e-02],
[-1.0770e-02, 3.6370e-02, -7.0476e-02, 6.4846e-02, -6.7540e-02]],
[[-3.7310e-02, 5.9737e-02, 5.9160e-05, 3.5063e-02, 7.1916e-02],
[ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02, 4.9856e-02],
[-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
[ 9.3012e-04, -2.4376e-02, -1.1352e-03, 1.1929e-02, -1.3466e-02],
[ 5.6119e-02, 2.2037e-02, 6.6074e-02, 3.7063e-02, 4.2302e-02]],
[[-3.6769e-02, 3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
[-2.8501e-02, -8.0159e-02, -2.8122e-02, 1.8005e-02, -3.7241e-02],
[ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02, 2.2794e-02],
[-6.3064e-02, 6.8354e-02, 4.3337e-02, -4.3442e-02, -6.5524e-02],
[-5.6791e-02, 7.4918e-02, 5.9263e-02, -1.2248e-04, -3.5795e-02]],
[[-2.1802e-02, -9.4069e-03, 7.6862e-02, 1.4458e-02, -7.0315e-02],
[-5.9874e-02, 5.0751e-02, 3.9215e-02, 7.7538e-02, 3.2766e-02],
[-4.4479e-02, 5.2775e-02, -1.9053e-03, 3.2342e-02, -4.6321e-02],
[ 3.8624e-02, 2.0750e-02, 4.6860e-02, -5.3962e-02, 6.4275e-02],
[-6.5864e-02, -1.3428e-02, 1.7864e-02, -7.0872e-02, -4.7788e-02]]],
[[[-6.8434e-03, 1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
[-4.6419e-03, 7.9605e-02, 2.4481e-02, -1.8153e-02, -6.5677e-02],
[-4.4087e-02, -3.3966e-02, 7.7397e-02, -2.4133e-02, -2.1281e-02],
[ 7.2108e-02, 5.7730e-02, 1.8671e-02, -6.6956e-02, -7.8317e-02],
[ 5.1377e-02, 4.0833e-02, 1.3915e-02, -7.3671e-02, -6.4171e-02]],
[[-2.2992e-02, -1.5135e-02, -1.6468e-02, 3.4729e-02, -5.0686e-02],
[ 3.3976e-02, 4.7649e-02, 9.0078e-03, 6.1119e-02, 7.7885e-02],
[ 2.0299e-02, 4.7323e-02, 5.7702e-02, 4.8393e-02, -4.6623e-02],
[ 4.0535e-02, 5.1355e-02, -3.0590e-02, 7.3194e-02, -6.9639e-02],
[-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02, 1.5673e-02]],
[[-2.4125e-02, -2.1872e-02, -4.1021e-02, 1.5590e-02, 7.0267e-02],
[-1.2325e-02, 4.8418e-02, 3.7418e-02, 5.6973e-02, 1.5516e-02],
[-5.0112e-02, 4.1789e-02, 5.5392e-02, -3.5548e-02, 2.8206e-02],
[ 5.9003e-02, 7.3764e-03, 1.5419e-02, -1.6909e-02, -4.0654e-02],
[ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],
[[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
[-1.6499e-02, 5.8059e-02, 3.5353e-02, 9.1365e-03, -2.8343e-02],
[ 4.6293e-02, 3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
[ 8.0904e-02, 1.1077e-02, 4.2119e-02, 5.4343e-02, 2.1999e-02],
[ 6.2393e-02, 2.9035e-02, 2.6746e-02, -2.8318e-02, -4.9989e-02]],
[[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
[-7.4939e-03, 9.3801e-03, 1.3410e-02, 5.7410e-02, 4.4898e-03],
[-1.7655e-02, 3.3112e-02, -6.4055e-02, -1.3577e-02, 6.3291e-02],
[-3.9472e-02, 1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
[-6.3974e-02, -6.9448e-02, -2.1668e-02, 2.4177e-02, 3.0717e-02]],
[[ 2.8530e-02, 2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
[ 5.4420e-02, -1.2094e-02, 2.1143e-02, -6.5897e-02, -2.8639e-02],
[ 4.7260e-02, 4.2903e-02, 7.4266e-02, 6.9400e-02, -6.4887e-02],
[-7.5132e-02, -5.4750e-02, 4.6103e-03, 3.0465e-02, -6.0162e-02],
[ 3.2707e-02, -3.7524e-02, -6.7505e-02, 2.1123e-02, -1.5651e-02]]],
[[[ 6.0358e-02, 5.9896e-02, -3.1081e-02, 5.8683e-02, 1.5452e-02],
[ 6.0256e-02, -7.1520e-02, 7.8586e-02, -3.8772e-02, -7.1890e-02],
[-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
[-3.2572e-02, -6.9616e-02, 2.9291e-02, 7.2235e-02, 2.5144e-02],
[ 7.6527e-02, 3.1913e-03, -1.8299e-02, 1.5759e-02, 6.3982e-02]],
[[-6.2217e-02, 2.4136e-02, 6.5684e-02, -5.2996e-02, -8.5318e-03],
[ 5.4878e-02, -7.0118e-02, 5.6222e-02, -5.8217e-02, -1.2457e-02],
[-9.0102e-03, 4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
[ 3.9352e-02, 4.2597e-02, 6.4178e-02, -1.6116e-02, 4.4007e-02],
[-4.6907e-02, 5.0872e-02, 1.4034e-02, -7.7642e-02, -3.1652e-02]],
[[ 1.8691e-02, 3.9128e-02, -1.8538e-03, 6.9222e-02, 4.7985e-02],
[ 2.5163e-02, -3.2308e-02, 5.8934e-02, 6.4200e-02, 7.5079e-02],
[-3.8752e-02, -6.2834e-02, -1.3630e-02, 4.7745e-02, -3.6710e-02],
[-6.5912e-02, -7.5509e-02, 2.0538e-03, 6.1806e-02, 4.7332e-02],
[ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02, 6.5297e-02]],
[[-4.0377e-02, 2.2776e-02, 2.0396e-03, 6.3307e-02, 7.7342e-02],
[-8.6686e-03, 6.8417e-02, 4.9833e-02, -5.8394e-02, -6.8530e-02],
[-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
[-7.8387e-02, 1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
[-6.5617e-02, 5.9709e-02, 7.6255e-02, 7.4220e-02, -3.9595e-02]],
[[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02, 1.1639e-02],
[ 4.5532e-02, 4.3314e-02, 2.7463e-02, -3.2127e-02, -6.0824e-02],
[ 4.7108e-02, 1.2112e-02, -1.6862e-02, -5.4160e-02, 4.8685e-03],
[-2.5893e-02, 6.4832e-02, 3.3282e-02, 4.9884e-02, 6.3713e-02],
[-1.9860e-02, 1.2712e-02, 7.0452e-04, 4.6135e-02, -5.4728e-02]],
[[-5.1852e-02, -1.5589e-02, -3.1799e-02, 4.5747e-02, -2.9827e-02],
[ 6.4932e-02, -5.3074e-02, -4.9272e-02, 1.8426e-02, -4.6095e-02],
[ 4.1712e-02, -7.9372e-02, -7.9577e-02, 1.2126e-02, 4.2022e-02],
[ 5.8650e-02, 1.3046e-02, -5.0546e-02, 7.1611e-02, 5.8748e-02],
[ 4.8559e-02, -8.1399e-02, 6.8672e-02, 7.3071e-02, -4.6508e-02]]],
...,
[[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02, 8.1223e-02],
[-2.6174e-02, 4.4648e-02, -1.7627e-05, 6.8991e-02, 9.3131e-04],
[-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
[-2.9851e-02, -7.2093e-02, 6.1742e-03, -6.0501e-02, -3.0240e-02],
[-3.6711e-02, -3.9918e-02, 1.8570e-02, 2.7867e-02, -3.9091e-02]],
[[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02, 2.6511e-02],
[-1.1918e-02, 1.9399e-02, 1.9778e-03, 4.6715e-02, 7.9662e-02],
[-3.9779e-02, -3.2971e-02, 2.6502e-03, 6.0599e-02, 6.1761e-02],
[-7.2092e-02, -3.8731e-02, -1.5203e-02, 3.3408e-02, -7.3232e-02],
[ 2.1354e-02, 4.9467e-02, 6.6561e-02, 6.4517e-02, -4.9400e-02]],
[[ 2.5279e-02, 7.5811e-02, -5.6423e-02, 7.1795e-03, -8.0469e-02],
[ 6.3054e-02, 1.5441e-02, -7.9545e-02, 6.0103e-02, 4.7542e-02],
[-2.6974e-02, 6.4899e-02, -7.6267e-02, 3.2200e-02, 9.7143e-03],
[ 4.1850e-02, -1.8550e-02, -5.0626e-02, 3.7149e-02, 7.6128e-02],
[-2.8989e-02, 2.5409e-03, -3.2850e-02, -1.1957e-02, 4.2580e-04]],
[[ 4.6736e-02, 7.7891e-02, 2.2977e-02, 3.5759e-02, -7.9195e-02],
[-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03, 2.4057e-02],
[-6.7396e-03, 2.7530e-02, 1.1896e-03, -1.6895e-02, -5.0218e-02],
[ 5.6456e-02, -7.6683e-02, 2.4498e-02, -5.4710e-02, 5.6294e-02],
[-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02, 4.4595e-02]],
[[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02, 5.5213e-02],
[ 1.1116e-02, -3.8443e-02, 4.3369e-02, -7.6902e-02, 5.1385e-02],
[ 5.6263e-02, -5.4902e-02, 8.0991e-02, 1.6011e-02, 6.6421e-02],
[-2.4895e-02, -4.4881e-02, 4.6953e-02, -4.1781e-02, -4.2947e-02],
[ 8.0550e-02, -7.2696e-02, -4.6141e-02, 6.7832e-03, -1.6691e-03]],
[[ 2.6609e-02, -3.9203e-02, -7.8157e-03, 2.2936e-04, -2.7554e-02],
[ 4.0520e-02, 1.1102e-02, -2.2165e-02, 6.4671e-02, 1.1872e-02],
[ 2.5477e-02, 3.2211e-02, 5.6317e-02, 5.1697e-02, 5.5899e-02],
[-3.0296e-02, -3.9487e-02, -2.5797e-02, 5.7478e-02, -4.8781e-03],
[-6.3375e-02, -4.3827e-02, 3.5311e-03, 4.7217e-02, 6.8362e-02]]],
[[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03, 6.8520e-02],
[ 2.3377e-02, -5.9736e-03, -6.8239e-02, 7.2911e-02, -6.6242e-02],
[-1.5282e-02, 1.7386e-02, 3.9979e-02, -6.8327e-03, -1.7662e-03],
[ 5.4649e-02, -4.8377e-03, 7.7069e-02, -8.0424e-02, -2.7894e-02],
[-6.3750e-02, -2.7770e-02, 5.7462e-02, -1.8159e-02, 5.8960e-02]],
[[ 1.5038e-02, -8.0078e-02, 1.0708e-02, 2.2493e-02, 2.2514e-02],
[-2.7322e-02, 4.5916e-02, 7.1295e-02, -5.6998e-02, -5.2429e-02],
[-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
[-8.0915e-02, 1.1733e-02, 7.0872e-02, 4.2211e-02, 3.7455e-03],
[ 5.6451e-02, -2.0291e-02, 5.9699e-02, -3.8810e-02, 9.7062e-03]],
[[ 7.0948e-02, 7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
[ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
[ 6.2984e-02, 7.1278e-02, 1.8568e-02, -8.1057e-02, -7.4301e-02],
[-1.7063e-02, 3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
[ 3.3733e-02, 3.2608e-02, -1.8692e-02, 6.1727e-02, 1.0257e-02]],
[[ 4.3113e-04, -6.9241e-02, 2.2611e-02, 4.1913e-02, -6.6395e-02],
[ 7.5128e-03, -7.3346e-02, 8.0353e-02, 1.2347e-02, 5.5333e-02],
[-9.7800e-03, 4.5897e-02, 2.8835e-02, -3.6708e-02, 3.9655e-02],
[ 2.7716e-02, -7.1659e-02, 7.1108e-03, 1.1511e-02, -4.8559e-02],
[-3.0865e-02, 7.5560e-02, 2.8310e-02, 7.4005e-02, -5.0888e-02]],
[[ 7.5087e-02, 6.3344e-02, 5.9466e-02, 1.0437e-02, 9.3939e-03],
[-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
[ 1.1036e-02, -2.6916e-02, -6.9722e-02, 5.9740e-02, 4.6108e-02],
[ 2.0379e-02, 3.6167e-02, 4.8153e-02, -3.0691e-02, -5.5250e-02],
[ 3.5924e-02, 4.5421e-02, -4.7335e-02, 6.4587e-02, -5.7064e-02]],
[[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02, 7.1452e-02],
[ 1.6848e-02, -7.5881e-02, 2.5285e-02, 2.5364e-02, -1.0818e-02],
[-3.0854e-02, -2.4429e-02, -6.4815e-02, 8.1414e-03, -7.9674e-02],
[-6.2038e-02, 7.4582e-02, -1.7759e-02, 2.3795e-02, -1.5795e-02],
[ 3.7823e-02, -3.3319e-04, 7.1363e-03, 7.7572e-02, 4.3771e-02]]],
[[[-4.8656e-03, 4.3062e-02, -3.2547e-02, 9.7140e-03, -5.3167e-02],
[ 4.2759e-02, -4.1656e-02, 6.4357e-02, 3.5642e-02, -7.8376e-02],
[-1.2937e-02, 6.4533e-02, 1.5182e-02, 1.1444e-02, -7.4220e-02],
[ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
[-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],
[[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02, 2.2487e-02],
[-2.9249e-03, 6.8871e-02, 4.3621e-03, 2.6227e-02, 4.3713e-02],
[ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
[-3.5359e-02, -1.3766e-02, 3.6173e-02, 9.1202e-03, -3.9747e-02],
[-7.2135e-02, -7.3420e-02, 6.0504e-02, 3.1594e-02, 7.6891e-02]],
[[ 7.2759e-02, -6.5420e-02, 6.7763e-02, 7.2741e-02, -7.4671e-02],
[-5.5163e-02, -7.5269e-02, 1.3287e-02, 1.8645e-02, -3.4054e-02],
[ 6.5525e-02, -4.1262e-03, 4.6500e-02, -6.6291e-02, 5.8884e-02],
[ 3.0486e-02, 3.6131e-03, 1.1222e-02, -3.3646e-02, -6.5889e-02],
[ 4.7762e-02, 3.6352e-02, 9.7470e-03, 7.7495e-03, -5.5064e-02]],
[[ 4.2110e-02, 5.1736e-02, 4.9755e-02, 1.8245e-02, 3.1093e-02],
[-5.8074e-02, -4.1158e-02, -5.9566e-03, 6.2394e-02, 1.6582e-02],
[-7.2003e-02, 1.4616e-02, -3.5987e-02, 3.0575e-02, -4.4705e-02],
[-4.7500e-02, -1.9091e-02, 1.2661e-02, 2.4751e-02, -7.1824e-02],
[ 4.3771e-02, 4.9023e-02, 7.2368e-02, -2.3195e-02, -3.0777e-02]],
[[-8.2526e-03, -1.3523e-02, -6.9580e-02, 2.5552e-02, -1.5779e-02],
[ 2.2318e-03, 2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
[ 7.4568e-02, -4.6680e-02, 7.4333e-02, -6.5834e-02, 8.0266e-02],
[ 1.0070e-02, 5.4708e-02, -1.4732e-03, 1.9077e-02, -2.5033e-02],
[-2.7357e-02, 1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],
[[ 4.0488e-02, -7.1390e-02, -2.3527e-02, 1.3764e-02, -1.5115e-02],
[-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
[-7.0937e-02, -5.7632e-02, -1.2339e-02, 5.2566e-02, -5.8696e-02],
[-6.4373e-02, 5.0876e-02, -6.8186e-02, 6.8750e-02, 4.6615e-02],
[ 3.0661e-02, -6.8377e-02, 1.7900e-02, -8.8543e-03, 1.4958e-02]]]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0482, 0.0700, -0.0483, 0.0543, 0.0427, -0.0160, 0.0652, 0.0702,
-0.0693, -0.0685, 0.0211, 0.0540, 0.0356, 0.0235, 0.0208, 0.0305])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.],
requires_grad=True)
tensor: tensor([ 0.0598, 0.0720, -0.2056, -0.0677, 0.0583, 0.0021, 0.0079, 0.0343,
0.1597, 0.1023, -0.0047, 0.0677, -0.0050, -0.0228, 0.0133, 0.1010,
0.0707, -0.1196, 0.0009, 0.2002, 0.0956, 0.0229, 0.0240, -0.0367,
0.0650, 0.2132, 0.0634, -0.0123, 0.0125, 0.0665, -0.0738, 0.0954,
0.0280, -0.1036, -0.1214, -0.0371, 0.0403, 0.0539, -0.0327, -0.0039,
0.0645, 0.0404, 0.0437, -0.1351, -0.0247, 0.1724, -0.0036, -0.1897,
-0.0021, 0.1507, 0.0524, 0.1568, 0.0330, 0.0332, 0.1069, -0.0358,
-0.0007, -0.0262, 0.0324, -0.0762, 0.1193, -0.0476, -0.0853, 0.0407,
-0.0146, 0.0540, -0.0463, 0.0140, 0.1053, -0.1393, 0.0070, 0.0167,
0.0542, 0.1323, 0.0909, 0.0447, 0.0251, -0.0830, -0.0047, -0.0313,
-0.0959, -0.0684, -0.1227, 0.1241, -0.0781, 0.0037, -0.0389, 0.1310,
-0.0885, 0.0379, 0.0523, 0.0267, -0.0916, 0.1664, 0.1827, -0.0398,
-0.0645, 0.1336, 0.0071, -0.0911, -0.0304, 0.1082, 0.1484, -0.0560,
-0.0540, 0.0596, 0.1664, 0.0995, 0.1534, -0.0178, -0.0171, -0.0267,
-0.0456, 0.0378, -0.0477, 0.0744, 0.1072, -0.1904, -0.0535, 0.1517],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., 0., -0., ..., -0., 0., 0.],
[0., 0., -0., ..., -0., 0., 0.],
[0., -0., 0., ..., 0., -0., -0.],
...,
[-0., -0., 0., ..., 0., -0., 0.],
[0., 0., 0., ..., -0., 0., -0.],
[-0., 0., 0., ..., 0., 0., 0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.0222, 0.0364, -0.0196, ..., -0.0034, 0.0235, 0.0133],
[ 0.0133, 0.0200, -0.0305, ..., -0.0283, 0.0068, 0.0075],
[ 0.0367, -0.0009, 0.0018, ..., 0.0404, -0.0319, -0.0467],
...,
[-0.0015, -0.0323, 0.0344, ..., 0.0226, -0.0282, 0.0235],
[ 0.0195, 0.0296, 0.0472, ..., -0.0476, 0.0057, -0.0356],
[-0.0482, 0.0452, 0.0364, ..., 0.0029, 0.0322, 0.0217]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0030, -0.0106, -0.0147, 0.0311, -0.0320, 0.0124, -0.0280, 0.0140,
0.0088, -0.0196, 0.0488, -0.0001, 0.0003, 0.0401, -0.0014, 0.0427,
-0.0403, -0.0008, 0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
0.0320, -0.0158, -0.0113, -0.0198, 0.0193, 0.0356, -0.0264, -0.0160,
0.0050, 0.0121, -0.0498, 0.0146, -0.0372, 0.0089, -0.0298, 0.0399,
0.0347, -0.0108, 0.0353, -0.0157, -0.0174, -0.0355, 0.0131, 0.0192,
0.0432, -0.0373, 0.0332, -0.0114, 0.0318, -0.0132, -0.0002, -0.0403,
0.0447, -0.0203, 0.0274, -0.0342, -0.0080, 0.0389, 0.0318, 0.0043,
0.0192, 0.0158, 0.0490, 0.0272, -0.0142, -0.0218, 0.0353, -0.0035,
0.0169, -0.0432, 0.0079, 0.0499, -0.0018, 0.0296, -0.0337, -0.0214,
0.0376, 0.0054, 0.0384, 0.0403, -0.0050, -0.0075, -0.0203, 0.0318,
0.0285, -0.0415, 0.0395, -0.0045, -0.0020, 0.0245, -0.0361, 0.0150,
0.0347, 0.0185, -0.0093, -0.0056, 0.0021, -0.0026, 0.0046, 0.0380,
-0.0403, -0.0422, -0.0218, -0.0198, -0.0446, 0.0296, 0.0276, 0.0089,
-0.0049, 0.0100, 0.0118, 0.0283, -0.0334, 0.0287, 0.0236, -0.0404])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=2, bias=True
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.7071, 0.7071], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., -0.], requires_grad=True)
tensor: tensor([0.3527, 0.0185], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., 0., 0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0., -0., -0.,
-0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., -0., -0., -0., -0.,
-0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0.],
[-0., -0., -0., 0., 0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., -0., -0.,
-0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0.,
-0., 0., -0., -0., -0., 0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0.,
0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0.,
-0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., 0., -0., 0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 5.3861e-02, -3.2810e-02, 1.8512e-02, 4.5986e-02, 5.1913e-02,
-4.3799e-02, 5.6503e-03, 2.7175e-02, 6.8377e-02, 1.3057e-02,
-3.2690e-02, -6.1064e-02, 3.8826e-02, 6.6544e-02, 5.5398e-02,
3.0196e-02, 1.7987e-02, 5.5327e-02, -5.0021e-02, -7.3435e-02,
6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
-7.9152e-03, 2.1925e-02, 4.2662e-02, 2.8947e-06, -8.5279e-02,
-7.4841e-03, 9.8418e-03, -7.7589e-02, 8.2463e-02, 4.1143e-02,
3.5816e-03, 4.9777e-02, -4.8500e-02, 5.3974e-02, -6.0017e-02,
6.7712e-02, 2.5674e-02, 6.1376e-02, 4.4900e-03, -6.7039e-02,
-7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
-7.6909e-02, 7.4443e-02, -2.8730e-02, 5.1316e-02, 1.8881e-04,
-3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
-5.9713e-03, -5.5715e-02, 5.2674e-02, -1.8986e-02, 1.1715e-02,
7.8946e-02, 8.7757e-02, 7.8873e-02, -4.0747e-02, -9.0934e-02,
8.4907e-02, 6.0053e-02, 6.8620e-02, -6.7833e-02, 4.9746e-02,
-9.3147e-03, 4.5588e-02, 2.6278e-02, 7.8289e-02, 6.9648e-02,
-1.8778e-02, 7.0376e-02, 4.2418e-02, 8.1578e-02, -2.5462e-03,
-8.2635e-02, -3.2393e-02, -1.8944e-02, 5.2082e-02, 7.6844e-02,
7.8650e-02, 1.0107e-02, 5.8002e-02, 5.6146e-02, 1.0183e-02,
8.1826e-02, 2.2654e-02, 1.4227e-02, 6.7762e-02, -1.4747e-02,
-1.7642e-02, -6.6754e-02, -3.3121e-02, 8.3696e-02, -1.6725e-02,
5.1801e-02, 8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
-2.1219e-02, -8.6543e-02, 1.6206e-02, -3.9126e-02, 5.2650e-02,
8.4007e-02, 6.5445e-03, -2.2124e-02, -2.6831e-03, 9.0644e-02],
[-7.6951e-02, -7.7030e-02, -2.8385e-02, 8.8830e-02, 1.8159e-02,
1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
-3.8862e-02, 8.1929e-02, 9.8960e-03, -7.4610e-02, 4.4434e-02,
9.1217e-02, 1.7353e-02, -4.1310e-02, 4.6109e-03, -6.0577e-03,
-5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
-2.7643e-02, 2.0528e-02, -3.5310e-02, -7.1740e-02, 9.2993e-03,
-4.4849e-02, -2.9480e-02, 8.8236e-02, 6.5175e-03, -1.8974e-02,
-6.3304e-02, -4.7579e-02, -7.5561e-02, 6.3819e-03, 6.4816e-02,
-5.2397e-03, 4.9444e-02, -7.8613e-02, 3.2730e-02, -4.0234e-02,
7.0352e-02, 9.4854e-03, -6.6504e-02, -5.6702e-02, 8.2212e-02,
-1.0986e-03, -7.2008e-02, -2.7400e-02, 2.1016e-02, 3.5173e-02,
-7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02, 3.6873e-02,
-4.1272e-02, -5.1494e-03, -7.6218e-03, 6.9450e-02, 3.0325e-02,
1.2730e-02, -6.4263e-02, 8.1150e-02, -7.4637e-02, 4.3607e-02,
-5.5105e-02, 7.7573e-02, 7.7680e-02, 4.4945e-02, -7.6075e-02,
3.1183e-02, 4.1456e-02, -5.6365e-02, -6.4912e-02, 1.1328e-02,
-6.9009e-02, 2.9723e-02, 3.2704e-02, 1.9641e-02, -2.5972e-02,
-4.7365e-02, -6.7854e-02, 3.1308e-02, -4.7928e-02, -3.9339e-02,
1.7182e-03, -7.1567e-02, -7.8225e-02, 8.7315e-02, 4.9131e-02,
-8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02, 3.4342e-02,
-1.7424e-02, 2.5424e-02, -7.5859e-02, 8.9858e-02, -6.9163e-02,
-7.2367e-02, 8.4212e-02, 5.0855e-02, 1.1414e-03, 7.7345e-02,
5.0493e-02, 7.5050e-02, -2.1079e-02, 7.8312e-02, 7.6594e-03,
4.9073e-02, 6.5474e-02, 2.7924e-02, -1.4094e-02, 4.2029e-02]])
(bias): Normal:
loc: tensor([-0., -0.])
scale: tensor([0.7071, 0.7071])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0687, -0.0438])
)
(observed): Observed()
)
)
One can also set the posterior when one creates the module
nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))
Out:
Linear(
in_features=10, out_features=10, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., -0., -0., -0., -0., -0., -0., -0.],
[0., -0., 0., 0., -0., -0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0., -0., 0., -0., -0., -0.],
[0., -0., -0., 0., 0., 0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0., 0., -0., 0., -0., 0.],
[0., 0., 0., 0., -0., -0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0., -0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0., 0., -0., 0., -0., -0.],
[-0., -0., 0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)
tensor: tensor([[-0.0074, 0.0521, -0.0124, -0.0409, 0.0264, 0.0437, -0.0267, 0.0628,
0.0505, 0.0381],
[-0.0073, 0.0132, -0.0146, 0.0461, 0.0400, -0.0276, -0.0511, -0.0327,
0.0057, -0.1127],
[-0.0828, -0.0232, -0.0328, -0.0699, -0.0015, -0.0479, 0.0434, -0.0257,
-0.0307, -0.0163],
[-0.0135, 0.0351, -0.0783, -0.0044, 0.0843, 0.0475, 0.0530, 0.0148,
-0.0002, 0.0309],
[ 0.0177, 0.0899, -0.0343, -0.0520, 0.0359, -0.0923, 0.0123, 0.0111,
-0.0587, 0.0311],
[-0.0475, -0.0160, 0.0131, -0.0638, -0.0513, -0.0493, 0.0084, -0.0723,
0.0861, 0.0588],
[ 0.0710, 0.0325, -0.0378, -0.1117, 0.0043, -0.0647, 0.0132, 0.0545,
0.0478, -0.0080],
[-0.0560, 0.0021, -0.0093, 0.0394, -0.0130, 0.0267, -0.0057, -0.0756,
-0.0694, 0.0105],
[ 0.0652, 0.0100, 0.0970, -0.0560, 0.0092, -0.0568, 0.0200, -0.0449,
-0.0877, 0.0570],
[-0.0767, 0.0075, -0.0159, 0.0510, 0.0111, -0.0010, -0.0614, 0.0479,
0.0140, 0.0625]], grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([0., 0., 0., 0., 0., 0., -0., -0., 0., -0.], requires_grad=True)
tensor: tensor([-0.0160, -0.0384, 0.0385, 0.0560, 0.0037, -0.0786, -0.0205, -0.0698,
-0.0399, -0.0347], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., -0., -0., -0., -0., -0., -0., -0.],
[0., -0., 0., 0., -0., -0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0., -0., 0., -0., -0., -0.],
[0., -0., -0., 0., 0., 0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0., 0., -0., 0., -0., 0.],
[0., 0., 0., 0., -0., -0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0., -0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0., 0., -0., 0., -0., -0.],
[-0., -0., 0., 0., 0., 0., 0., 0., 0., 0.]])
scale: tensor([[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0537, 0.1889, -0.0853, -0.2389, 0.2932, 0.1727, 0.2749, 0.0216,
0.2554, 0.2329],
[ 0.3147, 0.2312, 0.2066, -0.0289, -0.3128, -0.0764, -0.0221, -0.1260,
-0.0566, -0.2622],
[ 0.2383, -0.1345, 0.3103, 0.0020, -0.0411, -0.0745, -0.2340, -0.1291,
-0.1219, -0.2262],
[-0.1985, 0.1491, -0.2021, -0.0445, 0.3105, -0.0486, 0.1779, -0.0055,
-0.0425, -0.2117],
[ 0.0271, -0.0579, -0.3026, 0.2160, 0.0549, 0.1236, 0.2550, 0.0574,
-0.2970, -0.1247],
[ 0.3069, 0.2616, 0.2535, 0.3027, -0.0497, 0.2073, -0.0330, 0.1368,
-0.1539, 0.0674],
[ 0.2222, 0.2572, 0.1273, 0.0229, -0.0218, -0.2938, 0.0404, 0.0300,
0.0719, -0.2847],
[ 0.2613, 0.0033, -0.1948, 0.1279, -0.2686, -0.2299, 0.0870, -0.0815,
-0.2531, -0.1094],
[-0.0470, 0.1280, 0.0507, -0.2919, -0.2141, 0.2617, -0.0497, 0.2454,
-0.0208, -0.0923],
[-0.2441, -0.2966, 0.0220, 0.2931, 0.1903, 0.0392, 0.0920, 0.2355,
0.1240, 0.2759]])
(bias): Normal:
loc: tensor([0., 0., 0., 0., 0., 0., -0., -0., 0., -0.])
scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.2346, 0.2745, 0.2184, 0.2540, 0.0325, 0.2008, -0.0309, -0.2569,
0.1033, -0.1293])
)
(observed): Observed()
)
See the borch.posterior documentation for other posteriors and what parameters
you can set. Note that all posteriors does not work with all parameters but you can
have different posteriors for the different borch.Module’s in your network.
Exercises¶
Use what you have learned to train an image classifier for MNIST, you should achieve an accuracy larger than 98 %. Note: you can access MNST using
torchvision.datasets.MNIST.Fit the same model architecture with normal torch and compare the likelihood with the borch network, What are the differences and why?
Port the model to CIFAR and see how you can improve the accuracy.
Show how the Categorical distribution is related to the cross entropy loss function that is commonly used in frequentest deep learning.
Total running time of the script: ( 0 minutes 0.222 seconds)