Note
Click here to download the full example code
Neural Networks¶
Neural networks can be constructed using the borch.nn package.
Now that you’ve had a glimpse of autograd, nn depends on autograd
to define models and differentiate them. An nn.Module contains layers,
and a method forward(input) that returns the output.
For example, look at this network that classifies digit images:
convnet¶
It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.
A typical training procedure for a neural network is as follows:
Define a network that has some learnable parameters and/or randomVariables
For each batch in a dataset, do:
Process the input data through the network
Compute the loss (how far is the output from being correct?)
Propagate gradients back into the network’s parameters
Update the weights of the network, typically using a simple update rule:
weight = weight - learning_rate * gradient
Define the network¶
Let’s define this network:
import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__(posterior=posterior.Automatic())
# 1 input image channel, 6 output channels, 5x5 convolution kernel
self.conv1 = nn.Conv2d(1, 6, 5)
# 6 input channels, 16 output channels, 5x5 convolution kernel
self.conv2 = nn.Conv2d(6, 16, 5)
# An affine operation: y = Wx + b
# NB after two convolutional operations with 5x5 kernels and no padding,
# the spatial dimension of an image with intial dimension 32x32 is
# 5x5 (with 16 channels)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = self.fc2(x)
# Specifying the likelihood function
self.classification = distributions.Categorical(logits=x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Out:
Net(
(posterior): Automatic()
(prior): Module()
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-0.0991, -0.0592, -0.1623, 0.0523, 0.1645],
[-0.1893, -0.0215, 0.1007, -0.1699, -0.1765],
[ 0.1281, -0.0728, -0.1345, -0.0970, 0.0558],
[-0.1093, 0.1952, -0.0354, 0.0050, 0.0558],
[-0.1692, 0.1254, -0.0014, 0.1250, 0.1932]]],
[[[ 0.0767, 0.0089, -0.1279, -0.1135, 0.1844],
[-0.1334, 0.0288, 0.1844, 0.0812, -0.1021],
[-0.1645, 0.1263, -0.1469, -0.0918, -0.1910],
[ 0.0580, 0.0896, -0.1576, 0.0805, -0.1268],
[ 0.1165, 0.0309, -0.0574, 0.0782, 0.0083]]],
[[[-0.1977, -0.0520, -0.0045, 0.1165, -0.0447],
[-0.0220, 0.1359, 0.0219, -0.1281, -0.0446],
[-0.1952, 0.0164, 0.1802, 0.0125, -0.1155],
[-0.1422, -0.1467, 0.0933, -0.1485, -0.1893],
[-0.0724, 0.0799, -0.1816, 0.1123, -0.0929]]],
[[[-0.0213, -0.1834, -0.0827, 0.0897, 0.1789],
[ 0.1426, -0.0666, 0.0739, -0.0140, -0.0703],
[ 0.1502, -0.0841, -0.0333, 0.0246, -0.0301],
[-0.0028, 0.1863, 0.1322, 0.0346, -0.1045],
[ 0.1198, 0.1824, 0.1617, -0.1592, -0.1136]]],
[[[-0.0311, 0.1311, -0.1360, 0.1307, 0.0552],
[ 0.1491, 0.1264, 0.0093, 0.0053, 0.1045],
[-0.1270, -0.1262, -0.1517, -0.1297, 0.1289],
[-0.0223, 0.0269, -0.1771, 0.0697, -0.0873],
[ 0.1871, 0.1114, -0.1391, 0.0208, -0.0346]]],
[[[-0.1517, 0.0542, 0.0801, -0.1400, -0.1205],
[-0.0011, -0.1382, -0.0382, -0.0469, 0.0242],
[ 0.0768, 0.0658, 0.0869, -0.1353, 0.0962],
[ 0.1768, 0.1289, 0.0925, 0.0991, -0.1173],
[ 0.0856, -0.0907, 0.1015, -0.1604, -0.0407]]]], requires_grad=True)
tensor: tensor([[[[-0.0614, 0.0123, -0.1081, 0.1236, 0.1454],
[-0.1888, 0.0317, 0.0477, -0.1962, -0.1923],
[ 0.1091, 0.0490, -0.1451, -0.2202, -0.0304],
[-0.1762, 0.1958, 0.0447, -0.0126, -0.0172],
[-0.2533, 0.0799, -0.0985, 0.0271, 0.2088]]],
[[[ 0.0831, 0.0297, -0.1840, -0.0885, 0.1750],
[-0.1555, 0.0938, 0.1390, 0.2179, -0.1170],
[-0.1460, 0.1144, -0.1382, -0.1290, -0.1353],
[ 0.1784, 0.0510, -0.1081, 0.0149, -0.1268],
[ 0.1006, 0.0944, -0.0335, 0.0982, 0.0166]]],
[[[-0.1809, 0.0275, 0.0079, 0.1430, -0.0050],
[-0.0176, 0.1644, 0.0566, -0.2271, -0.0801],
[-0.2015, 0.0562, 0.1753, 0.0234, -0.1854],
[-0.1811, -0.1378, 0.1392, -0.1488, -0.2823],
[-0.2067, -0.0193, -0.1639, 0.1127, -0.1008]]],
[[[ 0.0130, -0.2083, -0.1244, 0.0793, 0.2028],
[ 0.2372, -0.0538, 0.1101, 0.0460, -0.0670],
[ 0.1380, -0.0999, 0.0380, 0.1050, -0.0276],
[ 0.0336, 0.1392, 0.0764, 0.0153, -0.1006],
[ 0.1250, 0.1943, 0.2652, -0.0947, -0.1738]]],
[[[-0.0794, 0.0738, -0.0990, 0.1658, 0.1148],
[ 0.2231, 0.0902, 0.0110, 0.0146, 0.1010],
[-0.1262, -0.2206, -0.1288, -0.1480, 0.1115],
[-0.0217, 0.0519, -0.1776, 0.0722, -0.0266],
[ 0.1953, 0.0690, -0.1989, -0.0013, -0.0624]]],
[[[-0.1406, 0.0837, 0.1120, -0.1516, -0.0857],
[ 0.0686, -0.1454, -0.0464, -0.0448, 0.0909],
[ 0.1257, 0.0885, 0.0548, -0.1890, 0.0385],
[ 0.2158, 0.0925, 0.1184, 0.1180, -0.0726],
[ 0.0904, -0.0799, 0.1279, -0.1520, 0.0205]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0725, -0.0897, 0.0513, -0.0100, -0.1734, 0.0447],
requires_grad=True)
tensor: tensor([-0.1067, -0.1109, 0.0153, -0.0572, -0.1683, -0.0096],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., -0., -0., 0., 0.],
[-0., -0., 0., -0., -0.],
[0., -0., -0., -0., 0.],
[-0., 0., -0., 0., 0.],
[-0., 0., -0., 0., 0.]]],
[[[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.]]],
[[[-0., -0., -0., 0., -0.],
[-0., 0., 0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., -0., -0.],
[-0., 0., -0., 0., -0.]]],
[[[-0., -0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., 0., -0.],
[0., 0., 0., -0., -0.]]],
[[[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., -0., -0., -0., 0.],
[-0., 0., -0., 0., -0.],
[0., 0., -0., 0., -0.]]],
[[[-0., 0., 0., -0., -0.],
[-0., -0., -0., -0., 0.],
[0., 0., 0., -0., 0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., -0., -0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-0.0991, -0.0592, -0.1623, 0.0523, 0.1645],
[-0.1893, -0.0215, 0.1007, -0.1699, -0.1765],
[ 0.1281, -0.0728, -0.1345, -0.0970, 0.0558],
[-0.1093, 0.1952, -0.0354, 0.0050, 0.0558],
[-0.1692, 0.1254, -0.0014, 0.1250, 0.1932]]],
[[[ 0.0767, 0.0089, -0.1279, -0.1135, 0.1844],
[-0.1334, 0.0288, 0.1844, 0.0812, -0.1021],
[-0.1645, 0.1263, -0.1469, -0.0918, -0.1910],
[ 0.0580, 0.0896, -0.1576, 0.0805, -0.1268],
[ 0.1165, 0.0309, -0.0574, 0.0782, 0.0083]]],
[[[-0.1977, -0.0520, -0.0045, 0.1165, -0.0447],
[-0.0220, 0.1359, 0.0219, -0.1281, -0.0446],
[-0.1952, 0.0164, 0.1802, 0.0125, -0.1155],
[-0.1422, -0.1467, 0.0933, -0.1485, -0.1893],
[-0.0724, 0.0799, -0.1816, 0.1123, -0.0929]]],
[[[-0.0213, -0.1834, -0.0827, 0.0897, 0.1789],
[ 0.1426, -0.0666, 0.0739, -0.0140, -0.0703],
[ 0.1502, -0.0841, -0.0333, 0.0246, -0.0301],
[-0.0028, 0.1863, 0.1322, 0.0346, -0.1045],
[ 0.1198, 0.1824, 0.1617, -0.1592, -0.1136]]],
[[[-0.0311, 0.1311, -0.1360, 0.1307, 0.0552],
[ 0.1491, 0.1264, 0.0093, 0.0053, 0.1045],
[-0.1270, -0.1262, -0.1517, -0.1297, 0.1289],
[-0.0223, 0.0269, -0.1771, 0.0697, -0.0873],
[ 0.1871, 0.1114, -0.1391, 0.0208, -0.0346]]],
[[[-0.1517, 0.0542, 0.0801, -0.1400, -0.1205],
[-0.0011, -0.1382, -0.0382, -0.0469, 0.0242],
[ 0.0768, 0.0658, 0.0869, -0.1353, 0.0962],
[ 0.1768, 0.1289, 0.0925, 0.0991, -0.1173],
[ 0.0856, -0.0907, 0.1015, -0.1604, -0.0407]]]])
(bias): Normal:
loc: tensor([-0., -0., 0., -0., -0., 0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0725, -0.0897, 0.0513, -0.0100, -0.1734, 0.0447])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
...,
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[ 2.4966e-03, -6.7341e-02, 6.2799e-02, -3.4370e-03, 4.5322e-02],
[ 1.4607e-02, 6.2478e-02, 6.8280e-02, -3.5654e-02, -1.8825e-02],
[ 7.1212e-02, 1.5050e-02, -1.8025e-02, 2.5491e-02, 2.8724e-02],
[-2.2547e-02, 6.0692e-02, 5.2667e-02, -5.7841e-03, 2.6376e-02],
[-4.3343e-02, -5.9860e-02, -1.2973e-02, 6.2583e-03, 8.0137e-02]],
[[-2.1569e-02, -7.1375e-02, 1.2631e-02, -1.8075e-02, -4.7912e-02],
[-2.8027e-02, -2.6252e-02, -5.9709e-02, 6.7074e-02, 2.6864e-02],
[-6.6509e-02, -8.1214e-02, 2.3467e-02, -6.6464e-02, -6.0669e-02],
[ 5.0579e-03, 1.1671e-02, -6.6733e-02, -4.9352e-02, -2.0956e-02],
[ 7.6966e-03, 2.2004e-02, -2.8989e-02, 5.1644e-02, -6.1621e-02]],
[[-5.4252e-02, -8.1522e-02, -6.2672e-02, -6.3098e-02, -5.2104e-02],
[ 2.5544e-02, 3.6878e-02, -4.4395e-02, -4.9102e-02, 4.8144e-02],
[-3.2576e-02, 8.0204e-02, -9.2122e-03, -7.7311e-03, 6.7384e-02],
[-1.5682e-02, 9.5757e-03, -6.5708e-02, -2.7689e-02, -6.5013e-02],
[-6.4995e-02, 2.7979e-02, -1.6400e-03, 4.6764e-02, -4.3071e-02]],
[[ 5.6996e-02, 5.9034e-03, -4.6686e-02, -6.9235e-02, 1.9127e-03],
[ 1.8545e-02, -1.9516e-02, -5.0705e-02, -3.5648e-02, 6.0718e-02],
[-6.7171e-02, -7.3959e-02, 5.4578e-02, -1.1379e-02, 4.7259e-02],
[ 1.8278e-02, 1.0828e-02, 2.9502e-02, -4.8256e-02, 1.5783e-02],
[ 3.0306e-02, -5.6080e-03, 1.1095e-02, -8.0521e-02, 6.7744e-02]],
[[ 6.4667e-03, -3.1020e-02, -5.5962e-02, -4.0058e-02, -3.4994e-02],
[-4.6501e-02, -5.0324e-02, -7.3864e-02, 1.8048e-02, -5.4226e-02],
[-7.8412e-02, -4.8003e-02, -1.1779e-02, -7.8603e-02, -3.8339e-02],
[-2.3014e-02, 5.4240e-02, -4.7970e-02, 2.5364e-02, 3.3201e-02],
[ 6.0505e-02, 6.4193e-02, 7.4750e-04, -2.1983e-02, 7.6744e-02]],
[[-6.7106e-02, 3.3650e-02, -7.8031e-02, 6.1907e-02, -7.8580e-02],
[-3.2164e-02, -4.9644e-02, 5.5871e-02, -1.9880e-02, -5.9899e-02],
[ 7.6375e-02, 1.6907e-02, -3.7618e-02, -4.6346e-02, 4.5166e-02],
[-6.2424e-02, 1.6883e-02, 3.9223e-02, -1.4790e-02, -6.1880e-02],
[ 3.8489e-02, -7.1589e-05, -1.7487e-02, -6.2649e-02, 6.8018e-02]]],
[[[-2.5034e-02, -4.6058e-02, -2.3337e-02, 2.2215e-02, -1.4850e-02],
[-4.3984e-02, 1.4747e-02, -7.9597e-03, -6.3069e-02, -2.2592e-02],
[ 4.3075e-02, 1.9039e-02, 3.0107e-02, 2.3445e-03, -6.9232e-02],
[ 6.4497e-03, -4.7578e-02, 7.6891e-02, -5.1616e-02, 1.2707e-02],
[-3.6706e-02, 7.6937e-02, -5.9667e-02, -8.1446e-02, 1.7545e-02]],
[[-4.4402e-02, 6.4599e-02, 4.1490e-02, -2.8644e-02, -4.8206e-02],
[ 2.6031e-02, 5.7371e-02, 2.0308e-02, -7.4061e-02, -4.4877e-02],
[-2.7835e-02, 5.2467e-02, 4.0908e-02, -8.1401e-02, -6.8110e-02],
[ 7.6093e-02, -9.4644e-03, -5.2426e-03, 7.7898e-03, -1.6856e-02],
[-7.1420e-02, 7.8023e-02, 8.0673e-02, 3.5987e-02, 7.1558e-02]],
[[-2.2731e-02, 6.9536e-03, -3.5987e-02, 7.5870e-02, -7.4783e-02],
[ 3.9431e-02, -3.5946e-02, 3.5435e-02, 3.6262e-02, 1.0886e-02],
[-5.3538e-02, 7.1992e-02, -3.2372e-04, -1.5915e-02, -7.1567e-02],
[ 2.4658e-02, -1.1047e-02, -5.5648e-02, 1.9212e-02, -1.2692e-02],
[-8.4647e-03, -4.5983e-02, -6.7419e-04, 9.8715e-03, -4.6148e-02]],
[[-2.0329e-02, 4.4603e-02, 6.4148e-02, 7.1504e-02, -4.5712e-02],
[ 4.6664e-02, 7.8892e-02, -3.1212e-02, -8.2775e-03, -6.5268e-02],
[-5.3630e-02, -3.4377e-02, 4.4915e-02, -3.5955e-02, 4.4886e-02],
[-6.4155e-02, -1.4317e-02, -6.8112e-02, -2.0220e-02, 3.0197e-02],
[-6.6971e-02, 7.1244e-02, 3.6331e-02, 4.9435e-02, 9.6099e-04]],
[[ 7.0543e-02, -2.7817e-02, 6.0078e-02, 4.7984e-02, -4.3316e-02],
[ 7.6346e-02, 1.2656e-02, -6.2175e-02, 7.6718e-02, 1.8006e-02],
[ 7.4139e-02, -5.7349e-02, -1.0675e-02, -2.4720e-02, -4.4425e-02],
[-7.0388e-02, -6.9265e-02, 2.1314e-02, 4.5109e-02, -7.1493e-02],
[-3.2351e-02, 3.5873e-02, -4.8951e-03, -7.2958e-02, -4.4238e-02]],
[[ 8.1172e-02, -5.8811e-02, 4.0991e-02, 6.9875e-02, -6.0488e-02],
[-6.3614e-02, 2.7591e-02, 4.7980e-02, 8.1415e-02, -3.3863e-02],
[ 2.1980e-02, -2.3815e-02, 1.8170e-02, -2.5447e-02, 2.7371e-02],
[ 8.1534e-02, -3.9964e-02, -7.7124e-02, 5.2303e-02, -2.0802e-02],
[ 6.0792e-02, -7.2300e-02, -2.8772e-02, 5.7216e-02, -3.4328e-02]]],
[[[-2.7990e-03, 1.8808e-02, 6.7455e-04, -6.6091e-02, 6.6982e-02],
[-6.9830e-02, 6.5788e-02, -4.8964e-02, -4.1140e-03, 3.1098e-03],
[-5.4660e-02, 7.1862e-02, 7.1153e-03, -6.5570e-02, -5.1447e-02],
[ 6.0571e-02, 2.0612e-03, 2.3722e-02, 1.5375e-02, -5.1639e-02],
[ 4.7830e-02, -2.5379e-02, 6.9856e-02, 4.5925e-02, 9.6594e-03]],
[[ 2.2661e-02, 6.6682e-03, -4.9346e-02, 2.0883e-03, 3.6717e-02],
[-2.9341e-02, -7.5506e-02, -2.3876e-02, -4.3604e-02, 6.0555e-02],
[-5.9189e-03, -9.2285e-03, -7.6425e-02, 5.5105e-02, -3.4522e-02],
[ 4.5115e-02, 6.4247e-03, 4.4812e-02, 1.9783e-02, 7.9421e-02],
[ 1.4393e-02, 6.9105e-02, -5.1842e-02, -5.6427e-04, -3.0378e-02]],
[[-2.9724e-02, 2.7430e-02, 2.0606e-02, -6.8593e-02, 7.5719e-02],
[-1.1574e-02, -6.1978e-02, 1.5036e-02, 6.5520e-02, -5.5500e-02],
[-5.4038e-03, 1.6463e-03, -1.6045e-02, 5.0524e-02, -6.1660e-02],
[ 7.9241e-02, -6.3365e-02, -7.6820e-02, -7.4036e-02, -3.8586e-02],
[-2.4314e-02, 8.9899e-03, -7.1113e-03, -3.9968e-02, -5.5614e-02]],
[[-2.1929e-02, -1.1871e-02, -4.8945e-02, 2.5946e-02, -6.4562e-03],
[-3.3630e-02, 5.2410e-02, 3.1839e-02, 3.7700e-02, 4.5499e-02],
[ 2.5421e-02, -6.7532e-02, 7.4772e-02, -4.7501e-03, -5.6375e-02],
[ 3.3231e-03, 2.4121e-02, 8.1080e-02, 1.0045e-02, 7.1572e-02],
[ 4.9069e-02, -3.3530e-02, 1.1358e-02, -5.2008e-04, -4.7222e-02]],
[[ 4.6323e-02, -2.6136e-04, -5.2806e-02, -2.4812e-02, 4.8923e-02],
[ 3.7666e-02, -5.7144e-02, -6.4298e-03, -1.7247e-02, 3.1968e-02],
[-6.0489e-02, 2.8763e-02, 2.8486e-02, 3.5758e-02, 1.7142e-02],
[ 8.0486e-02, -3.3536e-02, 3.6080e-02, -2.8437e-02, -1.9267e-03],
[-6.2648e-02, -5.2696e-02, -6.9294e-02, 3.6589e-02, 1.8661e-02]],
[[ 9.3734e-03, 1.9768e-02, -3.5296e-02, -2.7572e-02, 6.0065e-02],
[ 4.1112e-02, -3.4800e-02, -7.9605e-02, 5.3310e-02, -1.9664e-02],
[-1.3619e-02, 2.5598e-02, -2.6082e-02, 1.0369e-02, -6.4816e-02],
[-5.6846e-03, 2.9120e-02, 3.2444e-02, -1.8672e-02, -6.1514e-03],
[-4.0733e-02, 1.6849e-02, 3.8275e-02, -3.8825e-02, -5.8020e-02]]],
...,
[[[ 5.9329e-02, -6.5870e-02, 1.5859e-02, -1.5105e-02, -8.8580e-03],
[-5.0859e-02, 3.3974e-04, 3.0749e-02, 2.3093e-02, 4.2401e-02],
[-5.4646e-02, 7.5729e-02, -5.2260e-02, -4.6122e-03, 5.1098e-02],
[ 7.6842e-02, 1.2823e-02, 3.1960e-03, -3.9222e-02, -2.0589e-02],
[-6.3892e-03, 4.3896e-02, -2.7891e-02, -1.8322e-02, 3.2296e-02]],
[[-4.1220e-02, 8.0820e-02, -4.3842e-02, 7.9302e-02, -4.9983e-02],
[-7.3771e-03, -2.5980e-03, 3.6210e-03, -2.1433e-02, 7.3244e-02],
[-5.0760e-02, -2.5616e-03, 6.2678e-02, 7.3593e-02, 2.8301e-02],
[ 6.7007e-02, -1.9064e-02, 8.3973e-03, -5.9798e-02, -2.6974e-02],
[ 3.2876e-02, 3.8669e-02, -3.8719e-02, -1.0153e-02, 7.5301e-02]],
[[-5.0442e-02, 7.7989e-03, 3.9936e-02, 3.3326e-02, 7.9093e-02],
[-2.9757e-03, 5.9154e-02, 5.3348e-02, 7.3507e-02, -7.6732e-03],
[ 4.6368e-02, 2.1609e-02, 6.4193e-02, 6.0536e-02, -6.3037e-02],
[ 5.3429e-02, -2.5135e-02, -2.1193e-02, -5.1761e-04, 1.9996e-02],
[ 4.3215e-02, -2.7851e-02, -6.4092e-02, 6.4469e-03, 4.8372e-02]],
[[ 5.6178e-02, 3.5691e-02, -8.8206e-03, -3.0902e-02, 4.1369e-02],
[ 3.2037e-03, -7.3717e-02, -7.2683e-02, 3.9951e-02, 2.7550e-02],
[ 7.8738e-02, -7.8033e-02, -5.3960e-02, -3.4911e-02, 6.0757e-02],
[-1.3279e-02, 8.0203e-02, 5.5343e-02, -2.4808e-02, 7.3956e-02],
[ 3.9374e-02, 7.8239e-02, -7.7757e-02, 3.6571e-02, -2.9753e-02]],
[[ 1.4708e-02, -7.2628e-02, 6.9122e-02, -4.9947e-02, -2.0280e-02],
[-4.6593e-02, -7.6654e-02, 3.2037e-02, -1.1588e-02, -4.1092e-03],
[ 6.4820e-02, 1.2277e-02, -5.5943e-02, 3.8722e-02, -2.4028e-02],
[-5.1028e-02, 6.5714e-02, -1.7858e-02, 3.3266e-04, 7.5303e-02],
[-2.6313e-02, 2.8460e-02, 7.5677e-02, 6.7317e-02, 2.4816e-02]],
[[ 7.5268e-02, -1.0377e-02, 8.7081e-03, 2.6532e-02, -3.6530e-02],
[ 3.2920e-02, -5.4807e-03, 5.5366e-02, 6.0630e-02, 7.4726e-02],
[ 6.9652e-03, -4.8734e-02, -5.6073e-02, -6.8096e-02, -3.3923e-02],
[ 3.3389e-03, 5.3901e-02, 7.1177e-02, 5.2428e-02, -6.1103e-03],
[ 6.7079e-02, -2.1961e-02, 2.5835e-02, -7.2561e-02, -1.7097e-02]]],
[[[-1.4495e-02, -5.7205e-02, -6.9848e-02, 2.6035e-02, -6.8173e-02],
[ 7.7289e-03, 1.0050e-02, 2.6973e-02, -3.7391e-02, 4.0037e-02],
[-8.8926e-03, 2.8563e-02, -5.2584e-02, 7.0450e-02, 3.2828e-03],
[ 4.1858e-02, -5.4062e-03, 8.6848e-04, -2.5418e-02, -1.4656e-02],
[ 3.8785e-03, -1.3989e-02, 6.0276e-02, -5.7459e-02, -8.8772e-03]],
[[-2.5420e-02, -4.8067e-02, 4.2092e-02, -1.0464e-04, 2.9935e-03],
[-6.1568e-02, -2.6599e-02, 1.3947e-02, 6.6099e-02, 2.1827e-02],
[-3.0738e-02, 2.9112e-02, -3.2363e-02, 5.0961e-02, -5.7377e-02],
[-7.7591e-02, -1.8587e-03, 4.1255e-02, -7.4086e-02, 2.4900e-02],
[-5.8474e-02, -6.6314e-02, 3.9553e-02, 6.3980e-02, 3.6565e-02]],
[[ 1.8101e-02, 4.8489e-03, -7.8886e-02, -5.7999e-02, 7.0357e-02],
[-4.3515e-02, -4.7323e-02, -5.6187e-02, 2.3487e-02, -4.3308e-02],
[-4.8880e-02, -5.8023e-02, -7.3924e-02, -9.6583e-03, 6.2028e-02],
[ 5.2484e-02, -2.6062e-02, -4.3238e-03, -7.2771e-02, 3.2814e-02],
[ 5.0127e-02, 3.3270e-02, 6.3871e-02, -3.8169e-02, 7.1588e-02]],
[[-6.5380e-02, -7.1597e-02, 8.0692e-02, -5.3291e-02, 4.4471e-02],
[ 7.1552e-02, 4.8723e-02, -6.1756e-02, -4.6115e-02, 7.7099e-02],
[-1.9059e-02, 1.1879e-03, 1.7027e-02, -4.2909e-02, 5.6986e-02],
[-8.7295e-03, -4.8808e-02, -5.7670e-02, 2.4997e-04, -1.1744e-02],
[-1.3926e-02, 3.1644e-02, 1.1507e-02, 8.1024e-02, 7.2102e-03]],
[[ 2.1168e-02, -2.9686e-02, -4.1679e-02, -4.7661e-02, 3.7415e-02],
[ 2.2248e-02, -2.6836e-02, 3.3736e-02, 7.9188e-02, -5.4294e-02],
[-6.3778e-02, -2.1948e-02, -3.8511e-02, 3.4796e-02, 6.7318e-02],
[ 1.9141e-03, -1.8587e-02, 7.3516e-02, -5.1300e-02, 5.7174e-02],
[-3.4286e-02, -8.8658e-03, 2.8017e-02, -6.1345e-02, 7.5492e-02]],
[[-3.9274e-03, 5.5152e-02, -7.5673e-02, 2.3418e-02, 2.4453e-02],
[-4.7555e-02, 3.8749e-02, 7.6313e-02, 1.7370e-02, -3.5157e-02],
[ 2.5302e-03, -2.2002e-02, -3.2479e-02, -2.4231e-02, -3.4979e-02],
[ 5.0821e-02, 4.9669e-02, -3.9780e-03, -5.6009e-02, -7.5135e-02],
[-1.4819e-02, -3.7873e-02, 8.0048e-02, 6.6795e-02, 1.9579e-04]]],
[[[-6.7337e-03, -3.1325e-02, -3.8130e-04, 2.8326e-02, 5.0600e-02],
[ 2.1436e-02, -4.8592e-02, 6.1161e-02, -6.3545e-02, -6.6114e-02],
[ 2.6338e-03, 5.4122e-02, -4.5045e-02, 1.2242e-02, 8.8401e-03],
[-7.7559e-02, -1.9867e-02, 6.5748e-03, 4.9552e-02, -6.9732e-02],
[-3.9603e-02, -2.5227e-02, -7.4894e-02, -4.6245e-02, 3.7202e-02]],
[[ 7.4730e-02, 4.4270e-02, -1.6544e-02, -1.4239e-02, 5.9522e-03],
[ 2.7072e-03, -5.8596e-02, 4.8411e-02, 4.3153e-02, -4.5938e-02],
[ 7.8193e-02, -6.7088e-02, -6.9651e-02, -6.0037e-02, 2.6229e-02],
[-7.0300e-03, -3.4991e-02, -7.6215e-02, 9.9203e-03, 6.9492e-02],
[ 4.9921e-02, -2.0404e-02, 6.9527e-02, 3.0870e-02, 3.4420e-02]],
[[-1.8141e-02, 4.8283e-02, 1.9330e-02, -4.4328e-02, -1.4779e-02],
[-5.6117e-02, 7.9438e-02, 6.8914e-03, 2.9340e-02, 5.3000e-02],
[ 5.4209e-02, -2.0673e-02, 4.3754e-02, -3.3216e-02, -1.8343e-02],
[ 2.5629e-02, 3.6082e-02, 7.0708e-02, 7.3608e-02, 5.9628e-02],
[ 6.0885e-02, 1.6643e-02, 1.6415e-02, -1.0011e-02, 7.8816e-02]],
[[ 7.5320e-02, -5.7523e-02, -4.1853e-02, 1.0916e-02, -3.0991e-02],
[-6.2630e-02, 1.4596e-02, -1.0427e-02, 3.6875e-02, -1.6136e-03],
[-4.4583e-02, -3.0317e-02, -2.2016e-02, -6.6638e-02, 6.7848e-02],
[-5.1430e-02, -3.2587e-02, -1.1150e-03, 3.4666e-03, -3.3550e-02],
[ 3.4297e-02, -7.6208e-02, 7.9594e-02, 2.8101e-02, 5.8770e-02]],
[[-7.3166e-02, 6.3813e-02, 5.1557e-02, -8.4318e-03, 6.9600e-02],
[ 7.9715e-02, 1.2783e-03, -5.8935e-02, -4.6001e-02, 2.3448e-04],
[-2.1650e-02, 7.1989e-02, 5.8779e-02, 3.0808e-02, -5.0626e-02],
[-7.0207e-02, 6.4429e-02, 3.2764e-02, 4.4986e-02, -5.6941e-02],
[-3.9577e-02, 2.8508e-03, 1.5647e-02, -5.7797e-02, 3.3754e-02]],
[[-3.6933e-02, 6.9224e-02, 6.7209e-02, 1.0161e-02, 2.9785e-02],
[ 1.1000e-02, 5.7507e-02, 7.7336e-02, -2.3910e-02, 1.2587e-02],
[-3.8743e-02, 9.7108e-03, 4.9643e-02, 3.2226e-02, -6.9066e-02],
[ 1.7411e-02, -5.1872e-02, -4.3662e-02, -2.2543e-02, 4.4947e-02],
[ 5.7543e-02, -2.1366e-02, -2.5460e-02, 4.8669e-02, 4.8413e-02]]]],
requires_grad=True)
tensor: tensor([[[[ 0.0030, 0.0370, 0.0061, 0.0055, 0.1897],
[ 0.0061, 0.0060, -0.0193, 0.0377, -0.0564],
[-0.0500, -0.0086, 0.0329, 0.0670, 0.0101],
[-0.0847, 0.0830, 0.0771, -0.0188, 0.0843],
[-0.0097, -0.1136, -0.0237, 0.0625, 0.0676]],
[[-0.0559, -0.0980, -0.0725, -0.0112, -0.0243],
[ 0.0962, -0.0593, -0.1873, 0.0279, 0.0782],
[-0.0591, -0.0855, 0.0373, -0.1575, -0.0922],
[ 0.0329, 0.0644, -0.0940, -0.0186, -0.0622],
[-0.0482, 0.0230, 0.0336, -0.0282, -0.1814]],
[[-0.0127, -0.1343, -0.1221, -0.0167, -0.0065],
[ 0.0476, 0.0722, -0.0504, 0.0020, 0.0555],
[-0.0666, 0.1025, -0.0479, 0.0443, 0.0820],
[ 0.0718, 0.0899, -0.0440, -0.0267, -0.0442],
[-0.0810, 0.0078, 0.0804, 0.0138, 0.0176]],
[[ 0.0109, 0.0394, -0.0121, 0.0404, -0.0129],
[ 0.0802, -0.0914, -0.0956, 0.0313, 0.0754],
[-0.0038, -0.0588, -0.0027, -0.0544, -0.0149],
[-0.0293, 0.0581, 0.0832, 0.1311, 0.0171],
[ 0.1106, 0.0351, 0.0081, -0.0740, 0.0295]],
[[ 0.0067, 0.0062, -0.0142, 0.0262, -0.0578],
[-0.0709, 0.0171, -0.0674, -0.0588, 0.0029],
[-0.0295, -0.0527, -0.0683, -0.1257, -0.0551],
[ 0.0356, 0.0503, -0.0011, 0.0599, -0.0024],
[ 0.1453, 0.0082, -0.0518, -0.0353, 0.0338]],
[[ 0.0246, 0.0324, -0.1101, 0.0527, -0.1270],
[-0.1302, -0.0539, 0.0103, 0.0226, 0.0360],
[ 0.0771, 0.0761, -0.0768, -0.0466, 0.0357],
[-0.0788, -0.0133, 0.1923, 0.1031, -0.0516],
[-0.0493, 0.0171, -0.0626, -0.0484, 0.1355]]],
[[[ 0.0296, -0.0681, -0.0929, 0.0268, 0.0190],
[-0.0455, -0.0379, 0.0346, -0.0537, -0.0268],
[-0.0023, -0.0677, -0.0214, -0.0291, -0.1081],
[-0.0020, -0.0793, 0.1388, -0.0793, 0.1056],
[-0.0718, 0.1434, -0.1299, -0.0603, -0.0080]],
[[-0.1244, 0.0348, 0.0574, -0.0776, 0.0472],
[ 0.1165, 0.0356, 0.0074, 0.0251, -0.1098],
[ 0.0321, -0.0125, 0.0018, -0.1240, 0.0353],
[ 0.0994, -0.0156, 0.0453, -0.0251, -0.0393],
[ 0.0301, 0.1138, 0.1745, 0.1075, 0.1367]],
[[-0.0349, -0.0200, -0.0847, 0.0510, 0.0069],
[ 0.0411, 0.0056, 0.0714, 0.1483, 0.0639],
[-0.1383, 0.0720, -0.0054, 0.0541, -0.0916],
[-0.0191, -0.0420, -0.0035, 0.0365, -0.0419],
[-0.0620, -0.0209, 0.0207, 0.0278, -0.0520]],
[[-0.0598, -0.0202, 0.0199, 0.0228, 0.0366],
[ 0.0405, 0.0057, 0.0308, 0.0057, -0.0100],
[-0.0716, -0.0653, -0.0008, -0.0686, -0.0948],
[-0.0877, -0.0224, -0.0703, -0.0918, 0.0621],
[-0.1118, 0.0707, -0.0052, 0.0841, -0.0137]],
[[ 0.0930, 0.0124, -0.0211, 0.0541, -0.1070],
[ 0.0211, 0.0281, -0.0532, 0.0517, -0.0016],
[ 0.0772, 0.0063, -0.0234, 0.0156, -0.0652],
[-0.0644, -0.0647, -0.0120, 0.1149, -0.0564],
[ 0.0767, -0.0070, -0.0011, -0.1237, -0.0066]],
[[ 0.1717, -0.0064, 0.0110, 0.0309, 0.0137],
[-0.0550, 0.0663, 0.0573, 0.1322, -0.0082],
[ 0.0342, -0.0386, -0.0328, -0.0154, 0.0448],
[ 0.0547, -0.0263, -0.1192, 0.0387, -0.0164],
[ 0.0607, -0.0326, 0.1569, 0.0261, -0.0246]]],
[[[ 0.0702, 0.0203, -0.0141, -0.0446, 0.0751],
[-0.0347, 0.1131, -0.1338, -0.0294, 0.0571],
[-0.0677, 0.1090, -0.0009, -0.0625, -0.1291],
[ 0.0354, 0.0200, -0.0549, -0.0156, -0.0812],
[ 0.0049, -0.0936, 0.0989, 0.1016, 0.0297]],
[[ 0.0672, -0.0527, -0.0339, -0.0094, 0.0326],
[ 0.0930, -0.1736, 0.0808, -0.0811, 0.0135],
[-0.0180, -0.0720, -0.1406, -0.0214, -0.0106],
[-0.0289, -0.0049, -0.0305, 0.0357, 0.1551],
[-0.0046, 0.0355, -0.0194, 0.0280, -0.0185]],
[[-0.0964, 0.0292, 0.0972, -0.0208, 0.0972],
[-0.1380, -0.0010, 0.0390, 0.0697, -0.1097],
[-0.0065, -0.0035, -0.0658, 0.0803, -0.1039],
[ 0.0459, -0.0645, -0.0060, -0.0435, -0.0126],
[-0.0245, 0.0109, -0.0870, -0.0560, -0.0631]],
[[-0.0394, -0.0649, -0.0273, 0.0062, 0.0212],
[-0.0515, 0.0840, -0.0245, 0.0682, 0.0838],
[ 0.0320, -0.0481, 0.0371, -0.0384, -0.1137],
[ 0.0538, -0.0604, 0.0587, -0.0396, 0.0582],
[ 0.0630, -0.0448, 0.0376, 0.0443, -0.0243]],
[[ 0.0887, -0.0682, -0.0835, 0.0265, 0.0202],
[-0.0194, -0.0398, -0.0036, -0.0129, 0.0483],
[ 0.0099, 0.0179, -0.0808, 0.0102, -0.0288],
[ 0.0867, -0.0921, -0.1018, -0.0525, -0.0209],
[-0.0754, 0.0110, -0.0890, 0.1195, -0.0101]],
[[ 0.0732, 0.0196, -0.0186, -0.0783, 0.0396],
[ 0.0464, -0.0444, -0.1216, 0.0254, -0.0550],
[ 0.0327, 0.0962, -0.0902, 0.0010, -0.0489],
[-0.0611, 0.0270, 0.0678, 0.0009, 0.0287],
[-0.1257, -0.0570, 0.0649, -0.0547, -0.0710]]],
...,
[[[-0.0368, -0.1070, -0.0377, -0.0151, -0.0040],
[-0.0080, -0.0498, -0.0739, 0.0521, -0.0047],
[-0.1140, 0.1242, -0.0421, 0.0485, 0.1163],
[ 0.0760, 0.1164, 0.0229, -0.0359, 0.0471],
[-0.0256, -0.0217, 0.0007, -0.0934, -0.0431]],
[[-0.0146, 0.1180, -0.0894, 0.0662, -0.1869],
[ 0.0360, -0.0623, -0.0055, 0.0395, 0.0304],
[-0.0796, 0.0298, 0.1031, 0.1473, 0.0148],
[ 0.0487, -0.0528, 0.0374, -0.0222, -0.0536],
[ 0.0577, 0.0646, 0.0502, -0.0029, 0.0426]],
[[-0.0677, -0.0192, 0.0540, 0.0301, 0.0904],
[-0.0196, 0.1227, 0.1267, 0.0565, 0.0055],
[-0.0280, -0.0297, 0.0549, -0.0250, -0.1382],
[-0.1008, 0.0365, -0.0657, 0.0342, 0.0065],
[ 0.0315, 0.0172, -0.0969, 0.0013, 0.0104]],
[[ 0.0558, 0.0658, -0.0555, -0.0110, 0.0538],
[ 0.0588, -0.1003, -0.1523, -0.0112, 0.0897],
[ 0.0105, 0.0035, -0.0389, 0.0006, 0.1152],
[ 0.0710, 0.1802, 0.0374, -0.0865, 0.0942],
[ 0.0541, 0.0844, -0.0715, 0.0436, -0.0691]],
[[ 0.0120, -0.0954, 0.0155, 0.0152, -0.1001],
[-0.0493, -0.0088, 0.0529, -0.0759, -0.0161],
[ 0.0656, 0.0814, 0.0555, 0.0120, -0.0367],
[-0.0500, 0.1128, 0.0131, -0.0284, 0.0928],
[ 0.0287, 0.0125, 0.0348, 0.0482, -0.0143]],
[[ 0.0911, 0.0056, -0.0845, 0.0777, -0.0683],
[ 0.0142, -0.0229, 0.0458, 0.0650, 0.0809],
[ 0.0135, -0.0528, -0.1453, -0.0548, -0.0564],
[ 0.0639, 0.0091, 0.0875, 0.0836, -0.0286],
[ 0.0162, -0.0257, -0.0204, -0.1649, -0.0647]]],
[[[-0.0013, -0.0171, -0.1675, -0.0056, -0.0280],
[ 0.0808, 0.0189, 0.1015, -0.0198, 0.0340],
[-0.1070, -0.0596, -0.0489, 0.1452, -0.0884],
[-0.0348, -0.0879, -0.0137, 0.0157, 0.0132],
[ 0.0066, 0.0364, 0.1176, -0.0999, 0.0079]],
[[ 0.0413, -0.0700, 0.0778, 0.0368, 0.0522],
[-0.0291, -0.0256, 0.0869, 0.0585, 0.0262],
[-0.1031, -0.0223, -0.0122, 0.0745, 0.0264],
[-0.1007, 0.0369, -0.0398, -0.0670, 0.0970],
[-0.1686, -0.0153, -0.0234, 0.0754, 0.0909]],
[[ 0.0407, 0.0859, -0.1172, -0.0371, 0.0481],
[-0.0890, 0.0338, -0.0192, 0.0476, -0.0217],
[ 0.0182, -0.0838, -0.0416, -0.0070, -0.0653],
[-0.0378, -0.0439, 0.0132, -0.0796, -0.0450],
[ 0.0619, -0.0380, 0.0575, -0.0631, 0.1044]],
[[ 0.0278, -0.1035, 0.0012, -0.0706, 0.0558],
[ 0.0631, 0.1811, -0.0200, 0.0622, 0.0970],
[ 0.0673, -0.0611, 0.0469, 0.0018, 0.0879],
[-0.0202, -0.0397, -0.0025, 0.0565, -0.0940],
[-0.0053, 0.0735, 0.0329, 0.0304, 0.0019]],
[[-0.0521, -0.0473, -0.0060, -0.0507, 0.0482],
[-0.0048, -0.0338, -0.0170, 0.1147, -0.0254],
[-0.0622, -0.0311, -0.0690, 0.0485, 0.0562],
[ 0.0360, 0.0941, 0.1680, -0.0755, 0.0932],
[-0.0503, -0.0668, -0.0178, -0.1147, 0.0155]],
[[-0.0955, 0.0673, -0.1261, -0.0133, 0.0173],
[-0.0105, 0.0789, 0.0800, 0.0514, -0.0300],
[ 0.0112, 0.0047, 0.0450, -0.0191, -0.1006],
[ 0.1001, 0.0414, 0.0202, 0.0338, -0.0103],
[-0.1475, -0.1024, 0.0735, -0.0058, 0.0517]]],
[[[-0.0090, -0.0069, -0.0367, -0.0155, 0.0832],
[-0.0323, 0.0123, 0.0614, 0.0114, -0.0722],
[-0.0454, -0.0036, -0.0285, 0.1304, 0.0580],
[-0.0847, 0.0040, 0.0986, 0.0794, -0.0332],
[-0.0672, -0.0382, -0.0653, -0.0962, 0.1406]],
[[ 0.1578, 0.0159, -0.0342, 0.0398, -0.0326],
[ 0.0567, -0.1326, 0.0710, 0.0133, -0.0371],
[ 0.0955, -0.1701, -0.1302, -0.0422, -0.0215],
[-0.0294, -0.0221, -0.1544, 0.0118, 0.1777],
[ 0.0897, -0.1050, 0.0651, 0.0025, -0.0192]],
[[-0.0447, 0.0564, 0.0178, -0.0194, -0.0053],
[-0.0803, 0.1073, -0.0077, 0.0310, 0.1199],
[ 0.0350, 0.2188, 0.0568, -0.1098, -0.1173],
[ 0.0527, 0.1389, 0.0367, 0.1086, 0.0371],
[ 0.0656, 0.0234, -0.0382, -0.0206, 0.2017]],
[[ 0.0579, -0.0759, -0.0199, 0.0364, -0.0006],
[-0.0441, -0.0335, 0.0095, 0.0706, 0.0474],
[-0.0435, -0.0494, -0.0207, 0.0210, 0.0661],
[-0.1438, -0.0200, 0.0809, -0.0481, -0.0648],
[ 0.0745, -0.0318, 0.0249, 0.0515, 0.0518]],
[[-0.0142, 0.0744, -0.0306, -0.0203, 0.1105],
[ 0.0364, -0.0508, -0.0803, -0.0077, 0.0171],
[-0.0659, 0.1150, 0.0811, -0.0835, -0.0413],
[-0.0712, 0.0401, -0.0108, 0.0200, -0.0529],
[-0.0456, -0.0428, 0.0376, -0.0927, -0.0055]],
[[ 0.1133, 0.0363, 0.0195, 0.0877, 0.0687],
[-0.0174, 0.0007, 0.1192, -0.0410, -0.0328],
[-0.1145, -0.0187, 0.0342, 0.0645, -0.0715],
[-0.0164, -0.0221, -0.0441, 0.0314, -0.0280],
[ 0.0618, -0.0397, -0.0189, 0.0394, 0.0238]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0396, -0.0348, 0.0458, 0.0297, -0.0694, 0.0213, -0.0060, -0.0400,
-0.0470, 0.0165, 0.0250, 0.0796, 0.0477, 0.0254, 0.0719, 0.0583],
requires_grad=True)
tensor: tensor([ 0.1063, -0.0750, -0.0130, 0.1064, -0.0199, 0.0071, -0.0617, -0.0020,
-0.0503, -0.0594, 0.0255, 0.0815, -0.0090, 0.1249, 0.0318, 0.1622],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[0., -0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., 0.]],
[[-0., -0., 0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., 0., -0.]],
[[-0., -0., -0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., -0., -0., -0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., 0., -0., 0.],
[0., 0., 0., -0., 0.],
[0., -0., 0., -0., 0.]],
[[0., -0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., -0., 0.]],
[[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., -0., -0., -0., 0.]]],
[[[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[-0., 0., -0., -0., 0.]],
[[-0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., 0., 0., 0.]],
[[-0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.]],
[[-0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[0., 0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., -0., -0.]],
[[0., -0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., -0., 0., -0.],
[0., -0., -0., 0., -0.]]],
[[[-0., 0., 0., -0., 0.],
[-0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., 0., 0.]],
[[0., 0., -0., 0., 0.],
[-0., -0., -0., -0., 0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., -0., -0., -0.]],
[[-0., 0., 0., -0., 0.],
[-0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., -0., -0., -0.],
[-0., 0., -0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.]],
[[0., -0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[-0., -0., -0., 0., 0.]],
[[0., 0., -0., -0., 0.],
[0., -0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[-0., 0., 0., -0., -0.],
[-0., 0., 0., -0., -0.]]],
...,
[[[0., -0., 0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., -0., 0.]],
[[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., -0., -0., 0.]],
[[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., -0.],
[0., 0., 0., 0., -0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., 0., 0.]],
[[0., 0., -0., -0., 0.],
[0., -0., -0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., -0., 0., -0.]],
[[0., -0., 0., -0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., 0., -0.],
[-0., 0., -0., 0., 0.],
[-0., 0., 0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[0., -0., 0., 0., 0.],
[0., -0., -0., -0., -0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., -0., -0.]]],
[[[-0., -0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., -0., 0., -0., -0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.],
[-0., 0., -0., 0., -0.],
[-0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.]],
[[0., 0., -0., -0., 0.],
[-0., -0., -0., 0., -0.],
[-0., -0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[0., 0., 0., -0., 0.]],
[[-0., -0., 0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., -0.],
[-0., 0., 0., 0., 0.]],
[[0., -0., -0., -0., 0.],
[0., -0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[0., -0., 0., -0., 0.],
[-0., -0., 0., -0., 0.]],
[[-0., 0., -0., 0., 0.],
[-0., 0., 0., 0., -0.],
[0., -0., -0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[-0., -0., -0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., -0., 0., 0., -0.],
[-0., -0., -0., -0., 0.]],
[[0., 0., -0., -0., 0.],
[0., -0., 0., 0., -0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., -0., 0., 0., 0.]],
[[-0., 0., 0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., -0., 0.]],
[[0., -0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[-0., -0., -0., -0., 0.],
[-0., -0., -0., 0., -0.],
[0., -0., 0., 0., 0.]],
[[-0., 0., 0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., -0., 0.]],
[[-0., 0., 0., 0., 0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., 0., -0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., 0., 0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[ 2.4966e-03, -6.7341e-02, 6.2799e-02, -3.4370e-03, 4.5322e-02],
[ 1.4607e-02, 6.2478e-02, 6.8280e-02, -3.5654e-02, -1.8825e-02],
[ 7.1212e-02, 1.5050e-02, -1.8025e-02, 2.5491e-02, 2.8724e-02],
[-2.2547e-02, 6.0692e-02, 5.2667e-02, -5.7841e-03, 2.6376e-02],
[-4.3343e-02, -5.9860e-02, -1.2973e-02, 6.2583e-03, 8.0137e-02]],
[[-2.1569e-02, -7.1375e-02, 1.2631e-02, -1.8075e-02, -4.7912e-02],
[-2.8027e-02, -2.6252e-02, -5.9709e-02, 6.7074e-02, 2.6864e-02],
[-6.6509e-02, -8.1214e-02, 2.3467e-02, -6.6464e-02, -6.0669e-02],
[ 5.0579e-03, 1.1671e-02, -6.6733e-02, -4.9352e-02, -2.0956e-02],
[ 7.6966e-03, 2.2004e-02, -2.8989e-02, 5.1644e-02, -6.1621e-02]],
[[-5.4252e-02, -8.1522e-02, -6.2672e-02, -6.3098e-02, -5.2104e-02],
[ 2.5544e-02, 3.6878e-02, -4.4395e-02, -4.9102e-02, 4.8144e-02],
[-3.2576e-02, 8.0204e-02, -9.2122e-03, -7.7311e-03, 6.7384e-02],
[-1.5682e-02, 9.5757e-03, -6.5708e-02, -2.7689e-02, -6.5013e-02],
[-6.4995e-02, 2.7979e-02, -1.6400e-03, 4.6764e-02, -4.3071e-02]],
[[ 5.6996e-02, 5.9034e-03, -4.6686e-02, -6.9235e-02, 1.9127e-03],
[ 1.8545e-02, -1.9516e-02, -5.0705e-02, -3.5648e-02, 6.0718e-02],
[-6.7171e-02, -7.3959e-02, 5.4578e-02, -1.1379e-02, 4.7259e-02],
[ 1.8278e-02, 1.0828e-02, 2.9502e-02, -4.8256e-02, 1.5783e-02],
[ 3.0306e-02, -5.6080e-03, 1.1095e-02, -8.0521e-02, 6.7744e-02]],
[[ 6.4667e-03, -3.1020e-02, -5.5962e-02, -4.0058e-02, -3.4994e-02],
[-4.6501e-02, -5.0324e-02, -7.3864e-02, 1.8048e-02, -5.4226e-02],
[-7.8412e-02, -4.8003e-02, -1.1779e-02, -7.8603e-02, -3.8339e-02],
[-2.3014e-02, 5.4240e-02, -4.7970e-02, 2.5364e-02, 3.3201e-02],
[ 6.0505e-02, 6.4193e-02, 7.4750e-04, -2.1983e-02, 7.6744e-02]],
[[-6.7106e-02, 3.3650e-02, -7.8031e-02, 6.1907e-02, -7.8580e-02],
[-3.2164e-02, -4.9644e-02, 5.5871e-02, -1.9880e-02, -5.9899e-02],
[ 7.6375e-02, 1.6907e-02, -3.7618e-02, -4.6346e-02, 4.5166e-02],
[-6.2424e-02, 1.6883e-02, 3.9223e-02, -1.4790e-02, -6.1880e-02],
[ 3.8489e-02, -7.1589e-05, -1.7487e-02, -6.2649e-02, 6.8018e-02]]],
[[[-2.5034e-02, -4.6058e-02, -2.3337e-02, 2.2215e-02, -1.4850e-02],
[-4.3984e-02, 1.4747e-02, -7.9597e-03, -6.3069e-02, -2.2592e-02],
[ 4.3075e-02, 1.9039e-02, 3.0107e-02, 2.3445e-03, -6.9232e-02],
[ 6.4497e-03, -4.7578e-02, 7.6891e-02, -5.1616e-02, 1.2707e-02],
[-3.6706e-02, 7.6937e-02, -5.9667e-02, -8.1446e-02, 1.7545e-02]],
[[-4.4402e-02, 6.4599e-02, 4.1490e-02, -2.8644e-02, -4.8206e-02],
[ 2.6031e-02, 5.7371e-02, 2.0308e-02, -7.4061e-02, -4.4877e-02],
[-2.7835e-02, 5.2467e-02, 4.0908e-02, -8.1401e-02, -6.8110e-02],
[ 7.6093e-02, -9.4644e-03, -5.2426e-03, 7.7898e-03, -1.6856e-02],
[-7.1420e-02, 7.8023e-02, 8.0673e-02, 3.5987e-02, 7.1558e-02]],
[[-2.2731e-02, 6.9536e-03, -3.5987e-02, 7.5870e-02, -7.4783e-02],
[ 3.9431e-02, -3.5946e-02, 3.5435e-02, 3.6262e-02, 1.0886e-02],
[-5.3538e-02, 7.1992e-02, -3.2372e-04, -1.5915e-02, -7.1567e-02],
[ 2.4658e-02, -1.1047e-02, -5.5648e-02, 1.9212e-02, -1.2692e-02],
[-8.4647e-03, -4.5983e-02, -6.7419e-04, 9.8715e-03, -4.6148e-02]],
[[-2.0329e-02, 4.4603e-02, 6.4148e-02, 7.1504e-02, -4.5712e-02],
[ 4.6664e-02, 7.8892e-02, -3.1212e-02, -8.2775e-03, -6.5268e-02],
[-5.3630e-02, -3.4377e-02, 4.4915e-02, -3.5955e-02, 4.4886e-02],
[-6.4155e-02, -1.4317e-02, -6.8112e-02, -2.0220e-02, 3.0197e-02],
[-6.6971e-02, 7.1244e-02, 3.6331e-02, 4.9435e-02, 9.6099e-04]],
[[ 7.0543e-02, -2.7817e-02, 6.0078e-02, 4.7984e-02, -4.3316e-02],
[ 7.6346e-02, 1.2656e-02, -6.2175e-02, 7.6718e-02, 1.8006e-02],
[ 7.4139e-02, -5.7349e-02, -1.0675e-02, -2.4720e-02, -4.4425e-02],
[-7.0388e-02, -6.9265e-02, 2.1314e-02, 4.5109e-02, -7.1493e-02],
[-3.2351e-02, 3.5873e-02, -4.8951e-03, -7.2958e-02, -4.4238e-02]],
[[ 8.1172e-02, -5.8811e-02, 4.0991e-02, 6.9875e-02, -6.0488e-02],
[-6.3614e-02, 2.7591e-02, 4.7980e-02, 8.1415e-02, -3.3863e-02],
[ 2.1980e-02, -2.3815e-02, 1.8170e-02, -2.5447e-02, 2.7371e-02],
[ 8.1534e-02, -3.9964e-02, -7.7124e-02, 5.2303e-02, -2.0802e-02],
[ 6.0792e-02, -7.2300e-02, -2.8772e-02, 5.7216e-02, -3.4328e-02]]],
[[[-2.7990e-03, 1.8808e-02, 6.7455e-04, -6.6091e-02, 6.6982e-02],
[-6.9830e-02, 6.5788e-02, -4.8964e-02, -4.1140e-03, 3.1098e-03],
[-5.4660e-02, 7.1862e-02, 7.1153e-03, -6.5570e-02, -5.1447e-02],
[ 6.0571e-02, 2.0612e-03, 2.3722e-02, 1.5375e-02, -5.1639e-02],
[ 4.7830e-02, -2.5379e-02, 6.9856e-02, 4.5925e-02, 9.6594e-03]],
[[ 2.2661e-02, 6.6682e-03, -4.9346e-02, 2.0883e-03, 3.6717e-02],
[-2.9341e-02, -7.5506e-02, -2.3876e-02, -4.3604e-02, 6.0555e-02],
[-5.9189e-03, -9.2285e-03, -7.6425e-02, 5.5105e-02, -3.4522e-02],
[ 4.5115e-02, 6.4247e-03, 4.4812e-02, 1.9783e-02, 7.9421e-02],
[ 1.4393e-02, 6.9105e-02, -5.1842e-02, -5.6427e-04, -3.0378e-02]],
[[-2.9724e-02, 2.7430e-02, 2.0606e-02, -6.8593e-02, 7.5719e-02],
[-1.1574e-02, -6.1978e-02, 1.5036e-02, 6.5520e-02, -5.5500e-02],
[-5.4038e-03, 1.6463e-03, -1.6045e-02, 5.0524e-02, -6.1660e-02],
[ 7.9241e-02, -6.3365e-02, -7.6820e-02, -7.4036e-02, -3.8586e-02],
[-2.4314e-02, 8.9899e-03, -7.1113e-03, -3.9968e-02, -5.5614e-02]],
[[-2.1929e-02, -1.1871e-02, -4.8945e-02, 2.5946e-02, -6.4562e-03],
[-3.3630e-02, 5.2410e-02, 3.1839e-02, 3.7700e-02, 4.5499e-02],
[ 2.5421e-02, -6.7532e-02, 7.4772e-02, -4.7501e-03, -5.6375e-02],
[ 3.3231e-03, 2.4121e-02, 8.1080e-02, 1.0045e-02, 7.1572e-02],
[ 4.9069e-02, -3.3530e-02, 1.1358e-02, -5.2008e-04, -4.7222e-02]],
[[ 4.6323e-02, -2.6136e-04, -5.2806e-02, -2.4812e-02, 4.8923e-02],
[ 3.7666e-02, -5.7144e-02, -6.4298e-03, -1.7247e-02, 3.1968e-02],
[-6.0489e-02, 2.8763e-02, 2.8486e-02, 3.5758e-02, 1.7142e-02],
[ 8.0486e-02, -3.3536e-02, 3.6080e-02, -2.8437e-02, -1.9267e-03],
[-6.2648e-02, -5.2696e-02, -6.9294e-02, 3.6589e-02, 1.8661e-02]],
[[ 9.3734e-03, 1.9768e-02, -3.5296e-02, -2.7572e-02, 6.0065e-02],
[ 4.1112e-02, -3.4800e-02, -7.9605e-02, 5.3310e-02, -1.9664e-02],
[-1.3619e-02, 2.5598e-02, -2.6082e-02, 1.0369e-02, -6.4816e-02],
[-5.6846e-03, 2.9120e-02, 3.2444e-02, -1.8672e-02, -6.1514e-03],
[-4.0733e-02, 1.6849e-02, 3.8275e-02, -3.8825e-02, -5.8020e-02]]],
...,
[[[ 5.9329e-02, -6.5870e-02, 1.5859e-02, -1.5105e-02, -8.8580e-03],
[-5.0859e-02, 3.3974e-04, 3.0749e-02, 2.3093e-02, 4.2401e-02],
[-5.4646e-02, 7.5729e-02, -5.2260e-02, -4.6122e-03, 5.1098e-02],
[ 7.6842e-02, 1.2823e-02, 3.1960e-03, -3.9222e-02, -2.0589e-02],
[-6.3892e-03, 4.3896e-02, -2.7891e-02, -1.8322e-02, 3.2296e-02]],
[[-4.1220e-02, 8.0820e-02, -4.3842e-02, 7.9302e-02, -4.9983e-02],
[-7.3771e-03, -2.5980e-03, 3.6210e-03, -2.1433e-02, 7.3244e-02],
[-5.0760e-02, -2.5616e-03, 6.2678e-02, 7.3593e-02, 2.8301e-02],
[ 6.7007e-02, -1.9064e-02, 8.3973e-03, -5.9798e-02, -2.6974e-02],
[ 3.2876e-02, 3.8669e-02, -3.8719e-02, -1.0153e-02, 7.5301e-02]],
[[-5.0442e-02, 7.7989e-03, 3.9936e-02, 3.3326e-02, 7.9093e-02],
[-2.9757e-03, 5.9154e-02, 5.3348e-02, 7.3507e-02, -7.6732e-03],
[ 4.6368e-02, 2.1609e-02, 6.4193e-02, 6.0536e-02, -6.3037e-02],
[ 5.3429e-02, -2.5135e-02, -2.1193e-02, -5.1761e-04, 1.9996e-02],
[ 4.3215e-02, -2.7851e-02, -6.4092e-02, 6.4469e-03, 4.8372e-02]],
[[ 5.6178e-02, 3.5691e-02, -8.8206e-03, -3.0902e-02, 4.1369e-02],
[ 3.2037e-03, -7.3717e-02, -7.2683e-02, 3.9951e-02, 2.7550e-02],
[ 7.8738e-02, -7.8033e-02, -5.3960e-02, -3.4911e-02, 6.0757e-02],
[-1.3279e-02, 8.0203e-02, 5.5343e-02, -2.4808e-02, 7.3956e-02],
[ 3.9374e-02, 7.8239e-02, -7.7757e-02, 3.6571e-02, -2.9753e-02]],
[[ 1.4708e-02, -7.2628e-02, 6.9122e-02, -4.9947e-02, -2.0280e-02],
[-4.6593e-02, -7.6654e-02, 3.2037e-02, -1.1588e-02, -4.1092e-03],
[ 6.4820e-02, 1.2277e-02, -5.5943e-02, 3.8722e-02, -2.4028e-02],
[-5.1028e-02, 6.5714e-02, -1.7858e-02, 3.3266e-04, 7.5303e-02],
[-2.6313e-02, 2.8460e-02, 7.5677e-02, 6.7317e-02, 2.4816e-02]],
[[ 7.5268e-02, -1.0377e-02, 8.7081e-03, 2.6532e-02, -3.6530e-02],
[ 3.2920e-02, -5.4807e-03, 5.5366e-02, 6.0630e-02, 7.4726e-02],
[ 6.9652e-03, -4.8734e-02, -5.6073e-02, -6.8096e-02, -3.3923e-02],
[ 3.3389e-03, 5.3901e-02, 7.1177e-02, 5.2428e-02, -6.1103e-03],
[ 6.7079e-02, -2.1961e-02, 2.5835e-02, -7.2561e-02, -1.7097e-02]]],
[[[-1.4495e-02, -5.7205e-02, -6.9848e-02, 2.6035e-02, -6.8173e-02],
[ 7.7289e-03, 1.0050e-02, 2.6973e-02, -3.7391e-02, 4.0037e-02],
[-8.8926e-03, 2.8563e-02, -5.2584e-02, 7.0450e-02, 3.2828e-03],
[ 4.1858e-02, -5.4062e-03, 8.6848e-04, -2.5418e-02, -1.4656e-02],
[ 3.8785e-03, -1.3989e-02, 6.0276e-02, -5.7459e-02, -8.8772e-03]],
[[-2.5420e-02, -4.8067e-02, 4.2092e-02, -1.0464e-04, 2.9935e-03],
[-6.1568e-02, -2.6599e-02, 1.3947e-02, 6.6099e-02, 2.1827e-02],
[-3.0738e-02, 2.9112e-02, -3.2363e-02, 5.0961e-02, -5.7377e-02],
[-7.7591e-02, -1.8587e-03, 4.1255e-02, -7.4086e-02, 2.4900e-02],
[-5.8474e-02, -6.6314e-02, 3.9553e-02, 6.3980e-02, 3.6565e-02]],
[[ 1.8101e-02, 4.8489e-03, -7.8886e-02, -5.7999e-02, 7.0357e-02],
[-4.3515e-02, -4.7323e-02, -5.6187e-02, 2.3487e-02, -4.3308e-02],
[-4.8880e-02, -5.8023e-02, -7.3924e-02, -9.6583e-03, 6.2028e-02],
[ 5.2484e-02, -2.6062e-02, -4.3238e-03, -7.2771e-02, 3.2814e-02],
[ 5.0127e-02, 3.3270e-02, 6.3871e-02, -3.8169e-02, 7.1588e-02]],
[[-6.5380e-02, -7.1597e-02, 8.0692e-02, -5.3291e-02, 4.4471e-02],
[ 7.1552e-02, 4.8723e-02, -6.1756e-02, -4.6115e-02, 7.7099e-02],
[-1.9059e-02, 1.1879e-03, 1.7027e-02, -4.2909e-02, 5.6986e-02],
[-8.7295e-03, -4.8808e-02, -5.7670e-02, 2.4997e-04, -1.1744e-02],
[-1.3926e-02, 3.1644e-02, 1.1507e-02, 8.1024e-02, 7.2102e-03]],
[[ 2.1168e-02, -2.9686e-02, -4.1679e-02, -4.7661e-02, 3.7415e-02],
[ 2.2248e-02, -2.6836e-02, 3.3736e-02, 7.9188e-02, -5.4294e-02],
[-6.3778e-02, -2.1948e-02, -3.8511e-02, 3.4796e-02, 6.7318e-02],
[ 1.9141e-03, -1.8587e-02, 7.3516e-02, -5.1300e-02, 5.7174e-02],
[-3.4286e-02, -8.8658e-03, 2.8017e-02, -6.1345e-02, 7.5492e-02]],
[[-3.9274e-03, 5.5152e-02, -7.5673e-02, 2.3418e-02, 2.4453e-02],
[-4.7555e-02, 3.8749e-02, 7.6313e-02, 1.7370e-02, -3.5157e-02],
[ 2.5302e-03, -2.2002e-02, -3.2479e-02, -2.4231e-02, -3.4979e-02],
[ 5.0821e-02, 4.9669e-02, -3.9780e-03, -5.6009e-02, -7.5135e-02],
[-1.4819e-02, -3.7873e-02, 8.0048e-02, 6.6795e-02, 1.9579e-04]]],
[[[-6.7337e-03, -3.1325e-02, -3.8130e-04, 2.8326e-02, 5.0600e-02],
[ 2.1436e-02, -4.8592e-02, 6.1161e-02, -6.3545e-02, -6.6114e-02],
[ 2.6338e-03, 5.4122e-02, -4.5045e-02, 1.2242e-02, 8.8401e-03],
[-7.7559e-02, -1.9867e-02, 6.5748e-03, 4.9552e-02, -6.9732e-02],
[-3.9603e-02, -2.5227e-02, -7.4894e-02, -4.6245e-02, 3.7202e-02]],
[[ 7.4730e-02, 4.4270e-02, -1.6544e-02, -1.4239e-02, 5.9522e-03],
[ 2.7072e-03, -5.8596e-02, 4.8411e-02, 4.3153e-02, -4.5938e-02],
[ 7.8193e-02, -6.7088e-02, -6.9651e-02, -6.0037e-02, 2.6229e-02],
[-7.0300e-03, -3.4991e-02, -7.6215e-02, 9.9203e-03, 6.9492e-02],
[ 4.9921e-02, -2.0404e-02, 6.9527e-02, 3.0870e-02, 3.4420e-02]],
[[-1.8141e-02, 4.8283e-02, 1.9330e-02, -4.4328e-02, -1.4779e-02],
[-5.6117e-02, 7.9438e-02, 6.8914e-03, 2.9340e-02, 5.3000e-02],
[ 5.4209e-02, -2.0673e-02, 4.3754e-02, -3.3216e-02, -1.8343e-02],
[ 2.5629e-02, 3.6082e-02, 7.0708e-02, 7.3608e-02, 5.9628e-02],
[ 6.0885e-02, 1.6643e-02, 1.6415e-02, -1.0011e-02, 7.8816e-02]],
[[ 7.5320e-02, -5.7523e-02, -4.1853e-02, 1.0916e-02, -3.0991e-02],
[-6.2630e-02, 1.4596e-02, -1.0427e-02, 3.6875e-02, -1.6136e-03],
[-4.4583e-02, -3.0317e-02, -2.2016e-02, -6.6638e-02, 6.7848e-02],
[-5.1430e-02, -3.2587e-02, -1.1150e-03, 3.4666e-03, -3.3550e-02],
[ 3.4297e-02, -7.6208e-02, 7.9594e-02, 2.8101e-02, 5.8770e-02]],
[[-7.3166e-02, 6.3813e-02, 5.1557e-02, -8.4318e-03, 6.9600e-02],
[ 7.9715e-02, 1.2783e-03, -5.8935e-02, -4.6001e-02, 2.3448e-04],
[-2.1650e-02, 7.1989e-02, 5.8779e-02, 3.0808e-02, -5.0626e-02],
[-7.0207e-02, 6.4429e-02, 3.2764e-02, 4.4986e-02, -5.6941e-02],
[-3.9577e-02, 2.8508e-03, 1.5647e-02, -5.7797e-02, 3.3754e-02]],
[[-3.6933e-02, 6.9224e-02, 6.7209e-02, 1.0161e-02, 2.9785e-02],
[ 1.1000e-02, 5.7507e-02, 7.7336e-02, -2.3910e-02, 1.2587e-02],
[-3.8743e-02, 9.7108e-03, 4.9643e-02, 3.2226e-02, -6.9066e-02],
[ 1.7411e-02, -5.1872e-02, -4.3662e-02, -2.2543e-02, 4.4947e-02],
[ 5.7543e-02, -2.1366e-02, -2.5460e-02, 4.8669e-02, 4.8413e-02]]]])
(bias): Normal:
loc: tensor([0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0396, -0.0348, 0.0458, 0.0297, -0.0694, 0.0213, -0.0060, -0.0400,
-0.0470, 0.0165, 0.0250, 0.0796, 0.0477, 0.0254, 0.0719, 0.0583])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[-0.0091, 0.0136, -0.0051, ..., -0.0231, 0.0462, -0.0289],
[-0.0265, -0.0483, -0.0282, ..., -0.0270, 0.0013, 0.0135],
[-0.0033, -0.0453, 0.0016, ..., -0.0309, -0.0074, -0.0022],
...,
[ 0.0203, 0.0274, 0.0125, ..., -0.0139, 0.0251, 0.0211],
[-0.0178, 0.0492, -0.0229, ..., -0.0063, -0.0401, 0.0024],
[-0.0033, -0.0229, 0.0245, ..., 0.0420, 0.0423, -0.0412]],
requires_grad=True)
tensor: tensor([[ 0.0163, 0.0648, 0.0315, ..., -0.0154, 0.1702, -0.0326],
[-0.0761, -0.0296, -0.0542, ..., -0.0453, 0.0576, -0.0302],
[ 0.0159, -0.0278, -0.0281, ..., -0.0069, 0.0057, -0.0484],
...,
[-0.0042, 0.0008, 0.0761, ..., -0.0470, 0.0703, 0.0368],
[-0.0550, 0.1441, 0.0373, ..., 0.0340, -0.0733, 0.0047],
[-0.1526, -0.0554, -0.0007, ..., 0.0040, 0.0825, -0.0662]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0058, -0.0242, 0.0295, -0.0301, -0.0095, -0.0331, -0.0259, -0.0153,
0.0069, 0.0327, 0.0171, -0.0467, -0.0340, 0.0456, 0.0205, 0.0115,
0.0437, -0.0400, -0.0337, 0.0402, 0.0005, 0.0498, 0.0019, -0.0428,
0.0204, -0.0429, -0.0188, -0.0454, -0.0102, 0.0484, -0.0102, 0.0098,
-0.0489, -0.0116, 0.0008, 0.0365, 0.0245, 0.0127, 0.0432, 0.0170,
0.0064, 0.0087, -0.0159, -0.0008, 0.0249, 0.0176, 0.0213, 0.0199,
-0.0411, -0.0185, 0.0361, 0.0165, -0.0366, 0.0011, 0.0231, -0.0365,
-0.0065, 0.0200, -0.0402, -0.0289, 0.0112, 0.0014, -0.0048, 0.0469,
0.0315, -0.0378, 0.0392, -0.0483, 0.0469, 0.0200, -0.0357, 0.0011,
0.0031, 0.0180, 0.0359, 0.0056, 0.0218, -0.0161, -0.0221, -0.0196,
-0.0212, 0.0186, 0.0415, -0.0463, -0.0143, -0.0351, 0.0418, 0.0125,
-0.0339, 0.0268, 0.0286, -0.0028, 0.0439, -0.0283, -0.0468, 0.0295,
-0.0397, -0.0159, 0.0498, -0.0395, 0.0458, -0.0338, 0.0416, -0.0442,
-0.0284, 0.0464, 0.0410, -0.0138, 0.0295, -0.0449, -0.0248, 0.0475,
-0.0433, -0.0157, 0.0214, 0.0104, -0.0227, -0.0410, 0.0298, 0.0196],
requires_grad=True)
tensor: tensor([-1.3555e-01, 2.0939e-02, -1.0978e-02, 2.2384e-02, -3.5042e-02,
-2.3534e-02, 1.4926e-02, 1.2443e-02, 9.3025e-04, -9.5893e-03,
-6.1436e-03, -2.0465e-02, -6.8778e-03, 4.9478e-02, 3.5926e-02,
-4.3168e-02, 5.8348e-02, -9.2632e-02, -8.5259e-02, 6.1411e-02,
3.5284e-02, 9.3958e-03, 2.7893e-02, -3.3541e-02, -7.2413e-03,
-2.7861e-02, -2.7886e-02, -2.5454e-02, 2.0168e-02, 5.8621e-02,
3.8074e-03, -3.6508e-02, -6.4902e-02, -2.8883e-02, -7.6834e-04,
-7.5272e-03, 6.1758e-03, 7.4719e-02, 2.9627e-02, -1.8554e-02,
5.9529e-02, 1.5179e-01, 6.4342e-02, 7.4196e-04, -8.1407e-03,
8.9840e-03, 6.9436e-02, -3.7395e-02, -9.3132e-03, -5.0251e-03,
6.5105e-02, -1.0527e-02, 3.9077e-02, -1.3772e-02, -4.9026e-02,
1.2038e-02, -4.4850e-02, 9.1619e-02, -1.1016e-02, -6.6605e-04,
4.6487e-02, -9.2120e-03, -4.6353e-02, 9.6699e-03, -3.4516e-02,
-1.3860e-02, 5.3031e-02, -5.4605e-02, 6.1908e-02, 4.8374e-02,
-2.9948e-02, 6.3221e-02, 1.0043e-02, -1.6848e-02, -3.4269e-02,
3.2234e-03, 4.0958e-03, -1.0895e-04, -7.2229e-02, 1.4132e-03,
-1.5699e-02, -1.1046e-02, -3.4086e-02, -2.1041e-02, 2.7204e-02,
-8.7042e-02, 7.2495e-02, 6.3488e-02, -2.4440e-02, 6.2661e-02,
2.7276e-02, 7.7142e-02, -9.2377e-03, -8.9469e-02, -8.0198e-02,
-5.0707e-02, -1.9137e-02, -4.0547e-02, 1.1058e-01, -1.1626e-01,
1.4978e-02, -8.0421e-02, 2.4994e-02, -8.8205e-02, -3.4078e-02,
4.4075e-02, 1.1147e-03, -1.4480e-02, 6.0224e-02, -6.7068e-02,
-1.2512e-01, 7.6667e-02, 3.1673e-02, -6.0719e-02, 5.4334e-02,
-3.1811e-03, -1.9614e-01, -2.6915e-02, -7.9046e-03, -3.3319e-02],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., 0., -0., ..., -0., 0., -0.],
[-0., -0., -0., ..., -0., 0., 0.],
[-0., -0., 0., ..., -0., -0., -0.],
...,
[0., 0., 0., ..., -0., 0., 0.],
[-0., 0., -0., ..., -0., -0., 0.],
[-0., -0., 0., ..., 0., 0., -0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.0091, 0.0136, -0.0051, ..., -0.0231, 0.0462, -0.0289],
[-0.0265, -0.0483, -0.0282, ..., -0.0270, 0.0013, 0.0135],
[-0.0033, -0.0453, 0.0016, ..., -0.0309, -0.0074, -0.0022],
...,
[ 0.0203, 0.0274, 0.0125, ..., -0.0139, 0.0251, 0.0211],
[-0.0178, 0.0492, -0.0229, ..., -0.0063, -0.0401, 0.0024],
[-0.0033, -0.0229, 0.0245, ..., 0.0420, 0.0423, -0.0412]])
(bias): Normal:
loc: tensor([-0., -0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., -0.,
0., -0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0.,
-0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0.,
0., 0., 0., 0., 0., -0., -0., -0., -0., 0., 0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., -0., -0., 0.,
-0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., 0., 0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0058, -0.0242, 0.0295, -0.0301, -0.0095, -0.0331, -0.0259, -0.0153,
0.0069, 0.0327, 0.0171, -0.0467, -0.0340, 0.0456, 0.0205, 0.0115,
0.0437, -0.0400, -0.0337, 0.0402, 0.0005, 0.0498, 0.0019, -0.0428,
0.0204, -0.0429, -0.0188, -0.0454, -0.0102, 0.0484, -0.0102, 0.0098,
-0.0489, -0.0116, 0.0008, 0.0365, 0.0245, 0.0127, 0.0432, 0.0170,
0.0064, 0.0087, -0.0159, -0.0008, 0.0249, 0.0176, 0.0213, 0.0199,
-0.0411, -0.0185, 0.0361, 0.0165, -0.0366, 0.0011, 0.0231, -0.0365,
-0.0065, 0.0200, -0.0402, -0.0289, 0.0112, 0.0014, -0.0048, 0.0469,
0.0315, -0.0378, 0.0392, -0.0483, 0.0469, 0.0200, -0.0357, 0.0011,
0.0031, 0.0180, 0.0359, 0.0056, 0.0218, -0.0161, -0.0221, -0.0196,
-0.0212, 0.0186, 0.0415, -0.0463, -0.0143, -0.0351, 0.0418, 0.0125,
-0.0339, 0.0268, 0.0286, -0.0028, 0.0439, -0.0283, -0.0468, 0.0295,
-0.0397, -0.0159, 0.0498, -0.0395, 0.0458, -0.0338, 0.0416, -0.0442,
-0.0284, 0.0464, 0.0410, -0.0138, 0.0295, -0.0449, -0.0248, 0.0475,
-0.0433, -0.0157, 0.0214, 0.0104, -0.0227, -0.0410, 0.0298, 0.0196])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=10, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 0.0398, 0.0313, -0.0450, ..., 0.0049, -0.0503, -0.0522],
[ 0.0619, 0.0350, -0.0323, ..., 0.0821, -0.0089, -0.0656],
[-0.0563, 0.0827, -0.0371, ..., -0.0067, -0.0121, -0.0045],
...,
[ 0.0112, 0.0345, 0.0310, ..., -0.0570, -0.0379, 0.0196],
[ 0.0261, 0.0156, 0.0901, ..., 0.0250, 0.0211, 0.0911],
[ 0.0547, 0.0208, 0.0781, ..., 0.0858, 0.0083, -0.0799]],
requires_grad=True)
tensor: tensor([[ 8.0554e-03, -4.1716e-05, 2.3532e-02, ..., 3.5975e-02,
-5.7889e-02, -6.4972e-02],
[ 7.9123e-02, 1.1317e-01, -1.3441e-02, ..., 8.0784e-02,
-5.1207e-02, -1.2873e-01],
[-9.2812e-03, 1.8911e-02, 3.3624e-02, ..., -1.5495e-02,
6.0509e-02, -1.7716e-02],
...,
[ 4.3931e-02, 8.3967e-02, 5.6798e-02, ..., -1.7927e-02,
3.2743e-03, 7.3954e-02],
[-1.2302e-03, -3.4438e-02, 1.3092e-01, ..., 6.6330e-03,
3.5277e-02, 9.5786e-02],
[ 8.5259e-02, 2.0223e-02, -1.3733e-02, ..., 6.2607e-02,
-2.0986e-02, -8.0120e-02]], grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0608, -0.0649, -0.0132, 0.0874, -0.0895, -0.0182, -0.0491, -0.0399,
0.0076, -0.0610], requires_grad=True)
tensor: tensor([-0.0337, 0.0438, -0.0347, 0.1284, -0.1277, -0.0252, -0.0512, -0.0522,
0.0990, -0.1157], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., 0., -0., ..., 0., -0., -0.],
[0., 0., -0., ..., 0., -0., -0.],
[-0., 0., -0., ..., -0., -0., -0.],
...,
[0., 0., 0., ..., -0., -0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., -0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
...,
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, ..., 0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0398, 0.0313, -0.0450, ..., 0.0049, -0.0503, -0.0522],
[ 0.0619, 0.0350, -0.0323, ..., 0.0821, -0.0089, -0.0656],
[-0.0563, 0.0827, -0.0371, ..., -0.0067, -0.0121, -0.0045],
...,
[ 0.0112, 0.0345, 0.0310, ..., -0.0570, -0.0379, 0.0196],
[ 0.0261, 0.0156, 0.0901, ..., 0.0250, 0.0211, 0.0911],
[ 0.0547, 0.0208, 0.0781, ..., 0.0858, 0.0083, -0.0799]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., -0., -0., -0., -0., 0., -0.])
scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0608, -0.0649, -0.0132, 0.0874, -0.0895, -0.0182, -0.0491, -0.0399,
0.0076, -0.0610])
)
(observed): Observed()
)
)
You just have to define the forward function, and the backward
function (where gradients are computed) is automatically defined for you
using autograd.
You can use any of the Tensor operations in the forward function.
The learnable parameters of a model are returned by net.parameters()
params = list(net.parameters())
print(len(params))
print(params[0].size())
Out:
16
torch.Size([6, 1, 5, 5])
Let try a random 32x32 input
input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)
Out:
tensor([[-0.3673, 0.0775, -0.4730, 0.1129, -0.8138, -0.3398, 0.2442, 0.6617,
-0.4352, -0.4799]], grad_fn=<AddmmBackward0>)
Zero the gradient buffers of all parameters and backprops with random gradients:
net.zero_grad()
out.backward(torch.randn(1, 10))
Note
borch.nn only supports mini-batches. The entire borch.nn
package only supports inputs that are a mini-batch of samples, and not
a single sample.
For example, nn.Conv2d will take in a 4D Tensor of
nSamples x nChannels x Height x Width.
If you have a single sample, just use input.unsqueeze(0) to add
a fake batch dimension.
Before proceeding further, let’s recap all the classes you’ve seen so far.
- Recap:
torch.Tensor- A multi-dimensional array with support for autograd operations likebackward(). Also holds the gradient w.r.t. the tensor.nn.Module- Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.nn.Parameter- A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to aModule.autograd.Function- Implements forward and backward definitions of an autograd operation. EveryTensoroperation, creates at least a singleFunctionnode, that connects to functions that created aTensorand encodes its history.
- At this point, we covered:
Defining a neural network
Processing inputs and calling backward
- Still Left:
Computing the loss
Updating the weights of the network
Loss Function¶
A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.
There are several different
loss functions under the
nn package .
A simple loss is: nn.MSELoss which computes the mean-squared error
between the input and the target. They are how ever only equivalent to an maximum
likelihood approach in deep learning.
In order to infer the posterior of the weights and thus capture the uncertainty
of the weights as well, we have to use the infer package. In this example we
will use infer.vi_loss function that automatically creates the best loss function
for variational inference given the latent variables in your model.
Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:
target = torch.randint(10, (1,)) # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)
Out:
tensor(9575.6943, grad_fn=<AddBackward0>)
Now, if you would follow loss in the backward direction you will see a graph of
computations that looks like this:
input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
-> view -> linear -> relu -> linear ->
-> loss
So, when we call loss.backward(), the whole graph is differentiated
w.r.t. the loss, and all Tensors in the graph that has requires_grad=True
will have their .grad Tensor accumulated with the gradient.
Backprop¶
To backpropagate the error all we have to do is to loss.backward().
You need to clear the existing gradients though, else gradients will be
accumulated to existing gradients.
Now we shall call loss.backward(), and have a look at conv1’s bias
gradients before and after the backward.
net.zero_grad() # zeroes the gradient buffers of all parameters
The value for the loc paramater of the approximating distribution of
conv1.bias zeroing the gradients is
print(net.conv1.posterior.bias.loc.grad)
loss.backward()
Out:
tensor([0., 0., 0., 0., 0., 0.])
after calling backward the value is
print(net.conv1.posterior.bias.loc.grad)
Out:
tensor([-0.5059, 0.1341, 0.3560, -0.1543, -0.8424, 0.7598])
The only thing left to learn is:
Updating the weights of the network
Update the weights¶
The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):
weight = weight - learning_rate * gradient
We can implement this using simple python code:
learning_rate = 0.01
for f in net.parameters():
f.data.sub_(f.grad.data * learning_rate)
However, as you use neural networks, you want to use various different
update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc.
To enable this, torch built a small package: torch.optim that
implements all these methods. Using it is very simple:
import torch.optim as optim
# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)
# in your training loop:
n_batch_epoch = 10 # number of batches per epoch usually len(dataloader)
optimizer.zero_grad() # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step() # Does the update
Exercises¶
The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.
Try designing yor own feed forward networks with two different types of non lineareties ex. relu
Total running time of the script: ( 0 minutes 0.144 seconds)