Note
Click here to download the full example code
Neural Networks¶
Neural networks can be constructed using the borch.nn
package.
Now that you’ve had a glimpse of autograd
, nn
depends on autograd
to define models and differentiate them. An nn.Module
contains layers,
and a method forward(input)
that returns the output
.
For example, look at this network that classifies digit images:
It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.
A typical training procedure for a neural network is as follows:
Define a network that has some learnable parameters and/or randomVariables
For each batch in a dataset, do:
Process the input data through the network
Compute the loss (how far is the output from being correct?)
Propagate gradients back into the network’s parameters
Update the weights of the network, typically using a simple update rule:
weight = weight - learning_rate * gradient
Define the network¶
Let’s define this network:
import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__(posterior=posterior.Automatic())
# 1 input image channel, 6 output channels, 5x5 convolution kernel
self.conv1 = nn.Conv2d(1, 6, 5)
# 6 input channels, 16 output channels, 5x5 convolution kernel
self.conv2 = nn.Conv2d(6, 16, 5)
# An affine operation: y = Wx + b
# NB after two convolutional operations with 5x5 kernels and no padding,
# the spatial dimension of an image with intial dimension 32x32 is
# 5x5 (with 16 channels)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = self.fc2(x)
# Specifying the likelihood function
self.classification = distributions.Categorical(logits=x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Out:
Net(
(posterior): Automatic()
(prior): Module()
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[ 0.1680, 0.1718, 0.0098, -0.0428, 0.0831],
[ 0.0618, 0.1618, -0.0325, -0.1448, -0.1646],
[-0.0362, -0.0103, -0.1824, 0.1346, 0.0135],
[-0.0389, 0.0558, -0.0680, -0.1515, 0.1680],
[-0.1748, 0.1496, -0.1459, 0.0181, -0.1115]]],
[[[ 0.1081, -0.0935, -0.1567, 0.1787, 0.0230],
[-0.1890, -0.0129, 0.0133, 0.0472, 0.0424],
[-0.0680, -0.0616, -0.0199, 0.0837, 0.1960],
[ 0.0917, 0.0248, 0.0921, -0.0139, 0.0422],
[-0.0997, -0.0637, 0.1865, -0.1982, 0.1430]]],
[[[-0.1463, -0.0137, 0.1728, -0.0899, -0.0427],
[-0.0654, -0.0375, -0.1964, -0.1289, -0.0483],
[ 0.0495, 0.0236, 0.1467, 0.1145, -0.0054],
[-0.0240, -0.0058, -0.0281, 0.1004, -0.0038],
[ 0.0126, 0.0019, -0.1089, -0.1572, -0.1472]]],
[[[-0.0682, -0.0487, 0.0365, -0.1739, -0.0511],
[-0.0746, 0.1603, 0.0149, 0.1684, 0.0557],
[-0.1502, -0.1176, -0.0681, 0.1512, -0.0974],
[-0.0197, -0.1146, 0.0302, -0.0089, -0.0336],
[ 0.1591, 0.1478, -0.1726, -0.0903, 0.0923]]],
[[[-0.1594, 0.1311, 0.0003, -0.0176, -0.1664],
[ 0.0809, 0.1795, 0.1855, 0.1246, -0.0944],
[-0.1488, -0.0044, 0.0512, -0.1854, 0.0012],
[-0.0820, 0.1209, -0.0562, 0.0859, -0.1175],
[ 0.0652, 0.1324, -0.1412, 0.1655, 0.1654]]],
[[[-0.0019, -0.0864, 0.1593, 0.1857, -0.1915],
[ 0.1527, -0.0412, -0.0173, 0.0072, -0.0933],
[-0.1378, -0.1514, 0.1343, -0.0660, -0.0785],
[ 0.1723, 0.0924, 0.1815, -0.0553, 0.1033],
[ 0.0797, 0.0611, 0.0834, -0.1489, 0.1146]]]], requires_grad=True)
tensor: tensor([[[[ 0.1549, 0.1702, 0.0258, -0.0510, 0.1557],
[ 0.0103, 0.1476, 0.0522, -0.2040, -0.1602],
[-0.0236, 0.0437, -0.2519, 0.0928, -0.0489],
[-0.0862, 0.1147, -0.0815, -0.1318, 0.2126],
[-0.2679, 0.1714, -0.1526, 0.0422, -0.1586]]],
[[[ 0.1015, -0.1919, -0.1668, 0.2028, -0.0101],
[-0.1396, -0.0280, 0.0852, 0.0446, 0.0493],
[-0.0921, -0.0273, -0.0788, 0.0889, 0.2471],
[ 0.0824, 0.0427, 0.1677, -0.0809, -0.0029],
[-0.0677, -0.1122, 0.1575, -0.2243, 0.1292]]],
[[[-0.1460, -0.0464, 0.1325, -0.0497, -0.0142],
[-0.0287, 0.0196, -0.2510, -0.1181, 0.0341],
[ 0.0990, 0.0839, 0.1291, 0.1391, -0.0971],
[-0.0203, 0.0117, 0.0436, 0.1213, -0.0132],
[ 0.0082, -0.0122, -0.1402, -0.2093, -0.1193]]],
[[[-0.0340, -0.0688, 0.0707, -0.1087, -0.0648],
[-0.0878, 0.2453, 0.0498, 0.1437, 0.0903],
[-0.1033, -0.1412, -0.0248, 0.1735, -0.0668],
[ 0.0400, -0.0714, 0.0677, 0.0780, -0.0020],
[ 0.1382, 0.1072, -0.1638, -0.0979, 0.0906]]],
[[[-0.2221, 0.1444, -0.0130, -0.0024, -0.1279],
[ 0.0786, 0.2368, 0.2252, 0.1285, -0.0604],
[-0.0719, -0.0685, 0.0014, -0.2441, -0.0323],
[-0.0953, 0.1052, 0.0052, 0.0528, -0.1838],
[ 0.1155, 0.1910, -0.2823, 0.1724, 0.1009]]],
[[[-0.0248, -0.0756, 0.0753, 0.1658, -0.1588],
[ 0.0979, -0.0104, -0.0249, 0.0274, -0.1429],
[-0.1627, -0.1360, 0.1719, -0.1276, -0.0804],
[ 0.1876, 0.0270, 0.1819, -0.0160, 0.1198],
[ 0.1410, 0.0349, 0.1013, -0.1741, 0.2346]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0894, 0.1877, -0.0535, 0.0308, 0.1593, 0.0810],
requires_grad=True)
tensor: tensor([ 0.1004, 0.2790, -0.0837, -0.0029, 0.1159, 0.0025],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[0., 0., 0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., 0., -0.]]],
[[[0., -0., -0., 0., 0.],
[-0., -0., 0., 0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., -0., 0.]]],
[[[-0., -0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[0., 0., -0., -0., -0.]]],
[[[-0., -0., 0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., 0.]]],
[[[-0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.]]],
[[[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., 0., -0., 0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[ 0.1680, 0.1718, 0.0098, -0.0428, 0.0831],
[ 0.0618, 0.1618, -0.0325, -0.1448, -0.1646],
[-0.0362, -0.0103, -0.1824, 0.1346, 0.0135],
[-0.0389, 0.0558, -0.0680, -0.1515, 0.1680],
[-0.1748, 0.1496, -0.1459, 0.0181, -0.1115]]],
[[[ 0.1081, -0.0935, -0.1567, 0.1787, 0.0230],
[-0.1890, -0.0129, 0.0133, 0.0472, 0.0424],
[-0.0680, -0.0616, -0.0199, 0.0837, 0.1960],
[ 0.0917, 0.0248, 0.0921, -0.0139, 0.0422],
[-0.0997, -0.0637, 0.1865, -0.1982, 0.1430]]],
[[[-0.1463, -0.0137, 0.1728, -0.0899, -0.0427],
[-0.0654, -0.0375, -0.1964, -0.1289, -0.0483],
[ 0.0495, 0.0236, 0.1467, 0.1145, -0.0054],
[-0.0240, -0.0058, -0.0281, 0.1004, -0.0038],
[ 0.0126, 0.0019, -0.1089, -0.1572, -0.1472]]],
[[[-0.0682, -0.0487, 0.0365, -0.1739, -0.0511],
[-0.0746, 0.1603, 0.0149, 0.1684, 0.0557],
[-0.1502, -0.1176, -0.0681, 0.1512, -0.0974],
[-0.0197, -0.1146, 0.0302, -0.0089, -0.0336],
[ 0.1591, 0.1478, -0.1726, -0.0903, 0.0923]]],
[[[-0.1594, 0.1311, 0.0003, -0.0176, -0.1664],
[ 0.0809, 0.1795, 0.1855, 0.1246, -0.0944],
[-0.1488, -0.0044, 0.0512, -0.1854, 0.0012],
[-0.0820, 0.1209, -0.0562, 0.0859, -0.1175],
[ 0.0652, 0.1324, -0.1412, 0.1655, 0.1654]]],
[[[-0.0019, -0.0864, 0.1593, 0.1857, -0.1915],
[ 0.1527, -0.0412, -0.0173, 0.0072, -0.0933],
[-0.1378, -0.1514, 0.1343, -0.0660, -0.0785],
[ 0.1723, 0.0924, 0.1815, -0.0553, 0.1033],
[ 0.0797, 0.0611, 0.0834, -0.1489, 0.1146]]]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., 0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0894, 0.1877, -0.0535, 0.0308, 0.1593, 0.0810])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
...,
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[ 4.2191e-02, 7.6840e-02, -7.4705e-02, -2.1865e-02, 6.1931e-02],
[ 4.5672e-02, -4.5500e-02, 6.8328e-02, -1.9356e-02, 6.2420e-03],
[-9.1323e-03, 6.4114e-02, -6.2778e-02, -2.9379e-02, 9.7347e-03],
[ 1.5921e-02, -2.0886e-02, -6.3837e-02, -4.3951e-02, 5.5913e-04],
[-1.2263e-02, -4.8414e-02, -4.4003e-02, -2.4253e-02, 3.2497e-02]],
[[ 2.3382e-02, 3.9259e-02, -7.7953e-02, -3.2520e-02, 1.8270e-02],
[-4.3315e-02, 7.9060e-02, -5.8311e-02, -2.8759e-02, -8.7089e-03],
[ 5.5321e-02, 4.8438e-02, 6.5735e-02, -4.9760e-02, -3.5508e-02],
[ 4.1005e-02, 4.6544e-02, 1.5205e-02, -3.6013e-02, 6.9640e-03],
[-6.8387e-02, -4.9177e-02, 4.0450e-02, 7.0279e-02, 7.5872e-02]],
[[-4.6557e-03, 1.8276e-02, -5.8220e-02, -2.1548e-02, 5.6864e-02],
[-1.3140e-02, 3.1058e-02, 1.7065e-02, -6.0197e-02, -9.2233e-03],
[-5.6885e-02, -6.7734e-02, 2.2703e-02, 7.8996e-02, -2.3009e-02],
[-3.2973e-02, 3.9733e-02, -4.7892e-02, 1.3657e-02, -6.1810e-02],
[ 6.4573e-02, -2.9753e-02, -5.5327e-02, -9.4562e-03, 2.1779e-02]],
[[ 7.1124e-02, 4.0486e-02, 6.9724e-02, 7.6075e-02, 4.1246e-03],
[ 8.3202e-03, 2.3893e-02, -6.6808e-03, -4.9518e-02, 3.8584e-02],
[-5.2714e-02, 5.0159e-02, -6.6929e-02, -7.0058e-02, -4.7074e-02],
[ 7.5636e-02, 2.7516e-02, -5.3383e-02, -7.9865e-02, 7.8899e-02],
[ 6.7598e-02, 2.2288e-02, -4.3938e-02, -2.2255e-02, 3.8803e-02]],
[[-1.8988e-02, 6.3276e-02, -6.4471e-02, -3.8395e-02, -5.1800e-02],
[ 2.6698e-02, 5.7260e-03, -2.2055e-02, 2.9274e-02, 1.7332e-02],
[-1.0218e-02, -2.7324e-02, -2.9080e-02, 1.3222e-02, 7.9084e-02],
[ 4.3772e-02, 7.7684e-02, 4.8257e-02, -6.8435e-02, -2.0343e-02],
[ 3.0093e-02, 3.1836e-03, 2.6288e-03, -8.2295e-03, -2.5020e-02]],
[[-3.8859e-02, 5.2245e-02, -7.8403e-02, -4.2871e-02, 6.1650e-02],
[-4.8324e-02, 5.1745e-02, 5.0752e-02, 7.3504e-02, -2.3351e-02],
[ 7.5441e-02, -5.1618e-02, 8.0429e-02, 7.3450e-02, 7.8664e-02],
[ 4.9288e-02, -1.9266e-02, 2.9147e-02, -8.0626e-02, -3.1476e-02],
[ 6.6063e-02, 1.7808e-02, 2.6498e-02, 4.6047e-03, 7.2748e-02]]],
[[[ 2.7147e-02, -6.4806e-02, -1.7023e-02, 4.9091e-03, -6.6140e-03],
[-7.7151e-02, -1.2995e-02, -4.5008e-02, -1.8290e-02, 5.0563e-02],
[-7.6304e-02, -8.0019e-02, -1.1984e-02, -3.8425e-02, -4.5299e-03],
[-1.4411e-02, 5.2870e-02, 7.4521e-02, 4.9850e-02, -2.3332e-02],
[ 3.2249e-03, 4.7371e-02, -5.4040e-02, 7.8740e-02, 6.2129e-02]],
[[-5.5597e-02, -7.9082e-03, 5.7858e-03, -1.7938e-02, 3.2981e-02],
[-5.7644e-02, -7.8341e-02, -5.2250e-02, -7.0176e-02, 9.1169e-03],
[ 4.8751e-03, 2.7214e-02, -7.2253e-02, -7.9531e-02, 2.2115e-02],
[-8.0066e-02, 6.1143e-03, 9.5127e-03, 7.6738e-02, -5.4762e-02],
[-4.0489e-02, -4.4740e-02, 7.4822e-02, 5.8303e-02, 3.4314e-02]],
[[ 3.2335e-02, -7.6859e-03, 7.2277e-02, 6.4942e-02, -2.1889e-02],
[ 6.7059e-02, -2.5517e-02, -7.2487e-02, 6.3138e-02, 1.8800e-02],
[-4.1668e-02, 3.3260e-02, 2.0425e-03, -9.8570e-03, -5.2738e-02],
[ 5.0590e-02, -1.0069e-02, -1.0559e-02, 1.4182e-02, 5.3266e-02],
[-1.0744e-02, -2.3853e-02, 1.9477e-02, 1.2789e-02, -1.9533e-03]],
[[-7.1060e-03, -4.1691e-02, 3.4899e-02, -7.2925e-02, -5.2534e-02],
[-4.6110e-02, 7.5498e-02, 6.1524e-02, 1.1099e-02, 7.3727e-02],
[ 5.4388e-02, 3.9116e-02, -6.6412e-02, -2.6689e-02, -2.0881e-02],
[ 2.6913e-02, 3.5247e-02, -2.6048e-02, 1.4267e-02, 2.6323e-02],
[ 6.5432e-03, -7.3995e-03, -4.3528e-02, 5.0431e-02, 6.9511e-02]],
[[-2.8527e-02, 6.1429e-02, 6.0427e-02, -4.1141e-02, -7.3619e-02],
[ 2.1947e-02, -3.8028e-02, -6.7364e-02, 1.5408e-02, 5.2976e-02],
[ 5.8108e-02, -6.1758e-02, 1.4653e-02, -7.6558e-02, -3.1448e-02],
[ 6.2695e-02, 6.4849e-02, 3.1678e-02, 8.7257e-03, 8.9366e-03],
[ 4.3787e-02, -2.6563e-02, 7.6010e-02, -4.9455e-02, 4.6068e-02]],
[[-4.6879e-02, -1.0211e-02, -5.7937e-02, -6.0173e-02, -5.6010e-03],
[ 6.7886e-02, 4.8736e-02, 7.5209e-02, -2.4597e-02, 1.2666e-02],
[ 2.7294e-02, 5.9385e-02, 8.1056e-02, 7.2644e-02, -7.7046e-02],
[ 1.7380e-02, -3.2567e-02, -2.1987e-02, 7.8235e-02, -5.2908e-02],
[-2.0811e-02, -2.9994e-02, -4.6838e-02, -4.2671e-02, 3.8727e-02]]],
[[[ 3.5134e-02, 5.2745e-02, -6.1149e-02, -7.6414e-02, -2.6176e-02],
[-6.1641e-02, -8.1041e-02, -2.5100e-02, -7.2410e-02, 6.1282e-02],
[-3.0138e-03, 5.0077e-02, -8.1627e-02, -4.3763e-02, -3.0137e-03],
[-4.0017e-02, -4.3606e-02, -5.0771e-02, -1.0728e-02, 6.0225e-02],
[-6.3614e-02, 5.9101e-02, -1.1370e-02, 6.4711e-02, -1.0095e-02]],
[[-5.0975e-02, 6.0192e-02, 4.8900e-02, 8.0230e-02, -1.5046e-02],
[ 7.2609e-02, 3.1814e-02, 6.7249e-02, 5.9835e-02, 7.7418e-02],
[ 2.9254e-02, 3.7293e-02, 1.1221e-02, 3.3947e-03, 4.2694e-02],
[-6.1252e-02, 3.9850e-02, 8.3872e-03, 6.4936e-02, 5.5738e-02],
[-3.8509e-02, -5.6070e-03, 6.8840e-02, 6.8610e-02, -4.3579e-02]],
[[ 1.8686e-02, 2.1468e-02, 5.9046e-02, 5.2732e-02, 5.1538e-02],
[ 7.0772e-02, 4.9593e-02, 4.1163e-02, 6.8360e-02, -3.9729e-02],
[-4.0441e-02, -1.5649e-02, 7.2554e-02, -2.2384e-02, 3.3869e-02],
[-1.1148e-02, 7.9215e-02, 1.0318e-02, -4.9182e-02, 9.6277e-03],
[ 4.1858e-02, 3.0790e-02, -3.0381e-02, 6.6102e-02, -5.1207e-02]],
[[-6.6738e-03, 6.4776e-02, -3.3882e-02, 7.1115e-02, 3.1006e-02],
[-2.2906e-02, -2.9046e-02, 3.5319e-02, 5.3543e-02, 2.1489e-02],
[ 3.3100e-02, -3.5833e-02, -1.8264e-02, 6.3019e-03, 3.8628e-02],
[-3.4829e-02, -2.0159e-02, -4.5294e-02, 1.4057e-02, -7.9188e-02],
[-3.1353e-02, -1.8218e-02, -8.0638e-02, -5.0035e-02, -3.4570e-02]],
[[ 6.1931e-03, -3.6366e-03, 2.6574e-02, -1.3003e-02, 6.1312e-02],
[ 4.6450e-02, -4.6889e-02, -2.5151e-02, -1.6860e-02, 6.3430e-02],
[-6.3583e-02, 3.9752e-02, -4.2166e-02, -2.7013e-02, -3.3751e-03],
[ 5.6220e-02, 8.1038e-02, 7.8622e-02, 5.7729e-02, 3.2215e-02],
[-3.7003e-02, -6.6767e-02, 6.3296e-02, -7.1348e-02, -3.1742e-03]],
[[ 7.3035e-02, 4.1458e-02, -4.6261e-02, 3.0164e-02, -1.8932e-03],
[ 3.8696e-02, 5.4963e-02, 4.5004e-02, -5.6026e-02, 1.2815e-02],
[-8.0065e-02, 1.8459e-02, 7.6491e-02, 7.3781e-02, 5.5816e-02],
[ 4.5870e-02, -8.1061e-02, -4.7139e-02, 5.6221e-02, 6.0546e-02],
[-6.1282e-02, -2.2380e-02, -6.3779e-02, -6.8776e-02, -4.3537e-02]]],
...,
[[[ 5.2454e-02, -4.7911e-02, -7.2297e-02, -2.0057e-02, -5.5949e-02],
[-1.7783e-02, -8.5585e-03, 2.1769e-02, 7.7195e-02, -2.6911e-02],
[-5.5650e-02, -7.1240e-02, -4.3191e-02, 7.2941e-02, 7.0596e-04],
[-1.1619e-02, 9.3917e-03, -2.8626e-02, -2.5418e-02, -1.0257e-02],
[ 5.1407e-02, 4.8896e-02, 6.0359e-02, 3.2419e-02, -1.9229e-02]],
[[ 1.5427e-02, 4.6055e-02, 3.8077e-02, 1.9711e-02, -4.7648e-02],
[-2.7328e-02, -7.1221e-02, 5.6410e-02, -2.5372e-02, -4.0767e-02],
[-3.1127e-02, 6.7051e-02, -7.9245e-02, 6.6682e-02, -2.6589e-03],
[ 6.9368e-02, -4.1969e-02, 5.7844e-02, 4.0387e-02, -5.6537e-02],
[ 1.3153e-02, 4.0302e-02, 7.1897e-02, 2.1402e-02, 5.5573e-02]],
[[-4.6272e-02, 2.1553e-02, 1.4332e-03, -4.3438e-02, -2.7539e-03],
[-3.4410e-02, -7.9327e-02, -3.9320e-02, 7.7251e-02, 2.2694e-02],
[-7.6279e-02, -3.1001e-02, 6.8287e-02, -7.7937e-02, -6.5913e-02],
[-5.2818e-02, -3.7519e-02, -2.3019e-02, -3.6675e-02, 7.9680e-02],
[-6.1808e-02, 5.1876e-02, -7.8567e-02, 8.1565e-02, -6.8310e-02]],
[[ 5.5309e-02, 8.7365e-03, -5.6561e-03, 7.1133e-02, -4.6917e-02],
[ 2.1503e-03, 8.4254e-03, 1.1892e-02, 7.9317e-02, -7.8111e-03],
[-5.6072e-02, -6.5883e-02, 3.5538e-02, -7.8521e-02, -6.9720e-02],
[-1.9393e-02, -7.3491e-02, 8.1700e-03, -1.9778e-03, 2.8227e-02],
[-1.6722e-02, 5.5947e-02, 6.2404e-03, -5.1478e-03, -2.5255e-02]],
[[-7.8407e-02, 7.2449e-02, -6.2823e-02, -6.6556e-03, -1.6768e-02],
[-3.4428e-02, -1.3495e-02, 6.7206e-04, 2.5644e-02, -3.1432e-02],
[-7.0893e-02, 2.9846e-02, -5.0699e-03, 2.9505e-02, -1.8095e-02],
[ 6.7203e-02, 5.6675e-03, -3.5811e-02, 1.9892e-02, 5.2260e-02],
[-6.8506e-02, -7.1065e-02, 3.3085e-02, 4.3750e-02, 1.6927e-02]],
[[-6.5512e-02, -5.6648e-02, 4.6444e-02, 7.4754e-02, 1.1064e-02],
[ 6.3824e-02, 6.5449e-02, -4.9581e-02, -1.4683e-02, -4.2090e-02],
[-4.7219e-02, -3.5248e-02, 3.6657e-04, -2.6233e-02, 7.6344e-02],
[-6.2184e-02, -6.0576e-02, 8.0155e-02, -4.5559e-02, 5.8270e-02],
[ 3.0984e-02, 5.9146e-02, 5.2023e-02, -8.1066e-02, 1.9684e-02]]],
[[[ 7.3765e-02, 3.5141e-03, -2.1752e-03, 7.4100e-02, 7.4641e-02],
[ 4.8591e-02, 7.6598e-02, 2.9565e-02, 2.1044e-02, -5.0159e-02],
[-6.1025e-02, -2.8593e-02, 2.4891e-02, -6.0556e-02, 6.2090e-02],
[-1.5755e-02, 4.0377e-02, 3.5295e-03, 5.3803e-02, -6.6197e-02],
[ 4.4328e-02, -1.8031e-02, -7.4217e-02, -9.6550e-03, -6.4711e-03]],
[[-7.8584e-02, -4.6375e-02, -3.3496e-02, -6.2696e-02, 3.7588e-02],
[-1.4592e-02, -2.4390e-02, -5.0144e-02, -5.0049e-02, 3.2448e-02],
[ 7.5607e-02, -6.8160e-02, -1.8825e-02, 6.0839e-02, 2.1488e-02],
[-6.6120e-02, 6.7261e-02, -5.8349e-02, 4.2199e-02, -5.2458e-02],
[-4.8494e-02, -5.4509e-02, 9.2273e-05, -3.4371e-02, 5.9357e-04]],
[[-7.8414e-02, 7.0503e-02, -3.8403e-02, -4.4648e-02, 7.5146e-02],
[-5.6907e-02, 3.4213e-05, -6.8368e-02, -5.0295e-02, -4.3787e-02],
[-5.8034e-02, -6.0247e-02, -5.8353e-02, -2.5280e-02, 6.9298e-02],
[-1.4764e-03, 4.3781e-02, -2.5240e-02, -4.3813e-02, 2.8583e-02],
[-4.1418e-02, -4.6101e-02, 2.7372e-03, -3.2267e-02, 5.5665e-02]],
[[-7.9642e-02, 5.8555e-02, 5.8883e-02, -6.3063e-02, -3.0692e-02],
[ 5.0561e-02, -6.8027e-02, 4.7893e-02, -6.2822e-02, 1.0364e-02],
[-6.9147e-02, -7.2420e-02, 7.8318e-02, 6.8844e-02, 5.5427e-03],
[ 6.0117e-02, -3.5221e-02, 7.4160e-02, -1.2965e-02, 6.2557e-02],
[ 4.4865e-04, -6.0832e-02, 7.6485e-03, 6.0057e-02, 2.3943e-02]],
[[-2.4230e-02, 6.7864e-02, 7.2121e-02, 3.4157e-02, -7.9538e-02],
[-3.3560e-02, 1.1223e-02, -1.7517e-02, -5.4311e-02, 6.5203e-02],
[-7.7116e-02, -6.3548e-02, 3.0147e-02, -5.6751e-02, -4.5155e-02],
[ 7.9671e-02, -6.0086e-02, -6.9947e-02, 1.2981e-02, -2.5239e-02],
[-2.4339e-03, 7.9699e-02, 2.4683e-02, -4.5459e-02, 9.2444e-03]],
[[ 4.2156e-02, -5.2864e-02, 7.0858e-02, 7.3412e-02, -1.2266e-02],
[ 4.9417e-02, 4.1968e-02, -6.1616e-02, -1.7818e-02, -2.4427e-02],
[ 2.1473e-02, -1.1306e-02, 4.7061e-02, -2.3690e-02, -6.6106e-02],
[ 8.0653e-02, -6.1394e-02, -9.6741e-03, 3.2104e-02, -3.5300e-02],
[-3.7291e-02, -1.6968e-02, 1.3973e-02, 5.5290e-02, 4.0622e-03]]],
[[[ 2.2983e-02, 1.0195e-02, 4.0193e-02, -7.5025e-02, -1.5421e-03],
[ 6.0510e-02, 1.8363e-02, -3.8094e-02, 4.5445e-02, -2.9622e-03],
[-5.3075e-02, 3.4295e-02, 5.0654e-02, -1.9342e-02, 3.2361e-02],
[ 7.3380e-02, -1.0884e-02, -5.3324e-02, -5.0394e-02, 3.1872e-02],
[-4.9773e-02, -2.8900e-02, -2.6879e-02, -2.8097e-02, 6.6398e-02]],
[[-2.5804e-02, -4.6905e-02, -4.2523e-02, -5.0381e-02, -2.0208e-02],
[-4.8815e-02, 2.2532e-02, 6.9881e-02, 2.1225e-02, 4.3858e-02],
[ 3.8482e-02, 7.5890e-03, 1.4969e-02, 1.8850e-02, -2.9226e-02],
[ 3.2008e-02, 2.3823e-02, 7.4640e-02, -3.2508e-02, -5.9983e-02],
[ 2.0176e-02, -2.4570e-02, -4.5569e-02, 6.8979e-02, -4.4682e-02]],
[[-4.6043e-02, -3.2025e-02, 2.8151e-02, -4.8214e-02, 7.1796e-02],
[ 6.4546e-02, -2.4392e-03, 6.5753e-02, 1.0233e-02, 3.1131e-02],
[-2.9636e-02, 7.3704e-02, -7.7781e-02, 8.0873e-02, -8.5714e-04],
[ 3.0164e-02, -6.3718e-02, 8.1516e-02, 1.4934e-02, 6.1368e-02],
[ 4.9876e-02, 9.5079e-03, 7.7987e-02, -3.8619e-02, 5.9578e-02]],
[[-4.2667e-03, -4.3751e-02, -2.8233e-02, -4.1592e-02, -6.0769e-02],
[ 4.8450e-02, -1.2896e-02, -4.6601e-03, -6.4206e-03, -1.1431e-02],
[-7.7503e-02, -1.0773e-02, -8.0490e-03, -3.1181e-03, 1.3285e-02],
[ 6.8160e-02, 6.7861e-02, 1.4774e-02, -1.3599e-02, -2.4826e-02],
[ 4.7705e-02, -7.3378e-02, -1.3825e-02, -3.1077e-02, -3.7972e-02]],
[[ 5.6202e-02, 4.1837e-02, -2.8773e-02, 7.2132e-02, 2.5515e-02],
[-3.3901e-02, 7.0663e-03, 6.8551e-02, -9.9631e-03, 3.0268e-02],
[-7.2736e-02, 2.3806e-02, 2.9459e-02, 6.0356e-02, -4.3899e-03],
[-2.8373e-02, -1.1198e-02, -1.2081e-02, 1.7582e-02, -6.8985e-02],
[-7.2981e-02, -5.7756e-02, -2.6205e-02, 5.1297e-02, -5.9565e-02]],
[[ 6.5016e-03, 6.3833e-02, 1.0038e-02, -7.4870e-02, -8.0244e-02],
[-1.2594e-02, 2.0456e-02, -4.8056e-02, 1.1776e-02, -7.6218e-02],
[ 5.1599e-02, -9.6947e-03, -1.6420e-02, -5.6519e-02, -7.0490e-02],
[-5.8097e-02, 7.7906e-03, -3.3099e-03, 1.3623e-02, -5.2314e-02],
[-1.3419e-02, 6.9299e-02, 7.2203e-02, 4.8437e-02, -4.1129e-02]]]],
requires_grad=True)
tensor: tensor([[[[-4.6757e-02, 1.3288e-01, -1.0601e-01, 4.9272e-03, 2.7013e-02],
[ 2.6809e-02, 2.9295e-02, 3.6850e-02, -7.2724e-03, 4.7542e-02],
[-9.0670e-02, 9.5702e-02, -3.9494e-03, -1.5364e-02, 6.0347e-03],
[-3.4580e-02, -3.7414e-02, -1.1697e-01, -2.5932e-02, 1.2895e-01],
[ 2.3890e-02, -9.9844e-02, -8.2125e-02, -2.4534e-02, -1.7775e-02]],
[[-6.7097e-03, 1.0264e-01, -3.6296e-02, -7.4812e-03, -2.2599e-02],
[-3.5118e-02, 6.8040e-02, -1.5194e-01, -6.6822e-02, 3.5745e-02],
[ 1.6445e-01, 9.8256e-02, 1.2114e-01, -4.1316e-02, -1.6552e-02],
[ 2.3123e-02, 3.2617e-02, -6.9864e-02, 7.1308e-04, -1.5938e-02],
[-1.2007e-01, -4.5737e-02, -3.4723e-02, 1.0539e-01, 7.5069e-02]],
[[ 5.1764e-03, -4.7447e-02, -9.1823e-02, 1.9314e-02, -1.2899e-02],
[ 2.8299e-02, 9.8341e-02, 4.5279e-02, -7.5509e-02, -6.5473e-02],
[-8.1838e-02, -1.3693e-01, 5.8116e-02, 7.9483e-02, 1.4864e-02],
[ 8.3024e-03, -2.8969e-02, 1.6709e-02, 8.6332e-02, -1.2912e-01],
[ 9.1989e-02, -1.0204e-01, -3.5730e-02, 2.1878e-02, 3.9299e-02]],
[[-8.4570e-03, 7.0740e-02, 9.5664e-02, 1.1943e-01, 7.8352e-03],
[-5.2236e-02, 3.5574e-02, -6.2162e-02, -2.6292e-02, 4.3690e-02],
[ 4.4340e-03, 4.6677e-02, -7.7338e-02, -1.2434e-01, 1.8030e-02],
[ 5.7909e-02, 9.0265e-02, -2.7924e-03, -1.1392e-01, 1.0558e-01],
[ 5.5833e-02, 5.4252e-02, -9.2561e-02, -2.0060e-02, 4.0309e-02]],
[[-4.0318e-02, 5.0439e-02, -1.5511e-01, -7.2652e-02, -7.1681e-02],
[ 3.3902e-02, -1.9146e-02, -2.6171e-02, 2.6969e-02, 1.1980e-01],
[-1.7600e-02, -3.5479e-02, -1.3453e-01, 3.0411e-02, -1.8185e-02],
[ 3.5787e-02, 9.1613e-02, 3.4231e-02, -8.5526e-02, 8.9608e-02],
[ 8.0859e-02, -5.1546e-02, -7.0114e-02, 1.7646e-02, -2.3049e-03]],
[[-3.7738e-02, 5.5325e-02, -1.2705e-01, -1.4088e-01, -3.4257e-03],
[-4.7597e-02, -6.1689e-02, 1.0258e-02, 8.6211e-02, 5.5421e-02],
[ 1.0676e-01, -1.0346e-01, 6.0882e-02, 5.1857e-02, 1.1405e-02],
[-4.7731e-02, -6.7521e-02, 7.2658e-02, -5.9263e-02, -9.3708e-02],
[ 6.4228e-02, 7.7473e-02, 3.8897e-02, -2.7823e-02, 3.3962e-02]]],
[[[ 3.2504e-03, -6.7740e-02, -3.0733e-02, 3.1548e-02, 8.1641e-02],
[-2.0448e-01, -1.4400e-02, -4.9462e-02, -3.7033e-02, 7.9739e-02],
[-8.0209e-02, -8.2590e-02, -1.0695e-02, 4.0322e-02, -7.3355e-02],
[-2.8050e-02, 2.3734e-02, 9.8712e-02, 7.6730e-02, -1.4703e-01],
[-5.6083e-02, 6.0224e-02, -8.6158e-02, 8.4946e-02, 2.6570e-02]],
[[ 1.3426e-02, -4.9485e-03, 2.7120e-02, -4.2500e-02, 3.8941e-02],
[-6.6332e-02, -9.4494e-02, -9.7996e-02, -8.4159e-02, 3.6172e-03],
[ 9.4459e-02, -1.8795e-02, -1.0849e-01, -1.5571e-01, -4.1801e-02],
[-7.7880e-02, 1.8528e-02, -2.2428e-03, 7.7880e-02, -7.2803e-02],
[-7.4748e-02, -1.1186e-01, 4.1013e-02, 3.6887e-02, 8.4037e-02]],
[[ 6.5337e-02, 4.5208e-03, 1.5117e-01, 8.8387e-02, -2.8306e-02],
[ 3.3486e-02, -5.9216e-02, -6.5075e-02, 6.9757e-02, -1.8149e-02],
[-8.1113e-02, 1.6794e-01, 1.2086e-02, 1.0645e-02, 3.0739e-03],
[ 1.4324e-02, -7.6908e-03, -6.0393e-02, 8.6236e-03, 7.2454e-03],
[ 2.7266e-02, -4.2846e-04, 1.1333e-02, 1.0053e-01, 2.4300e-02]],
[[ 3.1538e-02, -6.9510e-02, 1.1773e-01, 5.9435e-02, -1.8069e-02],
[-1.0187e-01, 5.6848e-02, 1.0995e-01, -4.2845e-02, -4.2273e-02],
[ 1.0056e-01, 1.0755e-01, -4.1334e-02, -6.9923e-02, -7.8262e-02],
[ 3.8990e-02, -1.0840e-02, 2.0285e-02, -7.4898e-02, 1.0637e-01],
[-2.6133e-02, 2.1303e-03, -5.4463e-02, 5.1494e-02, 2.1713e-02]],
[[-5.6810e-02, 6.5028e-02, 8.5436e-03, -5.1192e-03, -3.0413e-02],
[-2.1736e-02, -4.8212e-02, -6.9891e-02, 1.9763e-02, -6.8881e-02],
[ 1.3820e-01, -4.0725e-02, 6.3567e-03, 1.8803e-02, -6.3437e-03],
[ 1.3227e-01, -4.8645e-02, 1.0116e-01, 3.3921e-02, 3.4794e-03],
[ 1.1039e-01, 3.6667e-02, 2.6818e-02, -1.5001e-01, 1.5748e-01]],
[[-1.5231e-01, -8.2144e-02, -4.9698e-02, -1.3508e-01, -4.8011e-03],
[ 5.2455e-02, -5.3886e-05, 8.3577e-02, 1.7951e-02, -1.8239e-02],
[ 7.6571e-02, 5.1307e-02, 1.5132e-01, 1.0276e-01, 2.4194e-03],
[ 7.5639e-02, -1.4870e-02, 5.7076e-02, 3.3924e-02, -3.2090e-02],
[-1.3698e-01, 1.3499e-02, -9.5915e-02, -1.4111e-01, 5.8600e-02]]],
[[[ 2.0662e-02, 6.4280e-02, 1.9997e-02, -4.7385e-02, -1.3667e-02],
[-8.8869e-02, -7.1822e-02, -1.2350e-01, -7.7392e-02, 1.3802e-01],
[-9.8538e-02, 4.7742e-02, -8.0551e-02, 1.5743e-02, -2.6337e-02],
[-1.6897e-02, -9.6953e-02, -1.0730e-01, 8.2337e-02, 5.0256e-02],
[-3.2799e-02, 5.4430e-02, -4.8982e-02, 2.5748e-02, 1.6828e-02]],
[[-1.7873e-01, -1.4042e-02, 9.7526e-02, 6.6798e-02, -9.6164e-02],
[ 9.1119e-02, 3.2205e-02, 1.5595e-01, 5.7516e-02, 1.5295e-01],
[-2.6976e-02, -4.3605e-02, -3.6130e-02, 2.6217e-03, 5.0976e-02],
[-5.5043e-02, 1.9190e-01, -5.0940e-03, -3.5246e-02, 1.2800e-01],
[-4.0770e-02, -5.8901e-02, 1.1812e-01, 4.8643e-02, -7.3327e-02]],
[[ 1.2682e-01, 1.2839e-01, 6.2409e-02, 1.0035e-01, 1.0129e-01],
[ 7.2187e-02, 6.0281e-02, 1.3148e-01, 7.6282e-02, -1.5700e-02],
[ 6.6575e-04, -1.3292e-01, 1.0657e-02, -3.7122e-02, 8.5044e-02],
[ 3.2755e-03, 9.3671e-02, 3.9127e-03, 9.0736e-02, 9.7792e-02],
[ 1.3001e-01, 3.8121e-02, 5.0893e-02, 1.9933e-02, -7.5944e-02]],
[[-9.5800e-02, 1.5913e-02, -4.8295e-02, 3.4044e-02, 5.9260e-02],
[-5.6601e-02, -4.2292e-04, -4.7735e-02, -7.0158e-03, 4.3759e-02],
[ 3.3761e-02, -1.2779e-01, -3.9022e-02, 6.8614e-02, -4.8467e-03],
[ 3.8670e-02, -2.9852e-02, -6.6050e-02, -7.3340e-02, -7.6644e-02],
[ 1.1798e-02, -3.0740e-02, -5.0641e-02, -1.2520e-03, -7.6266e-02]],
[[ 5.0593e-02, 2.0795e-02, 6.8589e-02, 2.9682e-02, 2.4102e-02],
[ 9.1989e-02, -4.7258e-02, 2.0969e-02, 1.2252e-02, 8.5911e-02],
[-4.1193e-02, -2.9648e-02, -7.0664e-02, -6.0782e-03, 8.7604e-02],
[ 9.0460e-02, 4.0291e-02, 2.0270e-01, 2.7138e-02, 6.1062e-02],
[-5.7255e-02, -1.2864e-01, 5.3703e-02, 2.1063e-02, 8.3378e-03]],
[[ 4.5253e-02, -1.7370e-02, -8.8588e-02, 3.3718e-02, 1.2469e-02],
[-2.0238e-02, 8.9196e-02, 4.8577e-02, 2.2773e-02, 4.3565e-02],
[-2.2077e-02, -7.2846e-02, 2.2011e-02, 2.7994e-02, 7.1498e-02],
[ 9.7228e-02, -8.7686e-02, -1.9475e-02, -3.4963e-02, 1.0454e-02],
[-9.2156e-02, -2.6965e-02, -1.1747e-01, -4.5681e-02, -7.0946e-02]]],
...,
[[[-2.4093e-02, -4.6103e-02, -6.8426e-02, 7.7020e-02, -7.0480e-02],
[-7.8094e-02, -1.7899e-03, 1.5691e-02, 7.5024e-02, -6.8291e-02],
[-4.2985e-02, -6.2913e-02, -8.8126e-02, -5.1507e-03, -5.0252e-02],
[ 1.4346e-02, -5.1735e-02, 1.4216e-01, -1.6072e-02, -1.9065e-02],
[ 4.0965e-02, 2.8348e-02, 8.5792e-02, 4.9581e-02, 4.2866e-03]],
[[ 7.8131e-02, 6.6378e-03, 5.6797e-02, -4.0113e-02, 4.6595e-02],
[-4.8508e-02, -1.2878e-02, 3.0165e-02, -1.2445e-01, -8.1256e-02],
[-2.0725e-02, 1.1674e-01, -1.5972e-01, 6.4756e-02, -2.6744e-03],
[ 1.1154e-01, -9.4811e-02, 9.5269e-02, 4.3903e-02, -1.0204e-01],
[ 9.8562e-02, 1.2404e-03, 4.1380e-02, 3.1553e-02, 5.3186e-02]],
[[ 2.6055e-02, 2.7388e-02, -6.0488e-02, -9.5511e-02, -1.2093e-01],
[-2.4814e-02, -1.2639e-01, -1.5237e-02, 5.5561e-02, 3.0209e-02],
[-8.2053e-02, -4.5935e-02, 6.1119e-02, -1.3363e-01, -7.1992e-02],
[-1.5287e-01, -6.1410e-02, 6.0375e-03, -8.7979e-02, 9.1407e-02],
[-1.0431e-02, 6.0618e-02, -1.5103e-01, 1.1385e-01, -4.2957e-02]],
[[ 5.2269e-02, -8.1624e-05, -6.7309e-02, 2.2620e-02, -8.8212e-02],
[-4.9433e-02, -1.1957e-02, -5.3011e-02, 5.0615e-02, 3.2292e-02],
[-8.3719e-02, -4.0361e-02, 1.0135e-01, -4.5393e-02, -1.9368e-01],
[-7.6934e-02, -9.0911e-02, -3.0937e-02, 8.6783e-02, 1.7257e-02],
[-2.0637e-03, 1.0265e-01, -6.3748e-02, 3.5307e-02, -4.9159e-02]],
[[-5.7169e-02, 1.0791e-01, 6.7873e-02, -5.5937e-02, 5.8612e-03],
[-4.9908e-02, -1.9537e-02, 3.3519e-02, 4.3909e-02, -5.1788e-02],
[-4.6378e-02, -2.4492e-02, 1.7248e-02, 2.1012e-02, 8.3702e-03],
[ 4.1057e-02, -6.2135e-02, -2.2248e-02, -1.6534e-02, 1.2725e-02],
[-9.7706e-02, -4.1293e-02, 2.0015e-02, 6.2570e-02, 1.2564e-01]],
[[ 3.4413e-02, -1.4170e-01, -2.1922e-02, 1.0131e-01, -1.7902e-03],
[ 9.0206e-02, 9.9272e-02, -1.1902e-01, 1.9933e-03, -3.6138e-02],
[-6.4406e-02, -1.4185e-02, -1.7267e-02, -4.1037e-02, 2.5467e-02],
[-1.9643e-02, -1.5773e-01, 1.4423e-01, -5.9313e-02, 9.5439e-02],
[ 4.9707e-02, 5.7449e-02, 3.8146e-02, -7.0734e-02, -2.7587e-03]]],
[[[ 7.7226e-02, 5.9220e-02, -4.4861e-02, 9.5649e-02, 4.2409e-02],
[ 7.6165e-02, 1.4427e-01, -1.8052e-02, -4.2129e-03, -1.1270e-01],
[ 7.9877e-03, -1.1160e-01, -8.1328e-03, -6.2670e-02, 5.2545e-02],
[ 3.0828e-02, 3.9485e-02, 2.2733e-02, 3.3326e-02, -6.9556e-02],
[-1.6844e-02, 1.1067e-02, -1.1267e-01, 3.9306e-03, -2.8748e-02]],
[[-1.2918e-01, -9.7690e-03, 3.9843e-02, -3.0822e-02, 8.3621e-02],
[-3.9798e-02, 1.2642e-03, 5.3833e-02, -7.6645e-02, 1.5902e-02],
[ 7.7658e-02, -6.3720e-02, -5.3355e-02, 1.1662e-01, 1.0570e-01],
[-8.8436e-02, 2.8177e-02, -2.8857e-02, 5.6628e-02, -3.3440e-02],
[-4.2123e-02, -8.4260e-02, 5.4268e-02, -3.7032e-02, 6.7248e-03]],
[[-3.7819e-02, 5.5044e-02, -7.9974e-02, -8.9877e-02, 2.6892e-02],
[ 6.9266e-02, -2.9285e-02, -3.6309e-02, -1.1215e-01, -6.7951e-02],
[-3.6242e-03, -7.9844e-02, -4.0420e-02, 4.0198e-03, -3.4907e-03],
[-1.0090e-03, -2.3102e-02, -8.5352e-02, -4.3441e-02, 1.3924e-02],
[-8.5443e-02, -3.8716e-02, 3.6130e-02, -4.8093e-02, 1.1988e-01]],
[[-4.3373e-02, 9.9957e-02, 1.4781e-01, -1.3750e-01, 2.6857e-02],
[ 4.4002e-02, -7.9690e-02, -3.7093e-02, 5.4835e-03, -8.9189e-03],
[-8.6401e-02, -1.4609e-01, 1.1692e-01, -8.3849e-03, 6.4011e-02],
[ 7.7343e-02, -4.4404e-02, 1.3719e-01, -7.4191e-02, 7.1682e-02],
[ 4.9977e-02, -7.5457e-02, 2.9231e-02, 1.8797e-01, 1.4974e-01]],
[[ 2.4268e-02, 6.1654e-02, 6.3272e-02, -1.0978e-02, -1.6786e-01],
[ 3.4399e-02, 3.4448e-02, -4.8970e-02, -6.7921e-03, 6.8282e-02],
[-1.0370e-01, -8.1308e-02, 9.1539e-03, -2.5933e-02, -3.0929e-02],
[ 7.3665e-02, -1.6387e-01, -9.2283e-02, -8.2550e-02, -4.3681e-02],
[-6.1899e-02, 1.7591e-01, 2.2806e-03, -1.0136e-01, -3.4592e-02]],
[[ 1.2506e-02, -1.0966e-01, -1.2902e-02, 4.1343e-02, -1.6163e-02],
[ 6.2248e-02, 4.3803e-02, -1.4526e-01, -1.5219e-01, 1.8110e-03],
[ 9.2129e-02, -7.6849e-02, 1.3811e-02, 3.2184e-02, -1.6774e-02],
[ 1.0622e-01, -1.3869e-02, -1.4605e-02, -2.6013e-02, -6.4145e-02],
[-2.6536e-02, -4.1028e-02, -2.1912e-03, 8.6890e-02, 4.4823e-02]]],
[[[ 3.6491e-02, -6.4714e-02, -6.0897e-03, -4.6173e-02, 2.3836e-02],
[-1.0014e-02, 7.0377e-03, -2.4878e-02, 6.7624e-02, -6.6327e-04],
[-9.7066e-02, 4.3811e-03, 1.2715e-01, -8.9454e-02, -7.0118e-02],
[ 6.5933e-02, -4.7004e-03, -1.1758e-01, -1.6999e-02, 5.3587e-02],
[ 2.3134e-02, -8.7674e-02, 6.0083e-03, -1.8805e-02, 1.6294e-01]],
[[-5.5460e-02, -1.1708e-02, -4.1244e-02, -3.7981e-02, -2.7246e-03],
[-6.9340e-02, 1.9413e-02, 1.3383e-01, 4.9228e-02, 8.6531e-02],
[ 6.4961e-02, -1.3002e-03, 2.9986e-02, 5.1293e-02, -1.2678e-02],
[ 7.0438e-02, 1.1205e-01, 1.7375e-01, 3.3223e-02, -7.5450e-02],
[ 4.1298e-02, -1.6560e-02, -8.0311e-02, 1.0686e-02, 3.0853e-02]],
[[-2.6702e-02, -6.8132e-02, -1.7867e-02, -5.7511e-02, 5.7227e-02],
[-3.6242e-02, -4.1175e-02, 9.6871e-02, -6.2403e-02, 3.0870e-02],
[ 3.5389e-02, 1.6592e-02, -1.9059e-01, 8.0625e-02, 3.5002e-02],
[ 7.4076e-03, -8.5312e-02, 2.7993e-02, 1.2907e-01, 5.8467e-02],
[ 6.3744e-02, 2.3969e-02, 5.5898e-02, -4.3199e-02, 1.6609e-02]],
[[-5.8074e-03, -5.6821e-02, -6.2274e-02, -7.0582e-02, 2.7614e-02],
[ 2.2282e-02, -4.4088e-02, -8.4117e-03, -3.5089e-02, -1.1400e-01],
[-3.2177e-02, -2.5463e-02, 1.3630e-02, -7.4737e-02, 5.4067e-02],
[ 1.3170e-02, 4.7218e-02, -3.8321e-03, -3.4776e-02, -7.6640e-02],
[ 1.1996e-02, -6.3880e-02, -4.6459e-02, -3.3479e-02, -5.4155e-03]],
[[ 1.4449e-02, 4.5700e-02, -9.0163e-02, 5.7394e-03, 4.8105e-02],
[-5.1704e-02, -9.8552e-02, 1.1284e-01, 5.1767e-02, 7.1735e-02],
[-1.8952e-02, 1.6876e-01, 6.1535e-02, 9.8161e-02, 1.1692e-02],
[-5.5695e-02, -9.7229e-02, 3.5393e-02, 3.6695e-02, -5.0597e-02],
[-8.2985e-02, -2.6743e-03, -8.6268e-02, 2.9535e-02, -6.2120e-02]],
[[ 9.9467e-02, 2.2627e-01, -5.1615e-02, -7.7695e-03, -1.3050e-01],
[-1.4781e-02, 1.0199e-01, -1.3160e-01, 5.6579e-03, -5.0954e-02],
[ 1.3510e-02, 1.2833e-01, -5.0731e-02, -5.0934e-02, -7.4834e-02],
[-1.9434e-03, 7.5700e-02, 4.0693e-02, 6.2565e-02, -1.0121e-01],
[-1.0599e-01, -3.2444e-03, 2.9434e-02, -2.1970e-02, -1.1739e-02]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0365, 0.0072, 0.0093, 0.0025, 0.0342, 0.0735, 0.0308, 0.0723,
-0.0758, 0.0133, 0.0423, 0.0314, -0.0622, 0.0366, 0.0547, -0.0157],
requires_grad=True)
tensor: tensor([ 0.0190, -0.0501, 0.0052, -0.1015, -0.0062, 0.0528, 0.0634, 0.0589,
-0.1459, -0.0041, 0.0811, 0.0519, -0.0044, -0.0696, 0.0491, -0.0273],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[0., 0., -0., -0., 0.],
[0., -0., 0., -0., 0.],
[-0., 0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., -0., 0.]],
[[0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., -0., 0., 0., 0.]],
[[-0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., -0., -0., 0.]],
[[0., 0., 0., 0., 0.],
[0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., -0., -0., 0.]],
[[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.]],
[[-0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[0., -0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., 0.]]],
[[[0., -0., -0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[0., 0., -0., 0., 0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[0., -0., -0., 0., 0.],
[-0., 0., 0., -0., -0.],
[0., -0., -0., 0., 0.],
[-0., -0., 0., 0., -0.]],
[[-0., -0., 0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.],
[0., -0., -0., 0., 0.]],
[[-0., 0., 0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., -0., 0.]],
[[-0., -0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., 0., 0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., -0., -0., 0.]]],
[[[0., 0., -0., -0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., 0., -0.]],
[[-0., 0., 0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., -0., 0., 0., -0.]],
[[0., 0., 0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., -0., 0., -0.]],
[[-0., 0., -0., 0., 0.],
[-0., -0., 0., 0., 0.],
[0., -0., -0., 0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., -0., -0.]],
[[0., -0., 0., -0., 0.],
[0., -0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., 0., 0.],
[-0., -0., 0., -0., -0.]],
[[0., 0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., 0., 0.],
[0., -0., -0., 0., 0.],
[-0., -0., -0., -0., -0.]]],
...,
[[[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., 0., -0.]],
[[0., 0., 0., 0., -0.],
[-0., -0., 0., -0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[0., 0., 0., 0., 0.]],
[[-0., 0., 0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., -0., 0., -0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., 0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., -0., -0.]],
[[-0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[-0., -0., 0., 0., 0.]],
[[-0., -0., 0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., -0., 0., -0., 0.],
[0., 0., 0., -0., 0.]]],
[[[0., 0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[0., -0., -0., -0., -0.]],
[[-0., -0., -0., -0., 0.],
[-0., -0., -0., -0., 0.],
[0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., 0.]],
[[-0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., 0.]],
[[-0., 0., 0., -0., -0.],
[0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.]],
[[-0., 0., 0., 0., -0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., -0., 0.]],
[[0., -0., 0., 0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[0., 0., 0., -0., -0.],
[0., 0., -0., 0., -0.],
[-0., 0., 0., -0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., -0., 0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., 0., 0., 0., -0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., 0., -0.]],
[[-0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0.],
[0., 0., 0., -0., 0.]],
[[-0., -0., -0., -0., -0.],
[0., -0., -0., -0., -0.],
[-0., -0., -0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[0., 0., -0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.]],
[[0., 0., 0., -0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., -0., -0., -0.],
[-0., 0., -0., 0., -0.],
[-0., 0., 0., 0., -0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[ 4.2191e-02, 7.6840e-02, -7.4705e-02, -2.1865e-02, 6.1931e-02],
[ 4.5672e-02, -4.5500e-02, 6.8328e-02, -1.9356e-02, 6.2420e-03],
[-9.1323e-03, 6.4114e-02, -6.2778e-02, -2.9379e-02, 9.7347e-03],
[ 1.5921e-02, -2.0886e-02, -6.3837e-02, -4.3951e-02, 5.5913e-04],
[-1.2263e-02, -4.8414e-02, -4.4003e-02, -2.4253e-02, 3.2497e-02]],
[[ 2.3382e-02, 3.9259e-02, -7.7953e-02, -3.2520e-02, 1.8270e-02],
[-4.3315e-02, 7.9060e-02, -5.8311e-02, -2.8759e-02, -8.7089e-03],
[ 5.5321e-02, 4.8438e-02, 6.5735e-02, -4.9760e-02, -3.5508e-02],
[ 4.1005e-02, 4.6544e-02, 1.5205e-02, -3.6013e-02, 6.9640e-03],
[-6.8387e-02, -4.9177e-02, 4.0450e-02, 7.0279e-02, 7.5872e-02]],
[[-4.6557e-03, 1.8276e-02, -5.8220e-02, -2.1548e-02, 5.6864e-02],
[-1.3140e-02, 3.1058e-02, 1.7065e-02, -6.0197e-02, -9.2233e-03],
[-5.6885e-02, -6.7734e-02, 2.2703e-02, 7.8996e-02, -2.3009e-02],
[-3.2973e-02, 3.9733e-02, -4.7892e-02, 1.3657e-02, -6.1810e-02],
[ 6.4573e-02, -2.9753e-02, -5.5327e-02, -9.4562e-03, 2.1779e-02]],
[[ 7.1124e-02, 4.0486e-02, 6.9724e-02, 7.6075e-02, 4.1246e-03],
[ 8.3202e-03, 2.3893e-02, -6.6808e-03, -4.9518e-02, 3.8584e-02],
[-5.2714e-02, 5.0159e-02, -6.6929e-02, -7.0058e-02, -4.7074e-02],
[ 7.5636e-02, 2.7516e-02, -5.3383e-02, -7.9865e-02, 7.8899e-02],
[ 6.7598e-02, 2.2288e-02, -4.3938e-02, -2.2255e-02, 3.8803e-02]],
[[-1.8988e-02, 6.3276e-02, -6.4471e-02, -3.8395e-02, -5.1800e-02],
[ 2.6698e-02, 5.7260e-03, -2.2055e-02, 2.9274e-02, 1.7332e-02],
[-1.0218e-02, -2.7324e-02, -2.9080e-02, 1.3222e-02, 7.9084e-02],
[ 4.3772e-02, 7.7684e-02, 4.8257e-02, -6.8435e-02, -2.0343e-02],
[ 3.0093e-02, 3.1836e-03, 2.6288e-03, -8.2295e-03, -2.5020e-02]],
[[-3.8859e-02, 5.2245e-02, -7.8403e-02, -4.2871e-02, 6.1650e-02],
[-4.8324e-02, 5.1745e-02, 5.0752e-02, 7.3504e-02, -2.3351e-02],
[ 7.5441e-02, -5.1618e-02, 8.0429e-02, 7.3450e-02, 7.8664e-02],
[ 4.9288e-02, -1.9266e-02, 2.9147e-02, -8.0626e-02, -3.1476e-02],
[ 6.6063e-02, 1.7808e-02, 2.6498e-02, 4.6047e-03, 7.2748e-02]]],
[[[ 2.7147e-02, -6.4806e-02, -1.7023e-02, 4.9091e-03, -6.6140e-03],
[-7.7151e-02, -1.2995e-02, -4.5008e-02, -1.8290e-02, 5.0563e-02],
[-7.6304e-02, -8.0019e-02, -1.1984e-02, -3.8425e-02, -4.5299e-03],
[-1.4411e-02, 5.2870e-02, 7.4521e-02, 4.9850e-02, -2.3332e-02],
[ 3.2249e-03, 4.7371e-02, -5.4040e-02, 7.8740e-02, 6.2129e-02]],
[[-5.5597e-02, -7.9082e-03, 5.7858e-03, -1.7938e-02, 3.2981e-02],
[-5.7644e-02, -7.8341e-02, -5.2250e-02, -7.0176e-02, 9.1169e-03],
[ 4.8751e-03, 2.7214e-02, -7.2253e-02, -7.9531e-02, 2.2115e-02],
[-8.0066e-02, 6.1143e-03, 9.5127e-03, 7.6738e-02, -5.4762e-02],
[-4.0489e-02, -4.4740e-02, 7.4822e-02, 5.8303e-02, 3.4314e-02]],
[[ 3.2335e-02, -7.6859e-03, 7.2277e-02, 6.4942e-02, -2.1889e-02],
[ 6.7059e-02, -2.5517e-02, -7.2487e-02, 6.3138e-02, 1.8800e-02],
[-4.1668e-02, 3.3260e-02, 2.0425e-03, -9.8570e-03, -5.2738e-02],
[ 5.0590e-02, -1.0069e-02, -1.0559e-02, 1.4182e-02, 5.3266e-02],
[-1.0744e-02, -2.3853e-02, 1.9477e-02, 1.2789e-02, -1.9533e-03]],
[[-7.1060e-03, -4.1691e-02, 3.4899e-02, -7.2925e-02, -5.2534e-02],
[-4.6110e-02, 7.5498e-02, 6.1524e-02, 1.1099e-02, 7.3727e-02],
[ 5.4388e-02, 3.9116e-02, -6.6412e-02, -2.6689e-02, -2.0881e-02],
[ 2.6913e-02, 3.5247e-02, -2.6048e-02, 1.4267e-02, 2.6323e-02],
[ 6.5432e-03, -7.3995e-03, -4.3528e-02, 5.0431e-02, 6.9511e-02]],
[[-2.8527e-02, 6.1429e-02, 6.0427e-02, -4.1141e-02, -7.3619e-02],
[ 2.1947e-02, -3.8028e-02, -6.7364e-02, 1.5408e-02, 5.2976e-02],
[ 5.8108e-02, -6.1758e-02, 1.4653e-02, -7.6558e-02, -3.1448e-02],
[ 6.2695e-02, 6.4849e-02, 3.1678e-02, 8.7257e-03, 8.9366e-03],
[ 4.3787e-02, -2.6563e-02, 7.6010e-02, -4.9455e-02, 4.6068e-02]],
[[-4.6879e-02, -1.0211e-02, -5.7937e-02, -6.0173e-02, -5.6010e-03],
[ 6.7886e-02, 4.8736e-02, 7.5209e-02, -2.4597e-02, 1.2666e-02],
[ 2.7294e-02, 5.9385e-02, 8.1056e-02, 7.2644e-02, -7.7046e-02],
[ 1.7380e-02, -3.2567e-02, -2.1987e-02, 7.8235e-02, -5.2908e-02],
[-2.0811e-02, -2.9994e-02, -4.6838e-02, -4.2671e-02, 3.8727e-02]]],
[[[ 3.5134e-02, 5.2745e-02, -6.1149e-02, -7.6414e-02, -2.6176e-02],
[-6.1641e-02, -8.1041e-02, -2.5100e-02, -7.2410e-02, 6.1282e-02],
[-3.0138e-03, 5.0077e-02, -8.1627e-02, -4.3763e-02, -3.0137e-03],
[-4.0017e-02, -4.3606e-02, -5.0771e-02, -1.0728e-02, 6.0225e-02],
[-6.3614e-02, 5.9101e-02, -1.1370e-02, 6.4711e-02, -1.0095e-02]],
[[-5.0975e-02, 6.0192e-02, 4.8900e-02, 8.0230e-02, -1.5046e-02],
[ 7.2609e-02, 3.1814e-02, 6.7249e-02, 5.9835e-02, 7.7418e-02],
[ 2.9254e-02, 3.7293e-02, 1.1221e-02, 3.3947e-03, 4.2694e-02],
[-6.1252e-02, 3.9850e-02, 8.3872e-03, 6.4936e-02, 5.5738e-02],
[-3.8509e-02, -5.6070e-03, 6.8840e-02, 6.8610e-02, -4.3579e-02]],
[[ 1.8686e-02, 2.1468e-02, 5.9046e-02, 5.2732e-02, 5.1538e-02],
[ 7.0772e-02, 4.9593e-02, 4.1163e-02, 6.8360e-02, -3.9729e-02],
[-4.0441e-02, -1.5649e-02, 7.2554e-02, -2.2384e-02, 3.3869e-02],
[-1.1148e-02, 7.9215e-02, 1.0318e-02, -4.9182e-02, 9.6277e-03],
[ 4.1858e-02, 3.0790e-02, -3.0381e-02, 6.6102e-02, -5.1207e-02]],
[[-6.6738e-03, 6.4776e-02, -3.3882e-02, 7.1115e-02, 3.1006e-02],
[-2.2906e-02, -2.9046e-02, 3.5319e-02, 5.3543e-02, 2.1489e-02],
[ 3.3100e-02, -3.5833e-02, -1.8264e-02, 6.3019e-03, 3.8628e-02],
[-3.4829e-02, -2.0159e-02, -4.5294e-02, 1.4057e-02, -7.9188e-02],
[-3.1353e-02, -1.8218e-02, -8.0638e-02, -5.0035e-02, -3.4570e-02]],
[[ 6.1931e-03, -3.6366e-03, 2.6574e-02, -1.3003e-02, 6.1312e-02],
[ 4.6450e-02, -4.6889e-02, -2.5151e-02, -1.6860e-02, 6.3430e-02],
[-6.3583e-02, 3.9752e-02, -4.2166e-02, -2.7013e-02, -3.3751e-03],
[ 5.6220e-02, 8.1038e-02, 7.8622e-02, 5.7729e-02, 3.2215e-02],
[-3.7003e-02, -6.6767e-02, 6.3296e-02, -7.1348e-02, -3.1742e-03]],
[[ 7.3035e-02, 4.1458e-02, -4.6261e-02, 3.0164e-02, -1.8932e-03],
[ 3.8696e-02, 5.4963e-02, 4.5004e-02, -5.6026e-02, 1.2815e-02],
[-8.0065e-02, 1.8459e-02, 7.6491e-02, 7.3781e-02, 5.5816e-02],
[ 4.5870e-02, -8.1061e-02, -4.7139e-02, 5.6221e-02, 6.0546e-02],
[-6.1282e-02, -2.2380e-02, -6.3779e-02, -6.8776e-02, -4.3537e-02]]],
...,
[[[ 5.2454e-02, -4.7911e-02, -7.2297e-02, -2.0057e-02, -5.5949e-02],
[-1.7783e-02, -8.5585e-03, 2.1769e-02, 7.7195e-02, -2.6911e-02],
[-5.5650e-02, -7.1240e-02, -4.3191e-02, 7.2941e-02, 7.0596e-04],
[-1.1619e-02, 9.3917e-03, -2.8626e-02, -2.5418e-02, -1.0257e-02],
[ 5.1407e-02, 4.8896e-02, 6.0359e-02, 3.2419e-02, -1.9229e-02]],
[[ 1.5427e-02, 4.6055e-02, 3.8077e-02, 1.9711e-02, -4.7648e-02],
[-2.7328e-02, -7.1221e-02, 5.6410e-02, -2.5372e-02, -4.0767e-02],
[-3.1127e-02, 6.7051e-02, -7.9245e-02, 6.6682e-02, -2.6589e-03],
[ 6.9368e-02, -4.1969e-02, 5.7844e-02, 4.0387e-02, -5.6537e-02],
[ 1.3153e-02, 4.0302e-02, 7.1897e-02, 2.1402e-02, 5.5573e-02]],
[[-4.6272e-02, 2.1553e-02, 1.4332e-03, -4.3438e-02, -2.7539e-03],
[-3.4410e-02, -7.9327e-02, -3.9320e-02, 7.7251e-02, 2.2694e-02],
[-7.6279e-02, -3.1001e-02, 6.8287e-02, -7.7937e-02, -6.5913e-02],
[-5.2818e-02, -3.7519e-02, -2.3019e-02, -3.6675e-02, 7.9680e-02],
[-6.1808e-02, 5.1876e-02, -7.8567e-02, 8.1565e-02, -6.8310e-02]],
[[ 5.5309e-02, 8.7365e-03, -5.6561e-03, 7.1133e-02, -4.6917e-02],
[ 2.1503e-03, 8.4254e-03, 1.1892e-02, 7.9317e-02, -7.8111e-03],
[-5.6072e-02, -6.5883e-02, 3.5538e-02, -7.8521e-02, -6.9720e-02],
[-1.9393e-02, -7.3491e-02, 8.1700e-03, -1.9778e-03, 2.8227e-02],
[-1.6722e-02, 5.5947e-02, 6.2404e-03, -5.1478e-03, -2.5255e-02]],
[[-7.8407e-02, 7.2449e-02, -6.2823e-02, -6.6556e-03, -1.6768e-02],
[-3.4428e-02, -1.3495e-02, 6.7206e-04, 2.5644e-02, -3.1432e-02],
[-7.0893e-02, 2.9846e-02, -5.0699e-03, 2.9505e-02, -1.8095e-02],
[ 6.7203e-02, 5.6675e-03, -3.5811e-02, 1.9892e-02, 5.2260e-02],
[-6.8506e-02, -7.1065e-02, 3.3085e-02, 4.3750e-02, 1.6927e-02]],
[[-6.5512e-02, -5.6648e-02, 4.6444e-02, 7.4754e-02, 1.1064e-02],
[ 6.3824e-02, 6.5449e-02, -4.9581e-02, -1.4683e-02, -4.2090e-02],
[-4.7219e-02, -3.5248e-02, 3.6657e-04, -2.6233e-02, 7.6344e-02],
[-6.2184e-02, -6.0576e-02, 8.0155e-02, -4.5559e-02, 5.8270e-02],
[ 3.0984e-02, 5.9146e-02, 5.2023e-02, -8.1066e-02, 1.9684e-02]]],
[[[ 7.3765e-02, 3.5141e-03, -2.1752e-03, 7.4100e-02, 7.4641e-02],
[ 4.8591e-02, 7.6598e-02, 2.9565e-02, 2.1044e-02, -5.0159e-02],
[-6.1025e-02, -2.8593e-02, 2.4891e-02, -6.0556e-02, 6.2090e-02],
[-1.5755e-02, 4.0377e-02, 3.5295e-03, 5.3803e-02, -6.6197e-02],
[ 4.4328e-02, -1.8031e-02, -7.4217e-02, -9.6550e-03, -6.4711e-03]],
[[-7.8584e-02, -4.6375e-02, -3.3496e-02, -6.2696e-02, 3.7588e-02],
[-1.4592e-02, -2.4390e-02, -5.0144e-02, -5.0049e-02, 3.2448e-02],
[ 7.5607e-02, -6.8160e-02, -1.8825e-02, 6.0839e-02, 2.1488e-02],
[-6.6120e-02, 6.7261e-02, -5.8349e-02, 4.2199e-02, -5.2458e-02],
[-4.8494e-02, -5.4509e-02, 9.2273e-05, -3.4371e-02, 5.9357e-04]],
[[-7.8414e-02, 7.0503e-02, -3.8403e-02, -4.4648e-02, 7.5146e-02],
[-5.6907e-02, 3.4213e-05, -6.8368e-02, -5.0295e-02, -4.3787e-02],
[-5.8034e-02, -6.0247e-02, -5.8353e-02, -2.5280e-02, 6.9298e-02],
[-1.4764e-03, 4.3781e-02, -2.5240e-02, -4.3813e-02, 2.8583e-02],
[-4.1418e-02, -4.6101e-02, 2.7372e-03, -3.2267e-02, 5.5665e-02]],
[[-7.9642e-02, 5.8555e-02, 5.8883e-02, -6.3063e-02, -3.0692e-02],
[ 5.0561e-02, -6.8027e-02, 4.7893e-02, -6.2822e-02, 1.0364e-02],
[-6.9147e-02, -7.2420e-02, 7.8318e-02, 6.8844e-02, 5.5427e-03],
[ 6.0117e-02, -3.5221e-02, 7.4160e-02, -1.2965e-02, 6.2557e-02],
[ 4.4865e-04, -6.0832e-02, 7.6485e-03, 6.0057e-02, 2.3943e-02]],
[[-2.4230e-02, 6.7864e-02, 7.2121e-02, 3.4157e-02, -7.9538e-02],
[-3.3560e-02, 1.1223e-02, -1.7517e-02, -5.4311e-02, 6.5203e-02],
[-7.7116e-02, -6.3548e-02, 3.0147e-02, -5.6751e-02, -4.5155e-02],
[ 7.9671e-02, -6.0086e-02, -6.9947e-02, 1.2981e-02, -2.5239e-02],
[-2.4339e-03, 7.9699e-02, 2.4683e-02, -4.5459e-02, 9.2444e-03]],
[[ 4.2156e-02, -5.2864e-02, 7.0858e-02, 7.3412e-02, -1.2266e-02],
[ 4.9417e-02, 4.1968e-02, -6.1616e-02, -1.7818e-02, -2.4427e-02],
[ 2.1473e-02, -1.1306e-02, 4.7061e-02, -2.3690e-02, -6.6106e-02],
[ 8.0653e-02, -6.1394e-02, -9.6741e-03, 3.2104e-02, -3.5300e-02],
[-3.7291e-02, -1.6968e-02, 1.3973e-02, 5.5290e-02, 4.0622e-03]]],
[[[ 2.2983e-02, 1.0195e-02, 4.0193e-02, -7.5025e-02, -1.5421e-03],
[ 6.0510e-02, 1.8363e-02, -3.8094e-02, 4.5445e-02, -2.9622e-03],
[-5.3075e-02, 3.4295e-02, 5.0654e-02, -1.9342e-02, 3.2361e-02],
[ 7.3380e-02, -1.0884e-02, -5.3324e-02, -5.0394e-02, 3.1872e-02],
[-4.9773e-02, -2.8900e-02, -2.6879e-02, -2.8097e-02, 6.6398e-02]],
[[-2.5804e-02, -4.6905e-02, -4.2523e-02, -5.0381e-02, -2.0208e-02],
[-4.8815e-02, 2.2532e-02, 6.9881e-02, 2.1225e-02, 4.3858e-02],
[ 3.8482e-02, 7.5890e-03, 1.4969e-02, 1.8850e-02, -2.9226e-02],
[ 3.2008e-02, 2.3823e-02, 7.4640e-02, -3.2508e-02, -5.9983e-02],
[ 2.0176e-02, -2.4570e-02, -4.5569e-02, 6.8979e-02, -4.4682e-02]],
[[-4.6043e-02, -3.2025e-02, 2.8151e-02, -4.8214e-02, 7.1796e-02],
[ 6.4546e-02, -2.4392e-03, 6.5753e-02, 1.0233e-02, 3.1131e-02],
[-2.9636e-02, 7.3704e-02, -7.7781e-02, 8.0873e-02, -8.5714e-04],
[ 3.0164e-02, -6.3718e-02, 8.1516e-02, 1.4934e-02, 6.1368e-02],
[ 4.9876e-02, 9.5079e-03, 7.7987e-02, -3.8619e-02, 5.9578e-02]],
[[-4.2667e-03, -4.3751e-02, -2.8233e-02, -4.1592e-02, -6.0769e-02],
[ 4.8450e-02, -1.2896e-02, -4.6601e-03, -6.4206e-03, -1.1431e-02],
[-7.7503e-02, -1.0773e-02, -8.0490e-03, -3.1181e-03, 1.3285e-02],
[ 6.8160e-02, 6.7861e-02, 1.4774e-02, -1.3599e-02, -2.4826e-02],
[ 4.7705e-02, -7.3378e-02, -1.3825e-02, -3.1077e-02, -3.7972e-02]],
[[ 5.6202e-02, 4.1837e-02, -2.8773e-02, 7.2132e-02, 2.5515e-02],
[-3.3901e-02, 7.0663e-03, 6.8551e-02, -9.9631e-03, 3.0268e-02],
[-7.2736e-02, 2.3806e-02, 2.9459e-02, 6.0356e-02, -4.3899e-03],
[-2.8373e-02, -1.1198e-02, -1.2081e-02, 1.7582e-02, -6.8985e-02],
[-7.2981e-02, -5.7756e-02, -2.6205e-02, 5.1297e-02, -5.9565e-02]],
[[ 6.5016e-03, 6.3833e-02, 1.0038e-02, -7.4870e-02, -8.0244e-02],
[-1.2594e-02, 2.0456e-02, -4.8056e-02, 1.1776e-02, -7.6218e-02],
[ 5.1599e-02, -9.6947e-03, -1.6420e-02, -5.6519e-02, -7.0490e-02],
[-5.8097e-02, 7.7906e-03, -3.3099e-03, 1.3623e-02, -5.2314e-02],
[-1.3419e-02, 6.9299e-02, 7.2203e-02, 4.8437e-02, -4.1129e-02]]]])
(bias): Normal:
loc: tensor([-0., 0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0365, 0.0072, 0.0093, 0.0025, 0.0342, 0.0735, 0.0308, 0.0723,
-0.0758, 0.0133, 0.0423, 0.0314, -0.0622, 0.0366, 0.0547, -0.0157])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 0.0290, 0.0450, 0.0432, ..., -0.0188, 0.0460, 0.0460],
[-0.0440, 0.0107, -0.0302, ..., -0.0457, 0.0447, -0.0215],
[-0.0185, 0.0321, -0.0076, ..., -0.0138, 0.0309, 0.0219],
...,
[ 0.0011, 0.0479, 0.0438, ..., 0.0297, -0.0310, 0.0051],
[ 0.0461, -0.0489, 0.0107, ..., -0.0152, 0.0484, -0.0418],
[ 0.0300, -0.0080, 0.0244, ..., -0.0223, 0.0077, -0.0147]],
requires_grad=True)
tensor: tensor([[-0.0771, 0.0682, 0.0952, ..., -0.0025, 0.0747, 0.0905],
[-0.0600, -0.0432, -0.0064, ..., -0.0666, 0.0176, 0.0168],
[ 0.0137, 0.0625, 0.0067, ..., 0.0625, 0.0137, 0.1347],
...,
[ 0.0379, -0.0104, 0.0276, ..., 0.0230, -0.0285, 0.0612],
[ 0.0619, -0.0173, 0.0673, ..., 0.0148, 0.0073, 0.0058],
[ 0.1118, -0.0446, -0.0601, ..., 0.0168, -0.0052, -0.0763]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 3.4268e-02, -4.9700e-02, 1.3390e-02, -2.0985e-02, -1.6808e-02,
2.4641e-02, -4.8857e-02, 3.6003e-02, 1.0424e-02, -2.8776e-02,
4.2286e-02, -2.7017e-02, 4.5771e-03, 2.5405e-02, 1.3040e-02,
4.0760e-02, 3.4614e-02, 9.6058e-03, -1.0389e-02, 1.7939e-02,
2.7261e-02, 1.0442e-02, -4.0359e-02, 6.0608e-04, -2.8432e-02,
-2.3294e-02, 2.5094e-02, -2.7444e-02, -3.9087e-02, 4.8333e-02,
-3.8067e-02, 4.3731e-02, -1.1337e-02, 6.5628e-03, -3.9661e-02,
-3.0479e-02, 4.2353e-02, -2.1659e-02, 1.7937e-02, 4.8377e-02,
-7.3253e-03, -2.8690e-02, -2.0960e-02, -2.0081e-02, -4.1321e-02,
-2.5133e-02, -3.6166e-02, 3.5816e-02, 3.0482e-03, -4.1686e-02,
3.5028e-02, 3.1139e-02, -2.5572e-03, -8.7303e-03, 4.4109e-02,
1.0638e-02, 3.5676e-02, 1.9173e-02, 2.6669e-02, -3.7786e-02,
-1.8936e-02, -1.0854e-05, 3.6708e-02, -2.0134e-02, -4.2009e-02,
3.9515e-02, -3.7057e-02, 2.5732e-02, -2.7906e-02, 3.4639e-02,
4.3312e-03, 1.9998e-02, -2.2635e-02, 2.4380e-02, -1.2083e-02,
3.3281e-02, 4.6888e-02, 3.9643e-02, 1.1001e-02, 4.8635e-02,
-1.2259e-03, 3.2456e-02, -1.0144e-02, -6.7254e-03, 1.5385e-03,
1.6456e-02, -3.1296e-02, -1.5280e-02, 3.8949e-02, -1.9992e-02,
-2.8284e-02, 4.0517e-02, -9.7160e-03, 7.7505e-04, 1.6642e-02,
4.0081e-02, 9.7422e-04, 3.7670e-02, 4.7919e-02, 4.4317e-02,
1.0160e-02, 1.9730e-02, -8.4861e-03, 3.3960e-02, -2.3660e-02,
-4.4850e-02, -3.0455e-02, -1.5874e-03, 1.4935e-02, 2.1578e-02,
4.7099e-02, -3.9308e-02, 1.2687e-02, -3.1501e-02, 2.6750e-02,
1.3101e-02, -4.7801e-02, 4.5735e-02, 2.1969e-02, 1.0239e-02],
requires_grad=True)
tensor: tensor([ 0.0461, -0.0686, -0.0018, -0.0174, -0.1039, 0.0644, 0.0129, -0.0319,
0.0540, -0.0665, 0.0386, -0.0862, -0.0050, -0.0043, 0.0145, 0.0309,
0.1006, 0.0050, -0.0034, 0.0625, -0.0306, 0.0265, -0.1264, -0.0380,
0.0116, -0.0690, 0.0944, -0.0293, -0.0990, 0.1073, 0.0308, 0.0658,
-0.0371, 0.0542, -0.0284, -0.0286, -0.0326, -0.0041, 0.0036, 0.0200,
-0.0654, -0.0131, -0.0922, -0.0521, -0.0524, 0.0307, 0.0191, 0.0887,
0.0523, -0.0082, 0.0263, 0.0368, -0.0741, 0.0028, 0.0159, 0.0315,
-0.0049, 0.0991, 0.0387, -0.0024, -0.0229, 0.1443, -0.0206, -0.0440,
0.0354, 0.0254, -0.0601, 0.0427, -0.0854, 0.0387, 0.0435, 0.0373,
0.0068, 0.0281, 0.0257, -0.0699, -0.0028, 0.0847, 0.0240, -0.0674,
-0.0045, 0.1192, 0.0438, -0.0321, 0.0262, -0.0189, -0.0056, -0.0746,
0.0501, 0.0481, -0.1232, -0.0209, 0.0074, -0.0143, -0.0851, 0.1050,
0.0118, 0.0893, 0.0022, -0.0285, -0.0378, -0.0167, -0.0383, 0.0616,
-0.0673, -0.0795, -0.0137, 0.0920, 0.0490, 0.0294, 0.1110, 0.0153,
0.0588, -0.0804, 0.0791, 0.1213, 0.0279, 0.0537, 0.0901, -0.0692],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., 0., 0., ..., -0., 0., 0.],
[-0., 0., -0., ..., -0., 0., -0.],
[-0., 0., -0., ..., -0., 0., 0.],
...,
[0., 0., 0., ..., 0., -0., 0.],
[0., -0., 0., ..., -0., 0., -0.],
[0., -0., 0., ..., -0., 0., -0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0290, 0.0450, 0.0432, ..., -0.0188, 0.0460, 0.0460],
[-0.0440, 0.0107, -0.0302, ..., -0.0457, 0.0447, -0.0215],
[-0.0185, 0.0321, -0.0076, ..., -0.0138, 0.0309, 0.0219],
...,
[ 0.0011, 0.0479, 0.0438, ..., 0.0297, -0.0310, 0.0051],
[ 0.0461, -0.0489, 0.0107, ..., -0.0152, 0.0484, -0.0418],
[ 0.0300, -0.0080, 0.0244, ..., -0.0223, 0.0077, -0.0147]])
(bias): Normal:
loc: tensor([0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., 0.,
-0., -0., 0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0.,
0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
-0., 0., -0., 0., 0., 0., 0., 0., -0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., -0., 0., -0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., 0., -0., 0., 0., 0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 3.4268e-02, -4.9700e-02, 1.3390e-02, -2.0985e-02, -1.6808e-02,
2.4641e-02, -4.8857e-02, 3.6003e-02, 1.0424e-02, -2.8776e-02,
4.2286e-02, -2.7017e-02, 4.5771e-03, 2.5405e-02, 1.3040e-02,
4.0760e-02, 3.4614e-02, 9.6058e-03, -1.0389e-02, 1.7939e-02,
2.7261e-02, 1.0442e-02, -4.0359e-02, 6.0608e-04, -2.8432e-02,
-2.3294e-02, 2.5094e-02, -2.7444e-02, -3.9087e-02, 4.8333e-02,
-3.8067e-02, 4.3731e-02, -1.1337e-02, 6.5628e-03, -3.9661e-02,
-3.0479e-02, 4.2353e-02, -2.1659e-02, 1.7937e-02, 4.8377e-02,
-7.3253e-03, -2.8690e-02, -2.0960e-02, -2.0081e-02, -4.1321e-02,
-2.5133e-02, -3.6166e-02, 3.5816e-02, 3.0482e-03, -4.1686e-02,
3.5028e-02, 3.1139e-02, -2.5572e-03, -8.7303e-03, 4.4109e-02,
1.0638e-02, 3.5676e-02, 1.9173e-02, 2.6669e-02, -3.7786e-02,
-1.8936e-02, -1.0854e-05, 3.6708e-02, -2.0134e-02, -4.2009e-02,
3.9515e-02, -3.7057e-02, 2.5732e-02, -2.7906e-02, 3.4639e-02,
4.3312e-03, 1.9998e-02, -2.2635e-02, 2.4380e-02, -1.2083e-02,
3.3281e-02, 4.6888e-02, 3.9643e-02, 1.1001e-02, 4.8635e-02,
-1.2259e-03, 3.2456e-02, -1.0144e-02, -6.7254e-03, 1.5385e-03,
1.6456e-02, -3.1296e-02, -1.5280e-02, 3.8949e-02, -1.9992e-02,
-2.8284e-02, 4.0517e-02, -9.7160e-03, 7.7505e-04, 1.6642e-02,
4.0081e-02, 9.7422e-04, 3.7670e-02, 4.7919e-02, 4.4317e-02,
1.0160e-02, 1.9730e-02, -8.4861e-03, 3.3960e-02, -2.3660e-02,
-4.4850e-02, -3.0455e-02, -1.5874e-03, 1.4935e-02, 2.1578e-02,
4.7099e-02, -3.9308e-02, 1.2687e-02, -3.1501e-02, 2.6750e-02,
1.3101e-02, -4.7801e-02, 4.5735e-02, 2.1969e-02, 1.0239e-02])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=10, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 0.0853, -0.0135, 0.0683, ..., 0.0101, -0.0386, 0.0218],
[ 0.0366, -0.0514, 0.0796, ..., -0.0416, -0.0514, 0.0824],
[-0.0907, 0.0680, 0.0275, ..., -0.0242, 0.0592, -0.0864],
...,
[-0.0862, 0.0863, -0.0316, ..., 0.0718, 0.0438, 0.0558],
[-0.0711, -0.0183, -0.0767, ..., 0.0480, 0.0798, -0.0622],
[ 0.0360, -0.0011, -0.0885, ..., 0.0526, 0.0213, -0.0500]],
requires_grad=True)
tensor: tensor([[ 0.0446, -0.0119, 0.0613, ..., 0.0736, -0.1029, -0.0459],
[ 0.0366, -0.1579, 0.0943, ..., 0.0078, -0.0376, 0.0493],
[-0.1436, 0.0732, 0.0015, ..., -0.0362, -0.0015, -0.0596],
...,
[-0.0071, 0.0492, 0.0063, ..., 0.0985, 0.0816, 0.0473],
[-0.1276, 0.0231, -0.0409, ..., 0.0153, 0.0586, -0.0618],
[ 0.0311, -0.0013, -0.0697, ..., 0.0642, -0.0082, -0.0596]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0895, 0.0306, -0.0073, 0.0383, 0.0428, -0.0883, 0.0808, -0.0065,
0.0228, 0.0389], requires_grad=True)
tensor: tensor([-0.0167, 0.0773, 0.0665, 0.0994, 0.0772, -0.0356, -0.0222, -0.0223,
-0.0587, 0.0343], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., 0., ..., 0., -0., 0.],
[0., -0., 0., ..., -0., -0., 0.],
[-0., 0., 0., ..., -0., 0., -0.],
...,
[-0., 0., -0., ..., 0., 0., 0.],
[-0., -0., -0., ..., 0., 0., -0.],
[0., -0., -0., ..., 0., 0., -0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
...,
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0853, -0.0135, 0.0683, ..., 0.0101, -0.0386, 0.0218],
[ 0.0366, -0.0514, 0.0796, ..., -0.0416, -0.0514, 0.0824],
[-0.0907, 0.0680, 0.0275, ..., -0.0242, 0.0592, -0.0864],
...,
[-0.0862, 0.0863, -0.0316, ..., 0.0718, 0.0438, 0.0558],
[-0.0711, -0.0183, -0.0767, ..., 0.0480, 0.0798, -0.0622],
[ 0.0360, -0.0011, -0.0885, ..., 0.0526, 0.0213, -0.0500]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., -0., 0., -0., 0., 0.])
scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0895, 0.0306, -0.0073, 0.0383, 0.0428, -0.0883, 0.0808, -0.0065,
0.0228, 0.0389])
)
(observed): Observed()
)
)
You just have to define the forward
function, and the backward
function (where gradients are computed) is automatically defined for you
using autograd
.
You can use any of the Tensor operations in the forward
function.
The learnable parameters of a model are returned by net.parameters()
params = list(net.parameters())
print(len(params))
print(params[0].size())
Out:
16
torch.Size([6, 1, 5, 5])
Let try a random 32x32 input
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
Out:
tensor([[ 0.4980, 0.6893, 0.0096, -0.4008, -0.1188, -0.2803, 0.3551, 0.4623,
-0.2278, -0.0520]], grad_fn=<AddmmBackward0>)
Zero the gradient buffers of all parameters and backprops with random gradients:
net.zero_grad()
out.backward(torch.randn(1, 10))
Note
borch.nn
only supports mini-batches. The entire borch.nn
package only supports inputs that are a mini-batch of samples, and not
a single sample.
For example, nn.Conv2d
will take in a 4D Tensor of
nSamples x nChannels x Height x Width
.
If you have a single sample, just use input.unsqueeze(0)
to add
a fake batch dimension.
Before proceeding further, let’s recap all the classes you’ve seen so far.
- Recap:
torch.Tensor
- A multi-dimensional array with support for autograd operations likebackward()
. Also holds the gradient w.r.t. the tensor.nn.Module
- Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.nn.Parameter
- A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to aModule
.autograd.Function
- Implements forward and backward definitions of an autograd operation. EveryTensor
operation, creates at least a singleFunction
node, that connects to functions that created aTensor
and encodes its history.
- At this point, we covered:
Defining a neural network
Processing inputs and calling backward
- Still Left:
Computing the loss
Updating the weights of the network
Loss Function¶
A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.
There are several different
loss functions under the
nn package .
A simple loss is: nn.MSELoss
which computes the mean-squared error
between the input and the target. They are how ever only equivalent to an maximum
likelihood approach in deep learning.
In order to infer the posterior of the weights and thus capture the uncertainty
of the weights as well, we have to use the infer
package. In this example we
will use infer.vi_loss
function that automatically creates the best loss function
for variational inference given the latent variables in your model.
Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:
target = torch.randint(10, (1,)) # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)
Out:
tensor(9549.6875, grad_fn=<AddBackward0>)
Now, if you would follow loss
in the backward direction you will see a graph of
computations that looks like this:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
-> view -> linear -> relu -> linear ->
-> loss
So, when we call loss.backward()
, the whole graph is differentiated
w.r.t. the loss, and all Tensors in the graph that has requires_grad=True
will have their .grad
Tensor accumulated with the gradient.
Backprop¶
To backpropagate the error all we have to do is to loss.backward()
.
You need to clear the existing gradients though, else gradients will be
accumulated to existing gradients.
Now we shall call loss.backward()
, and have a look at conv1’s bias
gradients before and after the backward.
net.zero_grad() # zeroes the gradient buffers of all parameters
The value for the loc paramater of the approximating distribution of
conv1.bias
zeroing the gradients is
print(net.conv1.posterior.bias.loc.grad)
loss.backward()
Out:
tensor([0., 0., 0., 0., 0., 0.])
after calling backward the value is
print(net.conv1.posterior.bias.loc.grad)
Out:
tensor([ 0.6376, 1.3013, -0.0847, 0.4150, 1.0615, 0.2077])
The only thing left to learn is:
Updating the weights of the network
Update the weights¶
The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):
weight = weight - learning_rate * gradient
We can implement this using simple python code:
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
However, as you use neural networks, you want to use various different
update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc.
To enable this, torch built a small package: torch.optim
that
implements all these methods. Using it is very simple:
import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
n_batch_epoch = 10 # number of batches per epoch usually len(dataloader)
optimizer.zero_grad() # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step() # Does the update
Exercises¶
The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.
Try designing yor own feed forward networks with two different types of non lineareties ex. relu
Total running time of the script: ( 0 minutes 0.126 seconds)