Note
Click here to download the full example code
Neural Networks¶
Neural networks can be constructed using the borch.nn package.
Now that you’ve had a glimpse of autograd, nn depends on autograd
to define models and differentiate them. An nn.Module contains layers,
and a method forward(input) that returns the output.
For example, look at this network that classifies digit images:
convnet¶
It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.
A typical training procedure for a neural network is as follows:
Define a network that has some learnable parameters and/or randomVariables
For each batch in a dataset, do:
Process the input data through the network
Compute the loss (how far is the output from being correct?)
Propagate gradients back into the network’s parameters
Update the weights of the network, typically using a simple update rule:
weight = weight - learning_rate * gradient
Define the network¶
Let’s define this network:
import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__(posterior=posterior.Automatic())
# 1 input image channel, 6 output channels, 5x5 convolution kernel
self.conv1 = nn.Conv2d(1, 6, 5)
# 6 input channels, 16 output channels, 5x5 convolution kernel
self.conv2 = nn.Conv2d(6, 16, 5)
# An affine operation: y = Wx + b
# NB after two convolutional operations with 5x5 kernels and no padding,
# the spatial dimension of an image with intial dimension 32x32 is
# 5x5 (with 16 channels)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = self.fc2(x)
# Specifying the likelihood function
self.classification = distributions.Categorical(logits=x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Out:
Net(
(posterior): Automatic()
(prior): Module()
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-0.0774, 0.1503, -0.1255, -0.0640, 0.0630],
[ 0.0595, 0.1217, -0.0309, -0.0112, 0.0043],
[-0.1666, 0.1739, -0.0512, 0.1344, -0.0321],
[-0.1129, -0.1998, 0.1292, 0.1143, 0.1476],
[-0.0027, 0.0059, 0.1920, 0.0098, 0.0048]]],
[[[ 0.0141, 0.0599, 0.1305, -0.0544, 0.1488],
[ 0.1270, -0.0026, -0.1380, -0.1876, 0.0489],
[ 0.1652, 0.0822, 0.0593, 0.1513, 0.0839],
[ 0.0577, -0.0596, -0.0329, 0.1613, -0.0014],
[ 0.1026, -0.1543, 0.1927, 0.0072, 0.1568]]],
[[[ 0.1353, -0.0796, 0.1384, 0.1742, 0.1012],
[-0.0558, 0.0222, 0.0847, 0.0301, -0.0592],
[-0.1298, 0.0103, 0.0711, -0.1305, -0.1070],
[ 0.1687, 0.1726, -0.1407, 0.0858, -0.1460],
[ 0.0043, -0.0089, 0.0811, -0.0686, -0.0656]]],
[[[ 0.1872, -0.0864, 0.0585, 0.1785, -0.0565],
[-0.1216, 0.1257, -0.0771, -0.0275, 0.0386],
[ 0.1893, -0.0486, 0.1845, -0.0519, 0.1210],
[ 0.0378, 0.0850, 0.0683, 0.1155, -0.0512],
[ 0.0487, 0.0793, -0.0349, 0.1779, -0.1606]]],
[[[-0.1173, 0.0174, 0.0488, 0.0573, 0.1199],
[ 0.0558, 0.0349, -0.1852, 0.0371, -0.0067],
[-0.1553, -0.1714, 0.0416, -0.0263, 0.1124],
[-0.1636, 0.0513, -0.1301, 0.1010, -0.1350],
[-0.0892, -0.1678, 0.1047, -0.0118, -0.1502]]],
[[[-0.0300, -0.1080, 0.0871, -0.1145, 0.0902],
[-0.1072, 0.1453, 0.1286, 0.1711, -0.0089],
[-0.0012, -0.1712, 0.0279, 0.0078, -0.0613],
[-0.0928, 0.0310, -0.0577, -0.0740, 0.0623],
[ 0.0680, 0.1578, 0.1678, 0.1280, -0.0358]]]], requires_grad=True)
tensor: tensor([[[[-0.0912, 0.1651, -0.0824, -0.1180, 0.1151],
[ 0.1682, 0.1379, 0.0727, 0.0383, 0.0096],
[-0.1481, 0.2358, -0.0234, 0.1954, -0.0651],
[-0.1287, -0.1321, 0.0942, 0.2415, 0.2055],
[ 0.0139, 0.0093, 0.1414, -0.0114, 0.0490]]],
[[[ 0.0070, 0.0804, 0.1924, -0.0535, 0.1665],
[ 0.1279, -0.0944, -0.0823, -0.1923, 0.0228],
[ 0.2292, 0.0453, 0.1051, 0.1499, 0.0351],
[ 0.0366, -0.0931, -0.0020, 0.1834, -0.0846],
[ 0.0633, -0.2066, 0.0871, 0.0926, 0.1730]]],
[[[ 0.1552, -0.0400, 0.1363, 0.1708, 0.1237],
[-0.0265, -0.0098, 0.0622, 0.0631, -0.1061],
[-0.1840, 0.0792, 0.1222, -0.2076, -0.0974],
[ 0.0809, 0.1328, -0.0938, -0.0133, -0.0958],
[-0.0706, 0.0437, 0.0714, -0.1809, -0.1122]]],
[[[ 0.2537, -0.0915, 0.0056, 0.1602, -0.1162],
[-0.1037, 0.2019, -0.0716, 0.0006, -0.0065],
[ 0.1379, -0.0639, 0.1727, -0.1089, 0.2738],
[ 0.0247, 0.0751, 0.0598, 0.0804, -0.1690],
[ 0.0494, 0.1083, -0.0208, 0.0432, -0.1384]]],
[[[-0.2125, 0.0266, 0.0491, 0.0178, 0.0777],
[ 0.0899, 0.0169, -0.2210, 0.0300, 0.0083],
[-0.1368, -0.2311, 0.0537, -0.0592, 0.1702],
[-0.2015, 0.0771, -0.1454, 0.0375, -0.2007],
[-0.0780, -0.1716, 0.0919, -0.0333, -0.1659]]],
[[[-0.0071, -0.1041, 0.1053, -0.1506, 0.0268],
[-0.1314, 0.1820, 0.0954, 0.1974, -0.0269],
[-0.0590, -0.2216, 0.0056, -0.1258, -0.1532],
[-0.0695, 0.0420, -0.0078, -0.1728, 0.1632],
[-0.0141, 0.2390, 0.1728, 0.1627, -0.1093]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.1635, 0.0177, 0.1092, 0.1259, 0.1528, 0.0503],
requires_grad=True)
tensor: tensor([-0.1425, 0.0416, 0.1181, 0.1505, 0.1053, 0.0927],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., -0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., 0., 0.],
[-0., 0., 0., 0., 0.]]],
[[[0., 0., 0., -0., 0.],
[0., -0., -0., -0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., -0., 0., -0.],
[0., -0., 0., 0., 0.]]],
[[[0., -0., 0., 0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., -0., -0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., -0., -0.]]],
[[[0., -0., 0., 0., -0.],
[-0., 0., -0., -0., 0.],
[0., -0., 0., -0., 0.],
[0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0.]]],
[[[-0., 0., 0., 0., 0.],
[0., 0., -0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., -0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., -0., 0.],
[0., 0., 0., 0., -0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-0.0774, 0.1503, -0.1255, -0.0640, 0.0630],
[ 0.0595, 0.1217, -0.0309, -0.0112, 0.0043],
[-0.1666, 0.1739, -0.0512, 0.1344, -0.0321],
[-0.1129, -0.1998, 0.1292, 0.1143, 0.1476],
[-0.0027, 0.0059, 0.1920, 0.0098, 0.0048]]],
[[[ 0.0141, 0.0599, 0.1305, -0.0544, 0.1488],
[ 0.1270, -0.0026, -0.1380, -0.1876, 0.0489],
[ 0.1652, 0.0822, 0.0593, 0.1513, 0.0839],
[ 0.0577, -0.0596, -0.0329, 0.1613, -0.0014],
[ 0.1026, -0.1543, 0.1927, 0.0072, 0.1568]]],
[[[ 0.1353, -0.0796, 0.1384, 0.1742, 0.1012],
[-0.0558, 0.0222, 0.0847, 0.0301, -0.0592],
[-0.1298, 0.0103, 0.0711, -0.1305, -0.1070],
[ 0.1687, 0.1726, -0.1407, 0.0858, -0.1460],
[ 0.0043, -0.0089, 0.0811, -0.0686, -0.0656]]],
[[[ 0.1872, -0.0864, 0.0585, 0.1785, -0.0565],
[-0.1216, 0.1257, -0.0771, -0.0275, 0.0386],
[ 0.1893, -0.0486, 0.1845, -0.0519, 0.1210],
[ 0.0378, 0.0850, 0.0683, 0.1155, -0.0512],
[ 0.0487, 0.0793, -0.0349, 0.1779, -0.1606]]],
[[[-0.1173, 0.0174, 0.0488, 0.0573, 0.1199],
[ 0.0558, 0.0349, -0.1852, 0.0371, -0.0067],
[-0.1553, -0.1714, 0.0416, -0.0263, 0.1124],
[-0.1636, 0.0513, -0.1301, 0.1010, -0.1350],
[-0.0892, -0.1678, 0.1047, -0.0118, -0.1502]]],
[[[-0.0300, -0.1080, 0.0871, -0.1145, 0.0902],
[-0.1072, 0.1453, 0.1286, 0.1711, -0.0089],
[-0.0012, -0.1712, 0.0279, 0.0078, -0.0613],
[-0.0928, 0.0310, -0.0577, -0.0740, 0.0623],
[ 0.0680, 0.1578, 0.1678, 0.1280, -0.0358]]]])
(bias): Normal:
loc: tensor([-0., 0., 0., 0., 0., 0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.1635, 0.0177, 0.1092, 0.1259, 0.1528, 0.0503])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
...,
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[ 1.7231e-02, 4.3731e-02, 5.4681e-02, 3.4570e-02, 4.3578e-02],
[-2.7786e-02, -4.7320e-02, 3.1995e-02, 4.2537e-02, 7.5819e-02],
[ 8.9881e-03, 2.6359e-02, -8.3574e-03, -5.6418e-02, -3.9465e-02],
[-1.9657e-02, 5.7231e-03, 5.5066e-03, -1.4926e-02, -7.4856e-02],
[-6.2679e-02, 3.7226e-02, 8.0120e-02, -2.3663e-02, -5.3511e-02]],
[[-7.5590e-02, 6.4561e-02, -4.3921e-02, -7.7799e-03, -5.2721e-02],
[-1.0355e-02, 5.9574e-02, 1.1166e-02, -4.8245e-02, 4.1576e-02],
[ 5.6966e-02, -7.2947e-03, -1.9590e-03, 6.9857e-02, -2.1827e-02],
[ 3.2483e-02, -6.1154e-02, 7.6961e-02, 4.5617e-02, -5.6834e-02],
[ 3.4771e-02, -3.0784e-02, 4.8407e-02, 2.7358e-02, 6.5453e-02]],
[[-7.3062e-02, -2.7048e-02, 1.0655e-02, 7.5364e-02, 6.3020e-02],
[ 1.2004e-02, -3.4794e-02, 4.0299e-02, 6.4137e-02, -2.4964e-02],
[ 7.4683e-02, -2.4405e-02, -6.7041e-02, 2.7558e-02, 2.7585e-02],
[-2.3105e-02, 3.5924e-02, -1.3142e-02, 3.6770e-02, -5.1734e-02],
[-3.0248e-03, -7.3805e-02, 1.5975e-02, -3.8458e-02, -3.4536e-02]],
[[ 1.2704e-02, -6.5938e-02, 5.8697e-04, -4.0066e-03, 6.0744e-02],
[-6.6682e-02, 7.7161e-02, 8.1204e-02, 1.0834e-02, -9.7154e-03],
[-4.5248e-02, 4.1191e-02, -3.1746e-02, 1.4713e-02, 1.5446e-02],
[-6.1854e-02, -5.4387e-02, 3.8440e-02, -3.7393e-02, 7.6462e-02],
[ 6.8184e-02, -3.0699e-02, -6.0217e-02, 3.2648e-02, -2.2715e-02]],
[[ 7.6031e-02, 6.2038e-02, 2.1564e-02, -6.5555e-02, -6.4686e-02],
[-4.2324e-02, 3.6401e-02, -6.7974e-02, -3.9416e-02, -4.4166e-02],
[ 6.7690e-02, -6.9973e-02, -4.7069e-02, -7.6832e-03, -4.0938e-03],
[-2.6045e-02, -4.6521e-02, 3.4383e-03, 6.4282e-02, -2.5852e-02],
[ 3.2686e-02, 3.1602e-02, 3.5400e-02, -7.0873e-02, 7.1903e-02]],
[[ 6.0529e-02, -5.2981e-02, -3.9675e-02, -6.5714e-02, 2.8020e-02],
[ 3.5057e-02, 1.3588e-02, -3.5319e-02, -7.1891e-02, -6.0487e-02],
[-6.3268e-02, 1.4282e-02, -3.6912e-02, 3.1696e-02, 5.1382e-02],
[-7.2075e-02, -5.2067e-02, 5.7057e-02, -6.0579e-02, -5.5177e-02],
[-2.2326e-03, 6.6584e-02, -6.6819e-02, -5.9396e-02, 4.9889e-02]]],
[[[ 6.1514e-02, 3.9812e-02, -7.7776e-02, -4.9428e-02, -2.5424e-02],
[-1.3566e-02, -3.6820e-02, -1.2480e-02, 4.0800e-02, -7.4332e-02],
[ 2.8230e-02, -2.4562e-02, -2.9086e-02, 6.3475e-02, -1.5058e-02],
[-1.5993e-02, 2.6303e-02, 8.7087e-03, -5.3252e-02, 2.7997e-02],
[ 2.8034e-02, -3.0461e-02, -1.0604e-02, 4.4506e-02, -1.2910e-02]],
[[-7.8455e-02, -7.8485e-02, -5.5244e-02, 4.7309e-03, -3.7835e-02],
[-5.3291e-02, 3.5150e-02, -6.4878e-02, -7.4903e-02, -5.5693e-02],
[ 3.2479e-02, 6.5471e-02, 4.4178e-02, -5.4057e-02, -7.9336e-02],
[-6.8247e-02, 6.5587e-02, -8.7773e-04, -3.5329e-02, -3.8033e-02],
[-3.8157e-02, 4.2570e-02, 7.6967e-02, -4.9785e-02, -2.8498e-02]],
[[ 4.3747e-02, 5.4973e-02, -2.6460e-02, -1.5235e-02, 7.8260e-02],
[ 1.8068e-02, 1.9185e-02, -6.4915e-02, 7.4767e-03, 6.4246e-02],
[-6.3240e-02, -7.0417e-03, -6.4609e-02, -3.5764e-03, 1.0317e-02],
[-4.5486e-02, 7.3897e-02, -2.8131e-02, -8.0268e-02, 1.2427e-02],
[ 2.6636e-02, 5.3320e-02, 8.0113e-02, -2.4988e-02, -3.3729e-02]],
[[-3.0252e-02, 2.0032e-02, 2.2845e-02, -3.2871e-02, -7.2805e-02],
[-1.2501e-02, 1.2056e-02, 1.7524e-02, 3.1636e-02, -3.6845e-02],
[ 2.2584e-02, 1.5433e-02, -8.1113e-02, 4.8879e-03, -4.5628e-02],
[-6.5354e-03, -1.2652e-02, -6.9698e-02, 6.7797e-02, 6.2843e-02],
[ 1.2404e-02, 1.8414e-02, -4.5958e-02, -6.2785e-03, 4.0931e-02]],
[[-1.8430e-02, -2.1989e-02, -8.0820e-03, 7.2810e-02, -9.3746e-03],
[-7.0612e-02, 5.1172e-02, -7.0489e-03, -4.5064e-02, 3.5534e-02],
[-3.2953e-02, 3.0727e-02, 4.8882e-02, -3.9965e-02, -2.3211e-02],
[ 6.9945e-02, 2.6839e-02, 3.3160e-03, -2.7622e-02, 7.6837e-02],
[-1.7656e-02, -4.1508e-03, 3.8821e-02, 3.6178e-02, 3.2764e-02]],
[[-3.1492e-02, -7.9136e-02, -4.2417e-02, 3.8367e-02, -7.3788e-02],
[-4.9922e-02, 2.4759e-02, 1.3668e-03, 7.1963e-02, -6.7763e-02],
[-4.6869e-02, 6.6676e-02, 3.5155e-02, -1.8898e-02, -3.7238e-02],
[-1.2625e-02, -5.9552e-02, -6.8551e-02, 5.6212e-02, -3.5397e-02],
[-3.0192e-02, -5.3948e-02, 4.1616e-03, -1.9126e-02, -2.7772e-02]]],
[[[ 7.5304e-02, 4.8642e-02, 6.6816e-03, -4.8956e-02, 7.0525e-03],
[ 5.3506e-02, -6.3276e-02, 3.2309e-02, -6.3837e-02, -7.6400e-03],
[ 8.0993e-02, 5.0589e-02, -4.4620e-02, -1.7407e-02, -4.0847e-02],
[-7.9109e-02, -5.0230e-02, -6.1707e-02, 3.5480e-02, 9.4730e-03],
[-4.0758e-02, 3.8149e-02, -2.6608e-02, 6.5586e-02, 6.7719e-02]],
[[ 3.7999e-02, 5.1426e-02, 4.6990e-02, -8.4217e-04, -7.8228e-02],
[ 7.4991e-02, -1.3921e-03, -5.2225e-02, -3.7669e-02, -1.3416e-02],
[-6.8178e-02, -3.8191e-02, -4.8837e-02, 5.6142e-02, 8.0981e-02],
[-4.9641e-02, 2.5925e-02, 3.2090e-02, 1.2274e-02, -7.0899e-02],
[ 2.0209e-02, 7.2872e-02, -2.2630e-02, -6.1033e-02, 4.0857e-02]],
[[-6.5934e-02, 6.1233e-02, 3.0293e-02, -7.4631e-02, 7.8141e-02],
[-1.4982e-02, 2.9501e-02, -6.2282e-02, 2.6212e-02, -2.5934e-02],
[ 3.0596e-02, 2.5406e-02, 4.6264e-02, -2.3452e-02, -1.9204e-02],
[ 8.5599e-03, 7.3909e-02, -6.4926e-02, -8.6878e-03, -3.7090e-02],
[ 3.4250e-02, -5.0255e-02, -7.8042e-02, -5.1085e-02, -8.0715e-02]],
[[ 2.6317e-02, -3.6153e-02, -5.3455e-02, 2.1931e-03, 6.2524e-04],
[-4.6965e-02, 3.9974e-02, -1.2460e-02, -4.2091e-02, -4.6497e-02],
[-3.4564e-02, -1.1512e-02, 2.1072e-02, -6.0303e-02, 1.7830e-02],
[-6.8750e-02, 2.9228e-02, 2.5533e-02, 5.8319e-02, -2.6735e-02],
[ 3.2792e-02, 3.2718e-02, 5.9794e-02, -7.6940e-02, -4.0923e-02]],
[[-7.0395e-02, -7.5975e-02, -2.7344e-02, -2.3934e-02, -2.9200e-02],
[ 5.4427e-02, -6.5287e-02, 3.4746e-02, 1.0117e-02, -6.4013e-02],
[-7.9322e-02, 7.8159e-02, -7.0473e-02, 4.4684e-02, 2.1939e-02],
[ 5.5900e-02, -6.5708e-02, 3.3804e-02, -1.8570e-02, -7.0815e-02],
[ 9.0783e-03, -7.2442e-02, 7.7275e-02, 5.2036e-03, 2.7754e-02]],
[[-3.5030e-02, -1.2288e-02, 6.1587e-02, 1.0093e-02, -3.9849e-02],
[ 3.1310e-02, 3.2095e-02, -7.3972e-02, 6.0673e-02, 6.6248e-02],
[ 7.5880e-02, 7.9613e-02, -1.4684e-02, 6.6668e-02, 1.1897e-02],
[-6.6942e-02, 6.9498e-02, 7.3219e-02, -5.7326e-02, 7.9767e-02],
[-2.2455e-02, 2.7784e-02, -5.6006e-02, -2.8644e-02, -4.7902e-02]]],
...,
[[[ 5.4641e-02, -2.8195e-02, 1.2863e-02, 6.9362e-02, -6.5228e-02],
[-1.4326e-02, -2.4860e-02, 6.7933e-02, 6.0904e-03, -5.2127e-02],
[-3.5126e-02, 1.9526e-02, -4.2415e-02, 7.7512e-03, 6.7621e-02],
[-7.7149e-02, -3.9294e-02, 2.8953e-02, 2.8484e-02, -5.9953e-02],
[ 3.0683e-02, -7.1082e-02, -1.7986e-02, 1.3298e-02, -6.4964e-02]],
[[ 3.5600e-02, -6.4155e-02, 4.7039e-02, -2.0131e-02, -2.7915e-02],
[ 2.9918e-02, 6.5944e-03, -6.6870e-02, -6.3787e-02, -4.9677e-02],
[-3.6079e-02, -2.8304e-02, 2.9721e-02, 2.8190e-02, -5.0218e-02],
[ 6.4923e-02, -4.9635e-02, -3.6667e-03, 7.9379e-02, 2.5979e-02],
[-4.8337e-02, 7.7505e-02, -7.3112e-02, 2.4510e-02, 2.5683e-02]],
[[ 2.8887e-03, -2.1671e-02, -5.2347e-02, -6.3329e-02, -4.0586e-02],
[-1.2757e-02, 3.3395e-02, -7.8268e-02, 7.3369e-02, 5.0369e-04],
[ 2.6221e-02, -2.9271e-02, -6.5565e-02, -1.6796e-02, -4.9055e-02],
[-5.8221e-02, -4.2509e-02, 4.6818e-06, 2.6047e-05, 4.1964e-02],
[ 1.0361e-02, -1.0747e-02, -5.5872e-02, -4.5506e-02, -2.9223e-02]],
[[-1.9352e-02, -8.0087e-02, -3.3809e-03, 3.9983e-02, 6.5648e-02],
[-2.5674e-02, 5.8915e-02, 1.8416e-02, 5.8460e-02, 3.2707e-02],
[-6.6357e-02, 6.9795e-02, 8.6752e-03, 5.9294e-02, 1.7985e-02],
[-6.5379e-02, 2.3563e-02, 5.0532e-02, 3.5488e-03, -4.1146e-02],
[ 8.1383e-02, 5.7224e-02, -7.2400e-02, -4.0180e-02, -6.1370e-02]],
[[ 5.9150e-02, 5.6013e-02, -4.5474e-02, -4.0693e-02, 4.2932e-02],
[-1.3553e-02, 4.4707e-02, -2.7249e-02, 2.3061e-02, 1.9638e-02],
[ 2.1247e-02, -6.3221e-02, 5.1882e-02, -1.6282e-02, 6.9770e-02],
[-7.0485e-02, 7.4524e-02, 4.4509e-02, -6.5970e-02, 2.9617e-02],
[-8.1458e-02, -4.9716e-02, 1.2315e-02, -2.0425e-02, 3.0172e-02]],
[[ 5.2212e-02, -3.5905e-02, 2.5783e-02, -6.0258e-02, 3.5215e-02],
[-7.5870e-02, -1.5704e-02, 3.3627e-02, -4.0729e-02, 5.7335e-02],
[-3.5374e-02, -7.5164e-02, -7.3468e-02, -1.1014e-02, 1.6214e-02],
[-3.1993e-02, -2.4012e-02, -1.6525e-03, -6.9434e-02, 2.8824e-02],
[-2.4923e-02, 7.8550e-02, 4.5400e-02, 2.7779e-02, -6.5854e-02]]],
[[[ 2.3195e-02, -6.8559e-02, -1.9293e-02, -3.7088e-02, -7.3186e-02],
[ 7.8055e-02, 5.0381e-03, 3.0678e-02, -5.8232e-02, 2.8428e-02],
[-2.2133e-02, -1.6136e-03, 6.5804e-02, 3.9714e-03, -3.9261e-02],
[ 2.5493e-02, -1.4515e-02, 3.1299e-02, -1.6629e-02, -2.5878e-02],
[-2.4748e-02, -6.8695e-02, 4.8038e-02, 1.7510e-02, -2.4795e-02]],
[[ 2.7344e-02, 5.4179e-02, 1.8617e-02, 6.7468e-02, 4.8763e-02],
[ 3.7600e-02, 3.9927e-02, 5.1062e-02, 2.1710e-02, -3.2169e-02],
[-3.8513e-02, -4.6700e-02, -3.3343e-02, 5.7257e-02, 7.1398e-02],
[ 3.1596e-02, -6.1682e-02, -1.1294e-02, -4.6606e-02, -1.9235e-02],
[-5.1762e-02, 4.1756e-03, 5.5901e-02, 5.0582e-02, -3.5234e-02]],
[[ 6.8751e-02, 1.9294e-02, 9.9260e-04, 4.8577e-02, 1.1296e-02],
[-2.5931e-03, 5.6043e-02, -3.9379e-02, -1.5890e-02, 2.7560e-02],
[-6.1309e-02, 4.4243e-02, -6.8550e-02, -7.5816e-02, 6.4328e-02],
[-3.5933e-02, 1.5707e-02, -4.1360e-03, 2.3218e-02, -5.3996e-02],
[-5.8497e-02, 2.2945e-02, -1.6730e-02, -4.9801e-02, 4.5134e-02]],
[[ 4.3283e-02, 1.4086e-02, -7.7765e-03, -2.4735e-02, -6.1307e-02],
[-3.7167e-02, 1.0755e-05, -3.5806e-02, -2.9059e-02, 7.9799e-02],
[-2.3940e-02, 3.7716e-02, 2.2327e-02, 4.3423e-02, -4.9689e-02],
[-3.8642e-02, -4.3529e-02, -2.7830e-02, -4.9579e-02, 6.6404e-02],
[ 1.3762e-02, -2.3933e-02, 7.1659e-02, -1.3726e-02, -8.0242e-02]],
[[-1.5402e-02, 1.8298e-02, -5.1471e-02, 2.4580e-02, 9.6023e-03],
[ 6.1876e-02, 7.8261e-02, 2.6394e-02, -7.8227e-02, -7.6062e-02],
[ 6.7169e-02, -7.8952e-03, -7.6834e-02, -7.2395e-02, -8.1512e-02],
[ 2.4895e-02, 7.4719e-02, -6.9676e-02, -2.0183e-02, 4.7940e-02],
[-1.6931e-02, 6.4322e-02, 4.9096e-02, 6.7067e-02, -5.1128e-02]],
[[-4.2134e-03, 7.9587e-02, -7.7337e-02, -1.6919e-02, -4.4513e-02],
[ 4.5003e-02, 2.9848e-02, -3.2239e-02, -5.3997e-02, 3.4833e-02],
[-4.6084e-02, 7.7325e-02, -8.2341e-04, -2.9711e-02, 4.7059e-02],
[ 7.1990e-02, -1.8925e-02, 6.9833e-02, -3.8232e-02, -5.3586e-02],
[ 5.6777e-02, 5.4212e-02, -7.0351e-02, 7.8116e-02, -5.8073e-03]]],
[[[-3.9872e-02, 2.2878e-02, -5.4838e-02, -7.8741e-02, -2.4075e-02],
[ 5.5670e-02, -7.5194e-02, -2.3993e-02, -1.3565e-02, -7.6118e-02],
[ 5.9835e-02, 7.7078e-02, -1.9101e-02, 3.7423e-02, 5.8969e-02],
[-7.6931e-02, 5.4068e-02, 7.6462e-02, 6.0935e-02, 6.1393e-02],
[-3.0153e-02, -6.9821e-02, -7.9367e-02, -6.8787e-02, 4.8573e-02]],
[[-7.7210e-04, 1.2697e-02, -7.1333e-02, -4.3644e-02, -3.0627e-02],
[ 6.4280e-02, -6.3660e-02, 5.7517e-02, 3.5869e-02, -2.4693e-02],
[ 4.2786e-03, -8.1200e-02, 6.9931e-02, -2.5703e-04, 1.7692e-02],
[ 8.5987e-03, -1.2595e-02, 7.9205e-02, 3.0073e-02, -3.2985e-02],
[-6.3697e-02, 3.2692e-02, -1.9431e-02, 5.3542e-02, -2.0049e-02]],
[[ 4.7488e-02, 3.1486e-02, 4.3938e-02, 3.8207e-02, -6.3004e-02],
[ 7.6382e-02, 1.8666e-02, 1.0028e-02, 6.2085e-02, 5.3552e-02],
[-3.0010e-02, 2.7386e-02, -2.2148e-02, -5.4034e-02, -2.1415e-02],
[-3.2287e-02, -4.1362e-02, 1.2052e-02, -6.5838e-02, -4.6819e-02],
[-6.8102e-02, 5.9098e-02, 2.8529e-02, -5.3848e-02, 2.3559e-02]],
[[ 2.5513e-02, 3.7517e-02, 5.5636e-02, 4.3730e-02, -3.5048e-02],
[-5.4454e-02, 7.0706e-02, -5.7952e-02, 2.3890e-02, 3.0251e-02],
[ 2.0294e-02, -6.2255e-02, -7.7577e-02, -6.8416e-02, -4.8070e-02],
[ 5.3928e-02, -6.0171e-02, 4.9991e-02, 4.6665e-02, -1.5579e-02],
[ 1.9901e-02, -6.1094e-02, -1.4091e-02, -6.6292e-02, 1.2545e-02]],
[[ 7.3009e-02, 7.1030e-02, -3.5882e-02, -5.9879e-02, 2.1529e-02],
[-2.7738e-02, -4.8476e-02, -3.5715e-03, -5.2242e-03, 4.8341e-03],
[-4.1100e-02, 2.7022e-02, -5.5728e-02, 5.0925e-02, 2.2531e-02],
[ 6.5409e-02, 2.5243e-02, 5.5194e-02, 5.0815e-02, 3.7556e-02],
[ 4.0211e-02, 3.1016e-02, -2.9596e-02, -4.3925e-03, -4.0317e-02]],
[[ 4.9369e-02, 4.1262e-02, 7.0892e-02, -7.3260e-02, 4.2668e-02],
[ 3.7235e-02, 3.5402e-02, 3.3255e-02, 5.4474e-02, 5.3561e-02],
[-1.7112e-02, -1.1525e-02, 1.9306e-02, 9.0656e-04, 2.4812e-02],
[-1.0841e-02, -3.1940e-02, 6.6983e-02, -1.3595e-02, 7.1947e-02],
[-1.4649e-02, 3.0190e-02, -4.8740e-02, -8.1647e-02, -4.3005e-02]]]],
requires_grad=True)
tensor: tensor([[[[ 1.8883e-02, 6.8035e-02, 4.9171e-02, 9.7844e-02, 8.8935e-02],
[-1.8658e-02, -6.0121e-02, 1.3134e-01, 1.0360e-01, 1.0962e-01],
[-4.9652e-02, 9.1593e-03, -1.0749e-01, -2.3633e-02, -6.8945e-02],
[-7.2921e-02, 5.4190e-03, -3.7902e-02, -9.1251e-02, -9.6080e-02],
[-9.6740e-02, 3.3653e-02, 1.2174e-01, -2.1966e-02, -6.8750e-02]],
[[-5.1564e-02, 9.9327e-02, -9.0770e-02, 1.1558e-01, -6.4334e-02],
[-3.3721e-02, 1.3854e-01, 7.1261e-02, 2.3334e-02, 6.2038e-03],
[ 4.1357e-02, 4.7015e-02, 1.1618e-01, 1.1209e-01, -1.5182e-02],
[-4.2263e-02, -5.2076e-02, 1.2271e-01, 6.1315e-02, -3.5989e-02],
[ 2.4518e-02, -7.3358e-02, 1.1719e-01, 6.8904e-02, 6.7703e-02]],
[[-9.5581e-02, -2.1765e-02, 2.3421e-03, 1.1011e-01, 2.8519e-02],
[ 9.2380e-02, -8.1339e-02, 4.1274e-02, 8.2023e-02, 4.4325e-02],
[ 1.2266e-01, -8.0220e-02, 1.5860e-02, 2.0438e-02, 9.9168e-02],
[-5.8377e-02, 3.0828e-02, -4.9333e-02, 2.4652e-02, -9.0897e-02],
[-2.0278e-03, -9.2374e-02, 7.0469e-02, 4.6639e-02, 1.4054e-02]],
[[ 6.7318e-02, -1.4048e-01, -6.6024e-02, -7.5259e-02, -7.3505e-03],
[-5.9789e-03, 8.2757e-02, -3.2293e-03, -6.3011e-02, -3.5699e-02],
[-4.6309e-03, 2.9589e-02, -8.0518e-02, -3.5417e-02, -8.1604e-02],
[-4.7370e-02, -1.8663e-02, 1.4938e-02, -2.6669e-02, 4.6225e-02],
[ 1.0167e-01, -4.0576e-03, -6.3230e-02, 1.9653e-02, 2.6441e-03]],
[[ 7.7611e-02, 9.2593e-02, 2.8293e-02, -8.5795e-02, 2.8690e-03],
[ 2.5574e-03, 1.1727e-01, -1.4177e-01, -8.9946e-02, -8.8752e-02],
[ 1.1321e-01, -3.1915e-03, -1.5451e-01, -1.1661e-01, 3.4048e-02],
[-5.9230e-02, 3.2312e-02, 6.5634e-02, 6.5836e-02, 6.9050e-02],
[ 2.9482e-02, 5.5532e-02, 1.4458e-02, -2.6399e-02, 1.5079e-01]],
[[ 8.1164e-02, -1.5860e-02, -6.3951e-02, -1.4953e-01, 2.6079e-02],
[ 2.9781e-03, 3.3932e-02, -7.3611e-02, -3.7866e-02, 6.0404e-03],
[-9.4234e-02, 6.2938e-03, -1.1837e-01, -2.1665e-02, 7.8637e-02],
[-9.3179e-02, -4.4013e-02, 4.6450e-02, 1.6881e-02, -7.4646e-03],
[-1.1145e-02, -3.3850e-02, -8.3531e-02, -7.9519e-02, 1.1401e-01]]],
[[[ 3.2203e-02, 2.0910e-02, -9.8143e-02, 4.6689e-02, -3.8075e-02],
[ 1.1719e-01, 5.9686e-03, 1.1960e-02, -3.0230e-03, -5.2190e-02],
[ 1.9934e-03, -4.0891e-02, -3.4388e-02, 1.7186e-01, -2.0041e-02],
[ 6.2537e-02, -4.8928e-02, 3.7143e-02, -7.1022e-02, -5.8980e-02],
[-1.8138e-02, -1.2401e-03, 9.8512e-04, 1.2971e-02, -1.6621e-02]],
[[-6.7650e-02, -1.2190e-01, -2.9892e-02, -7.9844e-03, 7.0830e-02],
[ 1.9891e-02, 7.0736e-02, -5.9314e-02, -8.9657e-02, 5.4045e-02],
[ 6.0116e-02, 1.1283e-01, 8.1388e-02, -3.8265e-03, -8.1494e-02],
[-1.2480e-01, 4.0550e-02, -1.2108e-02, 3.0710e-02, -6.6255e-03],
[-3.1892e-02, 4.0138e-02, 1.0084e-01, -1.1013e-01, 1.3274e-03]],
[[ 1.6757e-01, 4.7177e-02, 3.3340e-02, -3.8067e-02, 8.9873e-02],
[-3.3208e-02, -1.9158e-03, 2.2761e-02, -1.7572e-02, 1.7825e-01],
[-3.9615e-02, 3.0436e-02, -2.7566e-02, 2.7972e-02, 4.4888e-02],
[-1.1037e-01, 2.3430e-02, -5.5122e-02, -1.4218e-01, -3.7357e-02],
[ 2.8481e-02, 8.2003e-02, 7.9964e-02, -2.2991e-03, 2.2325e-02]],
[[-4.3186e-02, 3.3446e-02, 7.1732e-02, 3.7266e-02, 8.8012e-03],
[-1.2796e-02, 6.7978e-02, 7.0809e-02, 7.3238e-02, -4.9962e-02],
[-1.9777e-02, 2.5981e-02, -1.3383e-01, -4.4816e-02, -4.7046e-02],
[-2.1233e-02, 6.0056e-02, -1.1861e-01, 7.6347e-02, 5.1444e-02],
[ 2.1803e-03, 7.3725e-02, -5.2894e-02, -1.7344e-02, 9.9680e-02]],
[[-8.6873e-02, -8.9221e-02, -5.2683e-02, 1.4289e-01, -8.9839e-02],
[-7.3905e-03, 3.6452e-02, 2.7346e-02, -8.3526e-02, 3.3003e-02],
[-9.6205e-02, 1.0080e-01, 1.0827e-01, -1.0169e-01, -1.2534e-01],
[ 1.2297e-01, 4.5520e-02, -8.3773e-02, -1.2272e-02, 1.1171e-02],
[-2.2908e-02, -5.9739e-03, -3.1049e-02, 4.7375e-02, 3.4770e-03]],
[[-4.6835e-02, -1.4363e-01, -1.4266e-01, 4.5771e-02, -7.2627e-02],
[-1.0377e-01, 5.4805e-02, 2.4762e-02, 1.2288e-01, -6.7578e-02],
[-1.0513e-01, 8.9687e-02, 4.6560e-02, -2.5460e-02, -7.3801e-02],
[ 3.8986e-02, -3.2050e-02, 2.2372e-02, 3.5860e-02, 1.3424e-02],
[-4.8953e-02, -7.0792e-02, -3.7677e-02, -9.9571e-02, -3.7092e-02]]],
[[[ 2.7008e-02, 9.3211e-02, -2.5650e-02, -1.2120e-01, 8.9026e-02],
[ 3.3563e-02, 4.8096e-02, 2.4675e-02, -6.3760e-02, -1.0420e-02],
[ 4.8996e-02, 1.9733e-02, -6.8690e-03, 5.3957e-02, 2.5845e-02],
[-1.0207e-01, 5.9896e-02, -5.3257e-02, 6.9946e-02, 7.8127e-03],
[ 3.7401e-02, 2.6128e-03, -6.7030e-02, 1.2134e-01, 1.4503e-01]],
[[ 3.0161e-02, 1.6157e-02, 8.4664e-02, 4.0789e-02, -5.2141e-02],
[ 6.8237e-02, -1.7970e-02, 2.1938e-02, -3.8875e-02, -4.1224e-02],
[ 2.0714e-02, -2.9881e-02, -7.7453e-03, 1.2165e-01, 7.5875e-02],
[ 1.2965e-02, 1.2694e-02, 5.0254e-02, 3.9009e-02, -1.0212e-01],
[ 1.5359e-02, 3.5578e-03, 3.9632e-02, -9.0168e-02, 1.0299e-02]],
[[-2.0960e-02, 3.3475e-02, 4.8820e-02, -5.4706e-02, 5.5381e-02],
[ 2.4722e-02, -2.7748e-02, -2.8843e-02, -2.2392e-02, -6.9717e-02],
[-2.7829e-02, 1.4679e-03, 2.4757e-02, -8.7816e-02, 1.8764e-02],
[-1.0710e-01, 1.1481e-01, -5.6012e-02, 8.1546e-02, -9.8836e-02],
[ 2.9070e-02, -6.9273e-02, -5.0700e-02, -4.3182e-02, -8.8499e-02]],
[[ 1.5313e-02, -3.8199e-02, 3.5146e-03, -1.5000e-02, -4.4576e-02],
[-2.5990e-02, 5.2993e-03, 3.3307e-02, -1.1827e-01, -8.6286e-03],
[ 1.0449e-02, 5.4776e-02, 5.0992e-03, -1.5860e-01, -7.7440e-03],
[-3.6204e-03, -4.7962e-02, -4.2371e-02, 6.8431e-02, -5.3339e-02],
[ 4.4269e-02, 4.1502e-02, -1.2462e-02, -7.4518e-02, -7.7262e-02]],
[[ 5.1610e-03, -5.2510e-02, 5.5136e-02, -1.9081e-02, -9.9088e-02],
[ 8.0757e-02, -3.4951e-02, 4.9554e-02, -3.2164e-02, 5.6647e-02],
[-8.6125e-02, 6.2141e-02, -7.2851e-02, 6.7731e-02, 1.7170e-02],
[ 5.2774e-02, -1.6591e-01, 3.4490e-02, 5.7166e-03, -6.3123e-02],
[ 4.2355e-02, -5.6360e-02, 7.3484e-02, 3.0939e-02, 3.8095e-02]],
[[ 1.5646e-03, 1.7834e-02, 7.3988e-03, -1.5700e-02, -3.9332e-02],
[ 1.0043e-02, 5.1222e-02, -9.1322e-02, 5.1683e-02, 1.2517e-02],
[ 8.1686e-02, -1.2309e-02, -2.5122e-02, 7.3020e-02, 5.9552e-02],
[-5.6445e-02, 2.9868e-02, 2.9082e-02, -1.0408e-02, 1.5398e-01],
[ 5.2559e-02, -1.1959e-02, -1.4521e-02, 9.7013e-02, -9.5779e-02]]],
...,
[[[ 1.0793e-01, -2.4074e-02, 3.9709e-02, 1.0110e-01, -1.1722e-01],
[-1.4333e-02, -1.4521e-02, 2.3373e-03, 6.7708e-02, -4.9924e-02],
[ 8.0992e-02, 6.2219e-03, -9.5247e-02, -1.0383e-02, 1.2116e-01],
[-1.3052e-01, 3.1068e-02, -2.0858e-02, 1.0203e-01, -1.2493e-02],
[-4.8291e-02, -3.7350e-02, -2.5973e-02, -5.2889e-02, -1.4211e-01]],
[[ 1.4048e-02, -6.0941e-02, 1.4314e-01, -4.0112e-02, 2.0749e-03],
[-5.3904e-02, 4.6111e-02, -9.0083e-02, -6.0871e-02, 2.1673e-02],
[-6.1354e-02, -2.4384e-02, 3.1761e-02, 3.5867e-02, -5.3144e-02],
[ 9.2168e-02, 1.2534e-02, -2.1070e-02, 1.7463e-01, 4.6518e-02],
[-8.0494e-02, 8.2646e-02, -1.1117e-01, 2.9695e-02, -6.1277e-02]],
[[-6.9924e-02, -3.8576e-02, -8.7181e-02, -1.2090e-01, -7.3311e-02],
[ 6.9342e-02, 9.1480e-02, -8.6245e-02, 1.2549e-01, 1.5314e-02],
[-5.9550e-02, -6.7689e-02, -3.1209e-02, -6.1005e-02, -4.3274e-02],
[-8.6596e-02, 2.7538e-02, -8.3665e-03, 6.6909e-02, 4.2178e-02],
[ 4.2158e-02, 3.3340e-03, 9.0755e-03, 1.9898e-03, -2.1537e-02]],
[[-1.8037e-02, -4.7885e-03, 7.9153e-03, 3.5667e-03, 6.1726e-02],
[-9.7950e-02, 2.7148e-02, 5.6938e-02, 2.7169e-02, 5.3373e-02],
[-9.3417e-02, 7.2626e-02, -2.6297e-02, 4.8845e-03, 2.5102e-02],
[-1.8309e-02, -2.4018e-04, 7.9144e-02, 9.0640e-02, -2.4433e-02],
[ 1.6471e-01, 9.4758e-02, -1.0681e-01, -1.2206e-01, -4.2883e-02]],
[[-3.1497e-02, -4.8369e-03, -6.9565e-02, 3.7852e-02, 8.5132e-02],
[-1.4033e-02, 9.7230e-02, -6.4714e-02, 5.6117e-02, 8.1022e-02],
[ 5.0067e-03, -8.6493e-02, -2.8838e-02, 2.2062e-02, 1.9418e-02],
[-5.2541e-05, 1.3477e-01, 1.5361e-02, -3.6824e-02, 3.4045e-02],
[-1.7272e-01, 2.9210e-03, 1.5939e-02, -1.3572e-02, 4.1629e-02]],
[[ 4.4768e-02, -3.5056e-02, 3.7457e-02, -9.0991e-02, 4.8750e-02],
[-1.3768e-02, -5.2716e-02, 4.2374e-02, -6.0435e-02, 6.5938e-02],
[ 4.1509e-02, -5.4461e-02, -1.1085e-01, 1.2433e-02, -6.0321e-02],
[-8.3952e-03, -2.5181e-02, 8.1283e-03, -2.1729e-02, 1.3753e-01],
[ 3.2617e-03, 1.1116e-01, -5.8501e-02, 7.8371e-03, -4.6730e-03]]],
[[[ 7.5170e-03, -6.9893e-02, -7.2599e-02, -3.9132e-02, -8.8662e-02],
[ 8.1816e-02, 1.0796e-02, 4.0675e-02, -1.0288e-01, 1.6084e-02],
[ 5.3820e-02, -3.8369e-03, -3.4138e-02, -6.7229e-02, 3.4613e-02],
[ 3.6387e-02, -6.1193e-02, 1.1059e-01, 4.5436e-02, 2.8397e-02],
[-1.6659e-02, -9.3031e-02, 2.7387e-02, 9.7177e-02, -4.2562e-02]],
[[-2.0895e-02, 6.4056e-02, 9.3401e-02, 7.4680e-02, 4.2360e-02],
[ 3.6859e-02, 1.3341e-03, -1.7436e-02, 1.3251e-02, -4.4198e-02],
[-1.3440e-02, -5.0658e-02, -2.3869e-02, -1.0197e-03, 1.3315e-01],
[ 2.6638e-02, -1.4311e-01, -2.0728e-02, -3.4138e-02, -6.8038e-03],
[-2.3564e-02, -1.3118e-02, 6.4697e-02, 1.9203e-02, -5.7187e-02]],
[[ 7.9665e-02, -1.1841e-02, 8.7050e-03, 3.5707e-02, -7.4566e-02],
[ 1.0561e-01, 2.4171e-02, -5.8087e-02, -3.1631e-02, 1.4183e-01],
[-2.9622e-02, 6.5245e-02, -8.4141e-02, -9.6315e-03, 7.5537e-02],
[-1.1621e-02, -2.8678e-02, -1.8603e-02, -3.6197e-02, -6.8541e-02],
[-5.3699e-02, 8.2185e-05, 1.3339e-02, -5.6374e-02, 9.1662e-02]],
[[ 5.0422e-02, -3.0555e-02, -1.1994e-02, -2.8864e-02, -4.7944e-02],
[-9.2082e-02, -2.0114e-02, 1.1080e-03, -8.7728e-02, 8.1596e-03],
[-5.4058e-03, 4.7326e-02, -1.7694e-02, 1.1631e-02, -1.4565e-01],
[-6.5642e-02, -5.3809e-03, 4.6345e-02, -1.3860e-02, 5.7886e-02],
[-1.1082e-02, -3.7631e-02, 9.4582e-02, -9.6432e-02, -3.4886e-02]],
[[-8.1977e-03, 4.6430e-02, -7.8213e-04, 3.4898e-02, 2.3575e-03],
[ 1.0584e-01, 9.4443e-02, -1.8790e-02, -1.0291e-01, -1.2920e-01],
[ 2.9621e-02, 3.0077e-02, -9.9012e-02, -6.2667e-02, -1.1825e-01],
[-1.2209e-03, 1.8486e-02, -1.2279e-01, -8.0225e-02, -4.3227e-02],
[-4.5238e-02, 4.2571e-02, 3.5998e-02, 4.2050e-02, -1.2023e-01]],
[[ 4.7820e-02, 1.6323e-01, -9.7011e-02, 1.2791e-02, -7.0172e-02],
[ 4.1511e-02, 8.5733e-02, -5.9874e-03, -1.0885e-02, 6.1121e-02],
[-1.4934e-01, 7.1932e-02, 2.5868e-02, -1.3355e-02, 1.3271e-01],
[ 2.7480e-03, -2.4276e-02, 4.9873e-02, -1.1099e-02, -1.0676e-01],
[ 6.6207e-02, 1.6302e-02, -1.1455e-01, 1.0844e-02, -4.8347e-02]]],
[[[-9.9760e-02, 1.9150e-02, -1.4017e-01, -1.1585e-01, -3.8237e-03],
[-3.0495e-02, -1.0176e-01, -1.6690e-02, -2.9941e-02, -8.9302e-03],
[ 8.2867e-02, 2.4127e-02, 9.1499e-02, 7.4052e-02, 1.6828e-03],
[-1.1307e-01, 6.4954e-05, 1.1348e-01, 1.8891e-02, 2.4209e-03],
[-3.4641e-03, -1.0385e-01, -1.3858e-01, -8.9685e-02, 7.7434e-02]],
[[-4.5025e-02, 1.1394e-01, -6.5345e-02, -3.0052e-02, -4.3911e-02],
[ 6.8716e-02, -9.5503e-02, 5.0614e-02, 7.8437e-02, 1.0309e-01],
[-2.7321e-02, -8.1157e-02, 6.0459e-02, -8.1802e-02, -6.3444e-02],
[ 9.0615e-02, -1.3963e-01, 9.8714e-02, 1.2384e-01, -1.0084e-01],
[-1.5519e-02, -6.3194e-03, -3.4565e-02, 8.4971e-02, 3.9733e-02]],
[[-2.8205e-02, 7.0934e-02, 6.6078e-02, 7.2542e-02, -9.7110e-02],
[ 7.3296e-02, 3.7896e-03, -7.2323e-02, 1.6002e-02, 2.9101e-02],
[-7.3723e-02, 1.5889e-01, -1.8016e-02, 6.1716e-02, -1.9445e-02],
[-8.9607e-02, -7.3749e-02, 9.8301e-03, -3.4523e-02, -2.1987e-02],
[-9.3157e-02, 7.1487e-02, -4.6032e-02, 5.4644e-02, 5.0541e-02]],
[[ 6.0963e-02, 8.3188e-02, 4.2527e-02, 5.3783e-02, -3.3220e-02],
[-7.2273e-02, 8.5987e-02, 1.5571e-02, 3.0730e-02, 2.9485e-02],
[-8.2758e-03, -1.0446e-01, -5.6690e-02, -9.5882e-02, -1.0161e-01],
[ 8.2187e-02, -8.8814e-02, 1.6184e-02, 4.1667e-02, 2.8868e-02],
[ 2.5378e-02, -4.8991e-02, 2.9006e-02, -9.1864e-03, -9.6781e-02]],
[[ 2.9548e-02, -1.1967e-02, -2.0336e-02, 4.9937e-02, -5.2592e-02],
[-7.7833e-02, -6.7267e-03, 6.7678e-02, 1.2339e-02, 2.1119e-02],
[-3.2693e-02, -2.8428e-02, -7.3010e-02, 7.4197e-02, 6.0370e-02],
[ 1.1135e-01, 6.2439e-03, 6.6132e-02, 8.8876e-02, 6.7757e-02],
[ 8.4736e-03, 6.1874e-02, -5.5572e-03, -3.2001e-03, 3.9575e-02]],
[[ 9.4309e-02, 6.2423e-02, 1.6870e-01, -1.4243e-01, 2.1112e-02],
[-2.2676e-02, -5.8523e-02, 4.6869e-03, 1.1016e-01, 2.4293e-02],
[-1.0934e-01, -3.8908e-02, -3.0408e-02, -2.9326e-03, 1.8012e-02],
[-5.2711e-02, -9.1650e-02, 1.5264e-01, 3.1413e-02, 1.0127e-01],
[-5.4185e-02, -1.2086e-01, -5.9944e-02, -3.9325e-02, -5.0325e-02]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0382, -0.0121, -0.0262, 0.0070, 0.0597, 0.0315, -0.0379, -0.0624,
-0.0017, -0.0016, 0.0567, 0.0240, 0.0155, 0.0450, -0.0316, -0.0051],
requires_grad=True)
tensor: tensor([ 0.0458, -0.0146, 0.0629, 0.0288, 0.0599, 0.0193, -0.0135, -0.0289,
0.0030, -0.0048, -0.0006, -0.0199, 0.0055, 0.0746, -0.0526, 0.0112],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[0., 0., 0., 0., 0.],
[-0., -0., 0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.],
[-0., 0., 0., -0., -0.]],
[[-0., 0., -0., -0., -0.],
[-0., 0., 0., -0., 0.],
[0., -0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[0., -0., 0., 0., 0.]],
[[-0., -0., 0., 0., 0.],
[0., -0., 0., 0., -0.],
[0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., -0.]],
[[0., -0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[-0., -0., 0., -0., 0.],
[0., -0., -0., 0., -0.]],
[[0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., -0., 0.]],
[[0., -0., -0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., 0., -0., 0., 0.],
[-0., -0., 0., -0., -0.],
[-0., 0., -0., -0., 0.]]],
[[[0., 0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., -0., 0.],
[0., -0., -0., 0., -0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.]],
[[0., 0., -0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., -0., 0.],
[0., 0., 0., -0., -0.]],
[[-0., 0., 0., -0., -0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., 0., -0.],
[-0., -0., -0., 0., 0.],
[0., 0., -0., -0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., 0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.]]],
[[[0., 0., 0., -0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., -0., 0.]],
[[-0., 0., 0., -0., 0.],
[-0., 0., -0., 0., -0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[0., -0., -0., 0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., -0., 0., 0., 0.]],
[[-0., -0., 0., 0., -0.],
[0., 0., -0., 0., 0.],
[0., 0., -0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., 0., -0., -0., -0.]]],
...,
[[[0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.]],
[[0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[0., -0., -0., -0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., 0.],
[0., -0., -0., -0., -0.]],
[[-0., -0., -0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0.]],
[[0., 0., -0., -0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., 0., -0., 0.]],
[[0., -0., 0., -0., 0.],
[-0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., 0., -0.]]],
[[[0., -0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., 0., -0.],
[0., -0., 0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., 0., -0.],
[-0., 0., -0., -0., 0.]],
[[0., 0., -0., -0., -0.],
[-0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., -0., -0., -0., 0.],
[0., -0., 0., -0., -0.]],
[[-0., 0., -0., 0., 0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.]],
[[-0., 0., -0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., -0., -0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., -0., 0., -0.]]],
[[[-0., 0., -0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., -0., -0., -0., 0.]],
[[-0., 0., -0., -0., -0.],
[0., -0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., 0., 0., -0.],
[0., 0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., 0., -0., -0.],
[-0., 0., 0., -0., 0.]],
[[0., 0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[0., -0., 0., 0., -0.],
[0., -0., -0., -0., 0.]],
[[0., 0., -0., -0., 0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., -0., -0., -0.]],
[[0., 0., 0., -0., 0.],
[0., 0., 0., 0., 0.],
[-0., -0., 0., 0., 0.],
[-0., -0., 0., -0., 0.],
[-0., 0., -0., -0., -0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[ 1.7231e-02, 4.3731e-02, 5.4681e-02, 3.4570e-02, 4.3578e-02],
[-2.7786e-02, -4.7320e-02, 3.1995e-02, 4.2537e-02, 7.5819e-02],
[ 8.9881e-03, 2.6359e-02, -8.3574e-03, -5.6418e-02, -3.9465e-02],
[-1.9657e-02, 5.7231e-03, 5.5066e-03, -1.4926e-02, -7.4856e-02],
[-6.2679e-02, 3.7226e-02, 8.0120e-02, -2.3663e-02, -5.3511e-02]],
[[-7.5590e-02, 6.4561e-02, -4.3921e-02, -7.7799e-03, -5.2721e-02],
[-1.0355e-02, 5.9574e-02, 1.1166e-02, -4.8245e-02, 4.1576e-02],
[ 5.6966e-02, -7.2947e-03, -1.9590e-03, 6.9857e-02, -2.1827e-02],
[ 3.2483e-02, -6.1154e-02, 7.6961e-02, 4.5617e-02, -5.6834e-02],
[ 3.4771e-02, -3.0784e-02, 4.8407e-02, 2.7358e-02, 6.5453e-02]],
[[-7.3062e-02, -2.7048e-02, 1.0655e-02, 7.5364e-02, 6.3020e-02],
[ 1.2004e-02, -3.4794e-02, 4.0299e-02, 6.4137e-02, -2.4964e-02],
[ 7.4683e-02, -2.4405e-02, -6.7041e-02, 2.7558e-02, 2.7585e-02],
[-2.3105e-02, 3.5924e-02, -1.3142e-02, 3.6770e-02, -5.1734e-02],
[-3.0248e-03, -7.3805e-02, 1.5975e-02, -3.8458e-02, -3.4536e-02]],
[[ 1.2704e-02, -6.5938e-02, 5.8697e-04, -4.0066e-03, 6.0744e-02],
[-6.6682e-02, 7.7161e-02, 8.1204e-02, 1.0834e-02, -9.7154e-03],
[-4.5248e-02, 4.1191e-02, -3.1746e-02, 1.4713e-02, 1.5446e-02],
[-6.1854e-02, -5.4387e-02, 3.8440e-02, -3.7393e-02, 7.6462e-02],
[ 6.8184e-02, -3.0699e-02, -6.0217e-02, 3.2648e-02, -2.2715e-02]],
[[ 7.6031e-02, 6.2038e-02, 2.1564e-02, -6.5555e-02, -6.4686e-02],
[-4.2324e-02, 3.6401e-02, -6.7974e-02, -3.9416e-02, -4.4166e-02],
[ 6.7690e-02, -6.9973e-02, -4.7069e-02, -7.6832e-03, -4.0938e-03],
[-2.6045e-02, -4.6521e-02, 3.4383e-03, 6.4282e-02, -2.5852e-02],
[ 3.2686e-02, 3.1602e-02, 3.5400e-02, -7.0873e-02, 7.1903e-02]],
[[ 6.0529e-02, -5.2981e-02, -3.9675e-02, -6.5714e-02, 2.8020e-02],
[ 3.5057e-02, 1.3588e-02, -3.5319e-02, -7.1891e-02, -6.0487e-02],
[-6.3268e-02, 1.4282e-02, -3.6912e-02, 3.1696e-02, 5.1382e-02],
[-7.2075e-02, -5.2067e-02, 5.7057e-02, -6.0579e-02, -5.5177e-02],
[-2.2326e-03, 6.6584e-02, -6.6819e-02, -5.9396e-02, 4.9889e-02]]],
[[[ 6.1514e-02, 3.9812e-02, -7.7776e-02, -4.9428e-02, -2.5424e-02],
[-1.3566e-02, -3.6820e-02, -1.2480e-02, 4.0800e-02, -7.4332e-02],
[ 2.8230e-02, -2.4562e-02, -2.9086e-02, 6.3475e-02, -1.5058e-02],
[-1.5993e-02, 2.6303e-02, 8.7087e-03, -5.3252e-02, 2.7997e-02],
[ 2.8034e-02, -3.0461e-02, -1.0604e-02, 4.4506e-02, -1.2910e-02]],
[[-7.8455e-02, -7.8485e-02, -5.5244e-02, 4.7309e-03, -3.7835e-02],
[-5.3291e-02, 3.5150e-02, -6.4878e-02, -7.4903e-02, -5.5693e-02],
[ 3.2479e-02, 6.5471e-02, 4.4178e-02, -5.4057e-02, -7.9336e-02],
[-6.8247e-02, 6.5587e-02, -8.7773e-04, -3.5329e-02, -3.8033e-02],
[-3.8157e-02, 4.2570e-02, 7.6967e-02, -4.9785e-02, -2.8498e-02]],
[[ 4.3747e-02, 5.4973e-02, -2.6460e-02, -1.5235e-02, 7.8260e-02],
[ 1.8068e-02, 1.9185e-02, -6.4915e-02, 7.4767e-03, 6.4246e-02],
[-6.3240e-02, -7.0417e-03, -6.4609e-02, -3.5764e-03, 1.0317e-02],
[-4.5486e-02, 7.3897e-02, -2.8131e-02, -8.0268e-02, 1.2427e-02],
[ 2.6636e-02, 5.3320e-02, 8.0113e-02, -2.4988e-02, -3.3729e-02]],
[[-3.0252e-02, 2.0032e-02, 2.2845e-02, -3.2871e-02, -7.2805e-02],
[-1.2501e-02, 1.2056e-02, 1.7524e-02, 3.1636e-02, -3.6845e-02],
[ 2.2584e-02, 1.5433e-02, -8.1113e-02, 4.8879e-03, -4.5628e-02],
[-6.5354e-03, -1.2652e-02, -6.9698e-02, 6.7797e-02, 6.2843e-02],
[ 1.2404e-02, 1.8414e-02, -4.5958e-02, -6.2785e-03, 4.0931e-02]],
[[-1.8430e-02, -2.1989e-02, -8.0820e-03, 7.2810e-02, -9.3746e-03],
[-7.0612e-02, 5.1172e-02, -7.0489e-03, -4.5064e-02, 3.5534e-02],
[-3.2953e-02, 3.0727e-02, 4.8882e-02, -3.9965e-02, -2.3211e-02],
[ 6.9945e-02, 2.6839e-02, 3.3160e-03, -2.7622e-02, 7.6837e-02],
[-1.7656e-02, -4.1508e-03, 3.8821e-02, 3.6178e-02, 3.2764e-02]],
[[-3.1492e-02, -7.9136e-02, -4.2417e-02, 3.8367e-02, -7.3788e-02],
[-4.9922e-02, 2.4759e-02, 1.3668e-03, 7.1963e-02, -6.7763e-02],
[-4.6869e-02, 6.6676e-02, 3.5155e-02, -1.8898e-02, -3.7238e-02],
[-1.2625e-02, -5.9552e-02, -6.8551e-02, 5.6212e-02, -3.5397e-02],
[-3.0192e-02, -5.3948e-02, 4.1616e-03, -1.9126e-02, -2.7772e-02]]],
[[[ 7.5304e-02, 4.8642e-02, 6.6816e-03, -4.8956e-02, 7.0525e-03],
[ 5.3506e-02, -6.3276e-02, 3.2309e-02, -6.3837e-02, -7.6400e-03],
[ 8.0993e-02, 5.0589e-02, -4.4620e-02, -1.7407e-02, -4.0847e-02],
[-7.9109e-02, -5.0230e-02, -6.1707e-02, 3.5480e-02, 9.4730e-03],
[-4.0758e-02, 3.8149e-02, -2.6608e-02, 6.5586e-02, 6.7719e-02]],
[[ 3.7999e-02, 5.1426e-02, 4.6990e-02, -8.4217e-04, -7.8228e-02],
[ 7.4991e-02, -1.3921e-03, -5.2225e-02, -3.7669e-02, -1.3416e-02],
[-6.8178e-02, -3.8191e-02, -4.8837e-02, 5.6142e-02, 8.0981e-02],
[-4.9641e-02, 2.5925e-02, 3.2090e-02, 1.2274e-02, -7.0899e-02],
[ 2.0209e-02, 7.2872e-02, -2.2630e-02, -6.1033e-02, 4.0857e-02]],
[[-6.5934e-02, 6.1233e-02, 3.0293e-02, -7.4631e-02, 7.8141e-02],
[-1.4982e-02, 2.9501e-02, -6.2282e-02, 2.6212e-02, -2.5934e-02],
[ 3.0596e-02, 2.5406e-02, 4.6264e-02, -2.3452e-02, -1.9204e-02],
[ 8.5599e-03, 7.3909e-02, -6.4926e-02, -8.6878e-03, -3.7090e-02],
[ 3.4250e-02, -5.0255e-02, -7.8042e-02, -5.1085e-02, -8.0715e-02]],
[[ 2.6317e-02, -3.6153e-02, -5.3455e-02, 2.1931e-03, 6.2524e-04],
[-4.6965e-02, 3.9974e-02, -1.2460e-02, -4.2091e-02, -4.6497e-02],
[-3.4564e-02, -1.1512e-02, 2.1072e-02, -6.0303e-02, 1.7830e-02],
[-6.8750e-02, 2.9228e-02, 2.5533e-02, 5.8319e-02, -2.6735e-02],
[ 3.2792e-02, 3.2718e-02, 5.9794e-02, -7.6940e-02, -4.0923e-02]],
[[-7.0395e-02, -7.5975e-02, -2.7344e-02, -2.3934e-02, -2.9200e-02],
[ 5.4427e-02, -6.5287e-02, 3.4746e-02, 1.0117e-02, -6.4013e-02],
[-7.9322e-02, 7.8159e-02, -7.0473e-02, 4.4684e-02, 2.1939e-02],
[ 5.5900e-02, -6.5708e-02, 3.3804e-02, -1.8570e-02, -7.0815e-02],
[ 9.0783e-03, -7.2442e-02, 7.7275e-02, 5.2036e-03, 2.7754e-02]],
[[-3.5030e-02, -1.2288e-02, 6.1587e-02, 1.0093e-02, -3.9849e-02],
[ 3.1310e-02, 3.2095e-02, -7.3972e-02, 6.0673e-02, 6.6248e-02],
[ 7.5880e-02, 7.9613e-02, -1.4684e-02, 6.6668e-02, 1.1897e-02],
[-6.6942e-02, 6.9498e-02, 7.3219e-02, -5.7326e-02, 7.9767e-02],
[-2.2455e-02, 2.7784e-02, -5.6006e-02, -2.8644e-02, -4.7902e-02]]],
...,
[[[ 5.4641e-02, -2.8195e-02, 1.2863e-02, 6.9362e-02, -6.5228e-02],
[-1.4326e-02, -2.4860e-02, 6.7933e-02, 6.0904e-03, -5.2127e-02],
[-3.5126e-02, 1.9526e-02, -4.2415e-02, 7.7512e-03, 6.7621e-02],
[-7.7149e-02, -3.9294e-02, 2.8953e-02, 2.8484e-02, -5.9953e-02],
[ 3.0683e-02, -7.1082e-02, -1.7986e-02, 1.3298e-02, -6.4964e-02]],
[[ 3.5600e-02, -6.4155e-02, 4.7039e-02, -2.0131e-02, -2.7915e-02],
[ 2.9918e-02, 6.5944e-03, -6.6870e-02, -6.3787e-02, -4.9677e-02],
[-3.6079e-02, -2.8304e-02, 2.9721e-02, 2.8190e-02, -5.0218e-02],
[ 6.4923e-02, -4.9635e-02, -3.6667e-03, 7.9379e-02, 2.5979e-02],
[-4.8337e-02, 7.7505e-02, -7.3112e-02, 2.4510e-02, 2.5683e-02]],
[[ 2.8887e-03, -2.1671e-02, -5.2347e-02, -6.3329e-02, -4.0586e-02],
[-1.2757e-02, 3.3395e-02, -7.8268e-02, 7.3369e-02, 5.0369e-04],
[ 2.6221e-02, -2.9271e-02, -6.5565e-02, -1.6796e-02, -4.9055e-02],
[-5.8221e-02, -4.2509e-02, 4.6818e-06, 2.6047e-05, 4.1964e-02],
[ 1.0361e-02, -1.0747e-02, -5.5872e-02, -4.5506e-02, -2.9223e-02]],
[[-1.9352e-02, -8.0087e-02, -3.3809e-03, 3.9983e-02, 6.5648e-02],
[-2.5674e-02, 5.8915e-02, 1.8416e-02, 5.8460e-02, 3.2707e-02],
[-6.6357e-02, 6.9795e-02, 8.6752e-03, 5.9294e-02, 1.7985e-02],
[-6.5379e-02, 2.3563e-02, 5.0532e-02, 3.5488e-03, -4.1146e-02],
[ 8.1383e-02, 5.7224e-02, -7.2400e-02, -4.0180e-02, -6.1370e-02]],
[[ 5.9150e-02, 5.6013e-02, -4.5474e-02, -4.0693e-02, 4.2932e-02],
[-1.3553e-02, 4.4707e-02, -2.7249e-02, 2.3061e-02, 1.9638e-02],
[ 2.1247e-02, -6.3221e-02, 5.1882e-02, -1.6282e-02, 6.9770e-02],
[-7.0485e-02, 7.4524e-02, 4.4509e-02, -6.5970e-02, 2.9617e-02],
[-8.1458e-02, -4.9716e-02, 1.2315e-02, -2.0425e-02, 3.0172e-02]],
[[ 5.2212e-02, -3.5905e-02, 2.5783e-02, -6.0258e-02, 3.5215e-02],
[-7.5870e-02, -1.5704e-02, 3.3627e-02, -4.0729e-02, 5.7335e-02],
[-3.5374e-02, -7.5164e-02, -7.3468e-02, -1.1014e-02, 1.6214e-02],
[-3.1993e-02, -2.4012e-02, -1.6525e-03, -6.9434e-02, 2.8824e-02],
[-2.4923e-02, 7.8550e-02, 4.5400e-02, 2.7779e-02, -6.5854e-02]]],
[[[ 2.3195e-02, -6.8559e-02, -1.9293e-02, -3.7088e-02, -7.3186e-02],
[ 7.8055e-02, 5.0381e-03, 3.0678e-02, -5.8232e-02, 2.8428e-02],
[-2.2133e-02, -1.6136e-03, 6.5804e-02, 3.9714e-03, -3.9261e-02],
[ 2.5493e-02, -1.4515e-02, 3.1299e-02, -1.6629e-02, -2.5878e-02],
[-2.4748e-02, -6.8695e-02, 4.8038e-02, 1.7510e-02, -2.4795e-02]],
[[ 2.7344e-02, 5.4179e-02, 1.8617e-02, 6.7468e-02, 4.8763e-02],
[ 3.7600e-02, 3.9927e-02, 5.1062e-02, 2.1710e-02, -3.2169e-02],
[-3.8513e-02, -4.6700e-02, -3.3343e-02, 5.7257e-02, 7.1398e-02],
[ 3.1596e-02, -6.1682e-02, -1.1294e-02, -4.6606e-02, -1.9235e-02],
[-5.1762e-02, 4.1756e-03, 5.5901e-02, 5.0582e-02, -3.5234e-02]],
[[ 6.8751e-02, 1.9294e-02, 9.9260e-04, 4.8577e-02, 1.1296e-02],
[-2.5931e-03, 5.6043e-02, -3.9379e-02, -1.5890e-02, 2.7560e-02],
[-6.1309e-02, 4.4243e-02, -6.8550e-02, -7.5816e-02, 6.4328e-02],
[-3.5933e-02, 1.5707e-02, -4.1360e-03, 2.3218e-02, -5.3996e-02],
[-5.8497e-02, 2.2945e-02, -1.6730e-02, -4.9801e-02, 4.5134e-02]],
[[ 4.3283e-02, 1.4086e-02, -7.7765e-03, -2.4735e-02, -6.1307e-02],
[-3.7167e-02, 1.0755e-05, -3.5806e-02, -2.9059e-02, 7.9799e-02],
[-2.3940e-02, 3.7716e-02, 2.2327e-02, 4.3423e-02, -4.9689e-02],
[-3.8642e-02, -4.3529e-02, -2.7830e-02, -4.9579e-02, 6.6404e-02],
[ 1.3762e-02, -2.3933e-02, 7.1659e-02, -1.3726e-02, -8.0242e-02]],
[[-1.5402e-02, 1.8298e-02, -5.1471e-02, 2.4580e-02, 9.6023e-03],
[ 6.1876e-02, 7.8261e-02, 2.6394e-02, -7.8227e-02, -7.6062e-02],
[ 6.7169e-02, -7.8952e-03, -7.6834e-02, -7.2395e-02, -8.1512e-02],
[ 2.4895e-02, 7.4719e-02, -6.9676e-02, -2.0183e-02, 4.7940e-02],
[-1.6931e-02, 6.4322e-02, 4.9096e-02, 6.7067e-02, -5.1128e-02]],
[[-4.2134e-03, 7.9587e-02, -7.7337e-02, -1.6919e-02, -4.4513e-02],
[ 4.5003e-02, 2.9848e-02, -3.2239e-02, -5.3997e-02, 3.4833e-02],
[-4.6084e-02, 7.7325e-02, -8.2341e-04, -2.9711e-02, 4.7059e-02],
[ 7.1990e-02, -1.8925e-02, 6.9833e-02, -3.8232e-02, -5.3586e-02],
[ 5.6777e-02, 5.4212e-02, -7.0351e-02, 7.8116e-02, -5.8073e-03]]],
[[[-3.9872e-02, 2.2878e-02, -5.4838e-02, -7.8741e-02, -2.4075e-02],
[ 5.5670e-02, -7.5194e-02, -2.3993e-02, -1.3565e-02, -7.6118e-02],
[ 5.9835e-02, 7.7078e-02, -1.9101e-02, 3.7423e-02, 5.8969e-02],
[-7.6931e-02, 5.4068e-02, 7.6462e-02, 6.0935e-02, 6.1393e-02],
[-3.0153e-02, -6.9821e-02, -7.9367e-02, -6.8787e-02, 4.8573e-02]],
[[-7.7210e-04, 1.2697e-02, -7.1333e-02, -4.3644e-02, -3.0627e-02],
[ 6.4280e-02, -6.3660e-02, 5.7517e-02, 3.5869e-02, -2.4693e-02],
[ 4.2786e-03, -8.1200e-02, 6.9931e-02, -2.5703e-04, 1.7692e-02],
[ 8.5987e-03, -1.2595e-02, 7.9205e-02, 3.0073e-02, -3.2985e-02],
[-6.3697e-02, 3.2692e-02, -1.9431e-02, 5.3542e-02, -2.0049e-02]],
[[ 4.7488e-02, 3.1486e-02, 4.3938e-02, 3.8207e-02, -6.3004e-02],
[ 7.6382e-02, 1.8666e-02, 1.0028e-02, 6.2085e-02, 5.3552e-02],
[-3.0010e-02, 2.7386e-02, -2.2148e-02, -5.4034e-02, -2.1415e-02],
[-3.2287e-02, -4.1362e-02, 1.2052e-02, -6.5838e-02, -4.6819e-02],
[-6.8102e-02, 5.9098e-02, 2.8529e-02, -5.3848e-02, 2.3559e-02]],
[[ 2.5513e-02, 3.7517e-02, 5.5636e-02, 4.3730e-02, -3.5048e-02],
[-5.4454e-02, 7.0706e-02, -5.7952e-02, 2.3890e-02, 3.0251e-02],
[ 2.0294e-02, -6.2255e-02, -7.7577e-02, -6.8416e-02, -4.8070e-02],
[ 5.3928e-02, -6.0171e-02, 4.9991e-02, 4.6665e-02, -1.5579e-02],
[ 1.9901e-02, -6.1094e-02, -1.4091e-02, -6.6292e-02, 1.2545e-02]],
[[ 7.3009e-02, 7.1030e-02, -3.5882e-02, -5.9879e-02, 2.1529e-02],
[-2.7738e-02, -4.8476e-02, -3.5715e-03, -5.2242e-03, 4.8341e-03],
[-4.1100e-02, 2.7022e-02, -5.5728e-02, 5.0925e-02, 2.2531e-02],
[ 6.5409e-02, 2.5243e-02, 5.5194e-02, 5.0815e-02, 3.7556e-02],
[ 4.0211e-02, 3.1016e-02, -2.9596e-02, -4.3925e-03, -4.0317e-02]],
[[ 4.9369e-02, 4.1262e-02, 7.0892e-02, -7.3260e-02, 4.2668e-02],
[ 3.7235e-02, 3.5402e-02, 3.3255e-02, 5.4474e-02, 5.3561e-02],
[-1.7112e-02, -1.1525e-02, 1.9306e-02, 9.0656e-04, 2.4812e-02],
[-1.0841e-02, -3.1940e-02, 6.6983e-02, -1.3595e-02, 7.1947e-02],
[-1.4649e-02, 3.0190e-02, -4.8740e-02, -8.1647e-02, -4.3005e-02]]]])
(bias): Normal:
loc: tensor([0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., -0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0382, -0.0121, -0.0262, 0.0070, 0.0597, 0.0315, -0.0379, -0.0624,
-0.0017, -0.0016, 0.0567, 0.0240, 0.0155, 0.0450, -0.0316, -0.0051])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[-3.3777e-02, 4.7079e-03, -2.3445e-02, ..., -6.8087e-03,
-4.6413e-02, -4.1360e-02],
[-1.6733e-02, 5.1208e-03, 3.9803e-02, ..., -4.5116e-02,
1.8346e-03, -1.1031e-02],
[ 2.3320e-02, 4.0388e-03, -4.6767e-02, ..., -3.7066e-02,
3.7666e-02, -1.2776e-02],
...,
[-1.3560e-05, 3.7272e-02, -1.6224e-02, ..., -3.3796e-03,
3.3060e-02, 4.5754e-02],
[-9.2607e-03, -4.9655e-02, -3.0438e-02, ..., 1.7757e-02,
-4.1499e-02, -1.2796e-02],
[-3.5203e-02, -3.5148e-03, 4.2838e-03, ..., -2.5652e-02,
-7.0994e-03, -2.2834e-02]], requires_grad=True)
tensor: tensor([[-0.0664, 0.0043, 0.0788, ..., 0.0156, 0.0026, 0.0044],
[ 0.0383, -0.0106, 0.0170, ..., -0.0362, -0.0802, -0.0022],
[ 0.0117, 0.0738, -0.0664, ..., -0.1182, -0.0065, -0.0954],
...,
[ 0.1008, -0.0091, -0.0501, ..., -0.0091, -0.0095, 0.0133],
[-0.0029, 0.0151, 0.0108, ..., 0.0052, -0.0712, -0.0245],
[-0.0082, 0.0598, 0.0241, ..., -0.1353, 0.0168, -0.0179]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0091, -0.0222, -0.0435, -0.0427, 0.0332, -0.0145, 0.0444, -0.0225,
-0.0252, -0.0384, -0.0325, -0.0306, 0.0136, -0.0440, 0.0226, 0.0197,
-0.0346, 0.0281, -0.0020, -0.0201, -0.0482, -0.0426, 0.0018, 0.0065,
0.0293, 0.0076, 0.0418, -0.0487, 0.0359, -0.0118, 0.0179, -0.0176,
0.0471, -0.0438, 0.0071, 0.0235, 0.0245, -0.0395, 0.0206, 0.0396,
-0.0399, -0.0174, 0.0419, 0.0254, -0.0254, -0.0495, -0.0108, 0.0297,
0.0295, 0.0209, -0.0487, 0.0192, 0.0433, -0.0472, -0.0043, -0.0372,
0.0491, -0.0487, 0.0278, 0.0408, 0.0268, -0.0455, 0.0474, -0.0153,
-0.0464, -0.0248, 0.0130, 0.0203, -0.0359, -0.0385, -0.0100, 0.0170,
0.0464, 0.0401, -0.0339, 0.0245, 0.0023, 0.0165, 0.0467, 0.0407,
-0.0495, 0.0334, 0.0381, 0.0333, -0.0411, -0.0329, 0.0327, -0.0023,
0.0222, 0.0137, -0.0227, -0.0370, 0.0094, 0.0082, -0.0387, -0.0141,
-0.0045, -0.0353, -0.0282, 0.0103, -0.0167, -0.0474, -0.0296, -0.0452,
-0.0296, -0.0141, -0.0436, -0.0171, -0.0035, 0.0408, 0.0441, -0.0025,
0.0083, 0.0330, -0.0323, 0.0105, 0.0155, 0.0186, -0.0094, -0.0043],
requires_grad=True)
tensor: tensor([-0.1194, 0.0018, -0.0382, -0.0899, 0.0034, 0.0102, 0.0553, -0.0548,
-0.0950, 0.0837, -0.0716, 0.0613, 0.0356, -0.1366, 0.0281, 0.0833,
-0.0585, 0.0916, -0.0311, -0.0853, -0.0854, -0.0659, 0.0470, -0.0908,
0.0431, 0.1137, 0.0463, -0.1019, -0.0060, -0.0466, -0.0754, 0.0031,
0.0433, -0.0394, 0.0111, -0.0469, 0.0161, -0.0257, -0.0163, 0.0087,
-0.0052, -0.0057, 0.0504, -0.0218, -0.0147, -0.0882, 0.0371, 0.0567,
0.0220, 0.0803, -0.0448, -0.0313, 0.0571, -0.1098, -0.0201, -0.0401,
0.0568, -0.0305, -0.0368, 0.0482, 0.0158, -0.0946, 0.0302, -0.0686,
-0.0144, 0.0005, 0.1154, 0.0250, 0.0130, -0.0541, -0.0267, 0.0535,
0.0697, 0.0690, -0.0767, 0.0776, 0.0745, 0.0335, 0.0614, 0.0944,
-0.0984, 0.0523, 0.0240, 0.0436, -0.0207, 0.0420, 0.1366, -0.0816,
0.0789, -0.0659, 0.0373, -0.0096, 0.0804, 0.0902, -0.1037, -0.0539,
0.0304, -0.1223, 0.0599, -0.0474, -0.0649, -0.1068, 0.0388, -0.0095,
-0.0560, -0.0583, -0.0698, -0.0469, 0.0231, 0.0534, 0.0424, 0.0062,
-0.0409, 0.0445, -0.0633, -0.0341, 0.0322, 0.0177, 0.0837, -0.0139],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., 0., -0., ..., -0., -0., -0.],
[-0., 0., 0., ..., -0., 0., -0.],
[0., 0., -0., ..., -0., 0., -0.],
...,
[-0., 0., -0., ..., -0., 0., 0.],
[-0., -0., -0., ..., 0., -0., -0.],
[-0., -0., 0., ..., -0., -0., -0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-3.3777e-02, 4.7079e-03, -2.3445e-02, ..., -6.8087e-03,
-4.6413e-02, -4.1360e-02],
[-1.6733e-02, 5.1208e-03, 3.9803e-02, ..., -4.5116e-02,
1.8346e-03, -1.1031e-02],
[ 2.3320e-02, 4.0388e-03, -4.6767e-02, ..., -3.7066e-02,
3.7666e-02, -1.2776e-02],
...,
[-1.3560e-05, 3.7272e-02, -1.6224e-02, ..., -3.3796e-03,
3.3060e-02, 4.5754e-02],
[-9.2607e-03, -4.9655e-02, -3.0438e-02, ..., 1.7757e-02,
-4.1499e-02, -1.2796e-02],
[-3.5203e-02, -3.5148e-03, 4.2838e-03, ..., -2.5652e-02,
-7.0994e-03, -2.2834e-02]])
(bias): Normal:
loc: tensor([-0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., -0., 0., 0.,
0., 0., 0., -0., 0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., 0.,
0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., -0., 0.,
0., 0., -0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0.,
-0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., 0., 0., -0., -0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0091, -0.0222, -0.0435, -0.0427, 0.0332, -0.0145, 0.0444, -0.0225,
-0.0252, -0.0384, -0.0325, -0.0306, 0.0136, -0.0440, 0.0226, 0.0197,
-0.0346, 0.0281, -0.0020, -0.0201, -0.0482, -0.0426, 0.0018, 0.0065,
0.0293, 0.0076, 0.0418, -0.0487, 0.0359, -0.0118, 0.0179, -0.0176,
0.0471, -0.0438, 0.0071, 0.0235, 0.0245, -0.0395, 0.0206, 0.0396,
-0.0399, -0.0174, 0.0419, 0.0254, -0.0254, -0.0495, -0.0108, 0.0297,
0.0295, 0.0209, -0.0487, 0.0192, 0.0433, -0.0472, -0.0043, -0.0372,
0.0491, -0.0487, 0.0278, 0.0408, 0.0268, -0.0455, 0.0474, -0.0153,
-0.0464, -0.0248, 0.0130, 0.0203, -0.0359, -0.0385, -0.0100, 0.0170,
0.0464, 0.0401, -0.0339, 0.0245, 0.0023, 0.0165, 0.0467, 0.0407,
-0.0495, 0.0334, 0.0381, 0.0333, -0.0411, -0.0329, 0.0327, -0.0023,
0.0222, 0.0137, -0.0227, -0.0370, 0.0094, 0.0082, -0.0387, -0.0141,
-0.0045, -0.0353, -0.0282, 0.0103, -0.0167, -0.0474, -0.0296, -0.0452,
-0.0296, -0.0141, -0.0436, -0.0171, -0.0035, 0.0408, 0.0441, -0.0025,
0.0083, 0.0330, -0.0323, 0.0105, 0.0155, 0.0186, -0.0094, -0.0043])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=10, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 0.0330, -0.0638, 0.0145, ..., -0.0799, 0.0038, -0.0028],
[ 0.0647, -0.0898, 0.0264, ..., 0.0438, 0.0560, 0.0141],
[-0.0616, -0.0866, -0.0050, ..., 0.0206, -0.0182, -0.0549],
...,
[-0.0021, 0.0008, 0.0321, ..., -0.0234, 0.0801, 0.0765],
[-0.0517, -0.0638, 0.0722, ..., -0.0724, -0.0263, -0.0710],
[-0.0566, -0.0889, -0.0738, ..., 0.0263, -0.0151, -0.0537]],
requires_grad=True)
tensor: tensor([[ 0.0217, -0.1135, -0.0030, ..., -0.0561, 0.0181, 0.0389],
[ 0.0759, -0.0429, 0.0090, ..., 0.0921, 0.1001, 0.0168],
[-0.0243, -0.0910, 0.0879, ..., 0.0186, 0.0450, -0.0486],
...,
[ 0.0081, 0.0176, -0.0045, ..., -0.0133, 0.1110, 0.0398],
[-0.0471, -0.0219, 0.0584, ..., -0.0652, 0.0176, -0.0921],
[-0.0964, -0.1910, -0.1096, ..., -0.0227, -0.0748, -0.1086]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0342, -0.0681, 0.0492, -0.0765, -0.0070, -0.0712, -0.0436, 0.0281,
0.0831, 0.0615], requires_grad=True)
tensor: tensor([ 0.0767, -0.0408, 0.1565, -0.0743, -0.0248, -0.0158, -0.0733, -0.0389,
-0.0016, 0.0762], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., 0., ..., -0., 0., -0.],
[0., -0., 0., ..., 0., 0., 0.],
[-0., -0., -0., ..., 0., -0., -0.],
...,
[-0., 0., 0., ..., -0., 0., 0.],
[-0., -0., 0., ..., -0., -0., -0.],
[-0., -0., -0., ..., 0., -0., -0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
...,
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0330, -0.0638, 0.0145, ..., -0.0799, 0.0038, -0.0028],
[ 0.0647, -0.0898, 0.0264, ..., 0.0438, 0.0560, 0.0141],
[-0.0616, -0.0866, -0.0050, ..., 0.0206, -0.0182, -0.0549],
...,
[-0.0021, 0.0008, 0.0321, ..., -0.0234, 0.0801, 0.0765],
[-0.0517, -0.0638, 0.0722, ..., -0.0724, -0.0263, -0.0710],
[-0.0566, -0.0889, -0.0738, ..., 0.0263, -0.0151, -0.0537]])
(bias): Normal:
loc: tensor([0., -0., 0., -0., -0., -0., -0., 0., 0., 0.])
scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0342, -0.0681, 0.0492, -0.0765, -0.0070, -0.0712, -0.0436, 0.0281,
0.0831, 0.0615])
)
(observed): Observed()
)
)
You just have to define the forward function, and the backward
function (where gradients are computed) is automatically defined for you
using autograd.
You can use any of the Tensor operations in the forward function.
The learnable parameters of a model are returned by net.parameters()
params = list(net.parameters())
print(len(params))
print(params[0].size())
Out:
16
torch.Size([6, 1, 5, 5])
Let try a random 32x32 input
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
Out:
tensor([[-0.3077, -0.5381, 0.6875, 0.3936, -1.1296, 0.1405, -0.0033, -1.0102,
0.4813, 0.6858]], grad_fn=<AddmmBackward0>)
Zero the gradient buffers of all parameters and backprops with random gradients:
net.zero_grad()
out.backward(torch.randn(1, 10))
Note
borch.nn only supports mini-batches. The entire borch.nn
package only supports inputs that are a mini-batch of samples, and not
a single sample.
For example, nn.Conv2d will take in a 4D Tensor of
nSamples x nChannels x Height x Width.
If you have a single sample, just use input.unsqueeze(0) to add
a fake batch dimension.
Before proceeding further, let’s recap all the classes you’ve seen so far.
- Recap:
torch.Tensor- A multi-dimensional array with support for autograd operations likebackward(). Also holds the gradient w.r.t. the tensor.nn.Module- Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.nn.Parameter- A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to aModule.autograd.Function- Implements forward and backward definitions of an autograd operation. EveryTensoroperation, creates at least a singleFunctionnode, that connects to functions that created aTensorand encodes its history.
- At this point, we covered:
Defining a neural network
Processing inputs and calling backward
- Still Left:
Computing the loss
Updating the weights of the network
Loss Function¶
A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.
There are several different
loss functions under the
nn package .
A simple loss is: nn.MSELoss which computes the mean-squared error
between the input and the target. They are how ever only equivalent to an maximum
likelihood approach in deep learning.
In order to infer the posterior of the weights and thus capture the uncertainty
of the weights as well, we have to use the infer package. In this example we
will use infer.vi_loss function that automatically creates the best loss function
for variational inference given the latent variables in your model.
Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:
target = torch.randint(10, (1,)) # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)
Out:
tensor(9293.3262, grad_fn=<AddBackward0>)
Now, if you would follow loss in the backward direction you will see a graph of
computations that looks like this:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
-> view -> linear -> relu -> linear ->
-> loss
So, when we call loss.backward(), the whole graph is differentiated
w.r.t. the loss, and all Tensors in the graph that has requires_grad=True
will have their .grad Tensor accumulated with the gradient.
Backprop¶
To backpropagate the error all we have to do is to loss.backward().
You need to clear the existing gradients though, else gradients will be
accumulated to existing gradients.
Now we shall call loss.backward(), and have a look at conv1’s bias
gradients before and after the backward.
net.zero_grad() # zeroes the gradient buffers of all parameters
The value for the loc paramater of the approximating distribution of
conv1.bias zeroing the gradients is
print(net.conv1.posterior.bias.loc.grad)
loss.backward()
Out:
tensor([0., 0., 0., 0., 0., 0.])
after calling backward the value is
print(net.conv1.posterior.bias.loc.grad)
Out:
tensor([-0.6064, -0.4782, 0.9458, 0.5497, 1.2112, -0.2230])
The only thing left to learn is:
Updating the weights of the network
Update the weights¶
The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):
weight = weight - learning_rate * gradient
We can implement this using simple python code:
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
However, as you use neural networks, you want to use various different
update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc.
To enable this, torch built a small package: torch.optim that
implements all these methods. Using it is very simple:
import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
n_batch_epoch = 10 # number of batches per epoch usually len(dataloader)
optimizer.zero_grad() # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step() # Does the update
Exercises¶
The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.
Try designing yor own feed forward networks with two different types of non lineareties ex. relu
Total running time of the script: ( 0 minutes 0.135 seconds)