Note
Click here to download the full example code
Training an Image classifier¶
You will learn the basics of how to create an image classifier using the borch.nn package and fit it using the infer package.
Lets start of with importing what we need
import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F
The module borch.nn provides implementations of neural network modules that are used
for deep probabilistic programming. It provides an interface almost identical to
the torch.nn modules and in many cases it is possible to just switch
from torch import nn
to
from borch import nn
Data¶
In this example we will use simulated data and not run the fitting until convergence,
but show how the model is set up and how one can construct the training loop.
We will just generate some random data, where data represent the image and
target is the class.
data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)
Model¶
Lets set up the model.
In order to use infer and the borch to the fullest, we need to select a
a likelihood distribution. For classification the distributions.Categorical
is suitable.
class Net(borch.Module):
def __init__(self):
super(Net, self).__init__(posterior=borch.posterior.Automatic())
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 2)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = self.fc2(x)
# Specifying the likelihood function
self.classification = distributions.Categorical(logits=x)
return self.classification
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the batch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = Net()
print(net)
Out:
Net(
(posterior): Automatic()
(prior): Module()
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-2.3901e-02, 1.7009e-01, -4.0790e-02, 1.9047e-01, 1.1854e-01],
[-1.0204e-01, -1.4715e-01, 9.0185e-03, 8.0316e-02, -1.4466e-01],
[ 1.7187e-01, 8.4639e-02, 1.6665e-01, -1.5608e-01, -9.5238e-02],
[ 1.9682e-01, -9.3068e-02, -1.9482e-02, 1.6220e-01, -5.5923e-02],
[-1.5368e-01, -1.2494e-01, 3.1866e-02, -1.7428e-01, -6.5961e-02]]],
[[[-5.4801e-02, 5.0452e-02, -1.5372e-01, -1.1482e-01, 1.5138e-01],
[ 1.8247e-03, 4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
[-1.9791e-02, -1.0920e-01, 1.2796e-01, -6.0495e-02, 1.2142e-01],
[-1.0610e-01, 2.8335e-02, 2.0862e-02, -1.2132e-01, -5.6049e-03],
[ 2.7007e-02, 1.5627e-01, 7.0422e-02, -2.1336e-03, -5.9226e-02]]],
[[[-1.6497e-01, -9.5347e-02, 9.7235e-02, 1.7565e-01, -1.4118e-01],
[-1.1203e-02, -5.6668e-02, 9.0249e-02, 1.9961e-01, -2.0049e-02],
[-4.5493e-02, -2.0235e-02, -1.9463e-01, 1.5131e-01, 1.6076e-01],
[-1.9071e-01, -1.6333e-01, 1.0380e-01, -7.2042e-02, 1.0249e-01],
[ 1.7660e-01, 1.8708e-02, -1.0379e-01, 8.4113e-02, -1.3492e-01]]],
[[[ 1.3284e-01, -1.7679e-02, -9.9538e-02, 1.5133e-01, 1.0864e-01],
[-1.9522e-01, 1.0066e-01, -1.0742e-01, -1.1599e-01, 1.6930e-01],
[-1.0281e-01, -1.4473e-01, 1.6300e-01, -7.4540e-02, -6.5797e-02],
[ 1.5015e-01, 7.7701e-03, 7.3404e-02, 9.5653e-02, -1.2661e-01],
[-3.2228e-02, -7.9872e-02, 1.9932e-01, 6.0159e-02, 1.3894e-01]]],
[[[-1.7235e-01, 5.8651e-06, 8.9371e-02, -1.5355e-01, -1.2702e-01],
[ 5.9223e-02, 1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
[ 9.6366e-03, -1.0979e-01, -1.1570e-01, 7.1673e-03, 8.9918e-02],
[ 2.4720e-02, 5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
[-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],
[[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
[-7.8788e-03, 4.7546e-02, 1.4185e-01, 8.6187e-02, -1.0654e-01],
[-6.4892e-02, -7.4628e-02, 5.9973e-02, 5.0245e-02, 1.1456e-01],
[-4.4647e-02, 9.8620e-02, -1.2445e-01, -6.6966e-02, 3.5321e-02],
[ 1.9966e-01, -3.8652e-03, 3.5615e-03, -1.3894e-02, 3.5925e-02]]]],
requires_grad=True)
tensor: tensor([[[[-0.0311, 0.0859, -0.0472, 0.2721, -0.0104],
[-0.0725, -0.1375, 0.0592, 0.0649, -0.1419],
[ 0.1607, 0.0185, 0.2350, -0.1482, -0.0662],
[ 0.2044, -0.0331, 0.0378, 0.2124, -0.0684],
[-0.1102, -0.1244, 0.0971, -0.1783, -0.1476]]],
[[[ 0.0512, 0.0096, -0.1303, -0.0858, 0.1533],
[-0.0701, 0.0477, -0.0650, -0.0767, -0.1122],
[-0.0093, -0.0608, 0.1847, -0.0801, 0.1446],
[-0.1251, 0.0316, 0.0093, -0.1505, -0.0172],
[-0.0315, 0.0236, 0.1091, 0.0313, -0.0320]]],
[[[-0.1591, -0.1604, 0.0829, 0.1024, -0.1755],
[-0.0662, -0.1338, 0.0292, 0.2482, -0.0993],
[ 0.0087, 0.0244, -0.2304, 0.2126, 0.1096],
[-0.1508, -0.0875, 0.1158, -0.0840, 0.1429],
[ 0.1580, 0.0260, -0.0853, 0.0094, -0.1336]]],
[[[ 0.1035, -0.0026, -0.0805, 0.1173, 0.0393],
[-0.2130, 0.0323, -0.1763, -0.1442, 0.1475],
[-0.1606, -0.0697, 0.1412, -0.1098, -0.0839],
[ 0.1098, -0.0369, 0.0569, 0.1697, -0.0874],
[-0.0676, -0.1037, 0.1526, 0.1673, 0.0780]]],
[[[-0.1714, 0.0032, 0.0813, -0.1082, -0.0881],
[-0.0233, 0.1377, -0.0773, -0.0227, -0.0461],
[ 0.0277, -0.1762, -0.0500, -0.0424, 0.0864],
[ 0.0196, 0.0805, -0.1020, -0.2220, -0.1493],
[-0.0397, -0.2590, -0.0630, -0.1079, -0.0202]]],
[[[-0.1821, -0.1409, -0.1910, -0.0148, -0.0221],
[ 0.0964, 0.1303, 0.2632, 0.0626, -0.0929],
[-0.1410, -0.0631, 0.0544, 0.0835, 0.0102],
[-0.0231, 0.0293, -0.1860, 0.0420, 0.1065],
[ 0.1520, -0.0128, -0.0434, 0.0051, 0.0573]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.1531, -0.0004, 0.1924, -0.0493, 0.0953, 0.0265],
requires_grad=True)
tensor: tensor([-0.1049, 0.0772, 0.1882, -0.0163, 0.1152, 0.1543],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., -0., 0., 0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.]]],
[[[-0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[-0., -0., 0., -0., 0.],
[0., 0., -0., 0., -0.]]],
[[[0., -0., -0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[-0., 0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., -0., -0.]]],
[[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[0., -0., 0., -0., 0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-2.3901e-02, 1.7009e-01, -4.0790e-02, 1.9047e-01, 1.1854e-01],
[-1.0204e-01, -1.4715e-01, 9.0185e-03, 8.0316e-02, -1.4466e-01],
[ 1.7187e-01, 8.4639e-02, 1.6665e-01, -1.5608e-01, -9.5238e-02],
[ 1.9682e-01, -9.3068e-02, -1.9482e-02, 1.6220e-01, -5.5923e-02],
[-1.5368e-01, -1.2494e-01, 3.1866e-02, -1.7428e-01, -6.5961e-02]]],
[[[-5.4801e-02, 5.0452e-02, -1.5372e-01, -1.1482e-01, 1.5138e-01],
[ 1.8247e-03, 4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
[-1.9791e-02, -1.0920e-01, 1.2796e-01, -6.0495e-02, 1.2142e-01],
[-1.0610e-01, 2.8335e-02, 2.0862e-02, -1.2132e-01, -5.6049e-03],
[ 2.7007e-02, 1.5627e-01, 7.0422e-02, -2.1336e-03, -5.9226e-02]]],
[[[-1.6497e-01, -9.5347e-02, 9.7235e-02, 1.7565e-01, -1.4118e-01],
[-1.1203e-02, -5.6668e-02, 9.0249e-02, 1.9961e-01, -2.0049e-02],
[-4.5493e-02, -2.0235e-02, -1.9463e-01, 1.5131e-01, 1.6076e-01],
[-1.9071e-01, -1.6333e-01, 1.0380e-01, -7.2042e-02, 1.0249e-01],
[ 1.7660e-01, 1.8708e-02, -1.0379e-01, 8.4113e-02, -1.3492e-01]]],
[[[ 1.3284e-01, -1.7679e-02, -9.9538e-02, 1.5133e-01, 1.0864e-01],
[-1.9522e-01, 1.0066e-01, -1.0742e-01, -1.1599e-01, 1.6930e-01],
[-1.0281e-01, -1.4473e-01, 1.6300e-01, -7.4540e-02, -6.5797e-02],
[ 1.5015e-01, 7.7701e-03, 7.3404e-02, 9.5653e-02, -1.2661e-01],
[-3.2228e-02, -7.9872e-02, 1.9932e-01, 6.0159e-02, 1.3894e-01]]],
[[[-1.7235e-01, 5.8651e-06, 8.9371e-02, -1.5355e-01, -1.2702e-01],
[ 5.9223e-02, 1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
[ 9.6366e-03, -1.0979e-01, -1.1570e-01, 7.1673e-03, 8.9918e-02],
[ 2.4720e-02, 5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
[-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],
[[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
[-7.8788e-03, 4.7546e-02, 1.4185e-01, 8.6187e-02, -1.0654e-01],
[-6.4892e-02, -7.4628e-02, 5.9973e-02, 5.0245e-02, 1.1456e-01],
[-4.4647e-02, 9.8620e-02, -1.2445e-01, -6.6966e-02, 3.5321e-02],
[ 1.9966e-01, -3.8652e-03, 3.5615e-03, -1.3894e-02, 3.5925e-02]]]])
(bias): Normal:
loc: tensor([-0., -0., 0., -0., 0., 0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.1531, -0.0004, 0.1924, -0.0493, 0.0953, 0.0265])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
...,
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],
[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],
[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[[[-2.9409e-02, 6.0003e-02, 5.3245e-02, 2.8255e-02, 5.1865e-02],
[-1.3120e-03, 6.9924e-02, -2.1958e-02, 4.5196e-02, 5.5987e-02],
[ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
[-6.1391e-02, 2.6227e-02, 2.2223e-02, 5.1530e-02, 7.8572e-02],
[ 1.1986e-02, -3.5154e-02, 2.2590e-02, -2.8268e-02, -4.1092e-02]],
[[-7.7353e-02, -3.3001e-03, 2.2446e-02, 3.6329e-02, -8.4312e-03],
[ 5.9898e-02, 3.3921e-02, -4.0756e-02, 5.6990e-02, -7.9915e-02],
[ 1.1058e-02, -7.5468e-02, 3.1362e-02, -1.2730e-02, 7.0498e-02],
[-6.7726e-03, -3.9453e-02, 6.8676e-02, 2.8066e-02, 4.1992e-02],
[ 8.0445e-02, 6.9594e-02, 2.9287e-02, 4.4932e-02, -7.2432e-02]],
[[-6.8457e-02, -7.5792e-02, 3.8843e-02, 3.2058e-02, 2.1625e-02],
[ 7.5601e-02, -5.7455e-02, 1.2939e-02, -3.5165e-02, -5.6556e-02],
[ 8.3346e-03, 2.2498e-02, 6.8399e-03, -2.3131e-02, -2.1235e-03],
[-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
[ 4.4843e-02, 1.4243e-02, -7.8475e-02, 2.0007e-02, 6.8106e-02]],
[[-7.1642e-03, -4.0418e-02, 1.6455e-03, -7.4808e-03, 3.0932e-02],
[-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02, 7.4918e-02],
[ 2.7178e-03, 7.6559e-02, 2.3546e-02, 1.0888e-02, -3.2943e-02],
[-3.1639e-02, -4.4183e-02, -1.1435e-02, 1.8024e-02, 7.9902e-02],
[-1.6372e-02, 6.6337e-03, -5.5328e-02, 1.6804e-02, -1.1563e-02]],
[[ 3.4681e-02, 3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
[ 8.0547e-02, 4.9685e-02, 7.5298e-02, -2.9211e-02, 2.2224e-02],
[-2.4409e-02, 7.0003e-02, 2.1558e-04, -4.4977e-02, 2.3133e-02],
[-7.4964e-02, -6.4867e-02, -4.5671e-02, 3.1878e-02, 3.5104e-02],
[ 2.2541e-02, 3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],
[[-7.0267e-02, -2.8274e-03, -3.7755e-02, 3.8392e-02, -1.3160e-05],
[-2.4893e-02, 2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
[ 4.5711e-02, -3.3377e-02, 2.4278e-02, -4.1101e-03, 1.8952e-02],
[ 2.0236e-02, 6.4522e-03, -5.7711e-02, -4.5909e-02, 7.4404e-02],
[-3.7771e-02, -3.1536e-03, 6.4748e-02, -4.4604e-02, -2.2752e-02]]],
[[[-2.9177e-02, 5.0791e-02, -5.5935e-03, 4.7169e-02, -1.1366e-02],
[ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02, 6.6127e-02],
[-1.1869e-03, 7.9363e-03, 6.0124e-02, 8.7514e-04, -2.6833e-02],
[-4.4174e-02, 1.9975e-02, 2.7939e-02, 2.7791e-02, -1.7919e-02],
[-2.9191e-02, 2.5500e-02, 3.6869e-03, 3.1063e-02, -3.7672e-02]],
[[-6.7084e-02, 5.3586e-02, 1.8485e-02, -7.8099e-03, 4.6288e-02],
[ 4.9055e-02, -1.4634e-02, 3.2799e-02, 3.9726e-02, -7.0032e-02],
[ 7.0897e-02, 3.4370e-02, -1.4814e-02, -3.9030e-02, 3.0867e-02],
[ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02, 2.6245e-02],
[ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02, 3.3599e-02]],
[[ 1.1310e-02, -2.9816e-02, 6.3727e-02, 3.9850e-02, 1.3761e-02],
[ 3.0453e-02, -4.8504e-02, 5.3189e-02, 1.9425e-02, 4.7484e-02],
[ 6.1376e-02, -7.8290e-02, -6.8859e-02, 1.8497e-02, -1.1496e-02],
[-7.8178e-02, -4.5904e-02, 7.3181e-02, 2.9441e-02, 4.6967e-02],
[ 7.6978e-02, 7.2934e-02, 5.6798e-02, 5.8828e-02, 4.4637e-02]],
[[ 6.0281e-02, 7.7289e-02, 7.9016e-02, -4.1437e-02, 3.1101e-02],
[-5.0620e-02, 3.3108e-02, -5.8687e-02, 2.7694e-02, 5.4294e-02],
[ 2.1156e-02, 1.7004e-03, 2.4742e-02, 6.9593e-02, 5.7699e-02],
[ 6.8876e-02, 3.2239e-02, 3.3322e-02, 2.9973e-02, 7.4267e-02],
[-2.4736e-03, 3.8454e-02, -2.5898e-02, 2.0443e-02, 6.0816e-02]],
[[ 1.6487e-02, -6.8994e-03, 5.2835e-02, 5.7784e-02, -1.7036e-02],
[-7.1883e-02, 3.1576e-04, -5.5839e-02, 1.4949e-02, -1.9834e-03],
[-2.7395e-02, 3.9861e-04, -1.0588e-02, 9.9140e-03, 5.1499e-02],
[ 7.4060e-02, -1.0655e-02, 1.1668e-02, 4.9183e-02, 5.2846e-02],
[ 2.8634e-02, 4.2678e-02, -1.4281e-02, 1.3904e-03, 7.6289e-02]],
[[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
[-4.1579e-02, 2.8281e-02, 3.9429e-02, 7.5058e-03, 6.5170e-03],
[-3.4494e-02, 7.5710e-02, 4.1078e-02, 4.4451e-02, 4.2661e-02],
[-5.4398e-02, 5.1592e-02, -2.6367e-02, -3.2980e-02, 5.3860e-02],
[-4.2436e-02, -8.4286e-03, 7.5331e-02, -6.6725e-02, 4.9887e-02]]],
[[[-1.5211e-02, -6.2506e-03, 3.0621e-03, 3.0725e-02, -7.0877e-03],
[ 1.1974e-02, -5.2611e-02, -2.7415e-02, 4.3479e-02, -4.2108e-02],
[ 3.3816e-02, 6.1523e-02, -9.9011e-03, -3.7770e-02, 6.5915e-04],
[ 5.3678e-03, 5.9921e-02, -3.4530e-02, 5.1942e-02, 5.3762e-02],
[-4.7293e-02, -6.2274e-02, -7.5059e-02, 8.1645e-02, 2.1149e-02]],
[[-1.4459e-02, -2.7155e-02, -2.5730e-02, 7.6751e-02, -1.6932e-02],
[-5.3342e-02, -2.6885e-02, 4.3476e-02, -7.9174e-02, -3.5761e-02],
[ 4.2970e-02, 2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
[-2.5080e-02, 4.1672e-02, 4.2424e-02, -4.8704e-02, -6.0434e-02],
[ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],
[[-5.4303e-02, 6.7885e-02, -5.3922e-02, 6.5582e-02, -5.2617e-03],
[-8.4440e-03, 8.0911e-02, -3.8667e-02, -5.6241e-04, 7.0876e-02],
[ 5.4673e-02, 1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
[-2.8238e-02, -5.0765e-02, -3.4076e-02, 3.1219e-02, -1.8919e-02],
[-4.6076e-02, -7.6516e-02, 3.1247e-02, 4.1743e-02, 7.5575e-02]],
[[ 3.8787e-03, -8.1731e-03, 4.7381e-03, 5.8261e-02, 4.6416e-02],
[ 7.7171e-02, -6.5924e-02, 1.5769e-02, 2.6777e-02, 7.7365e-02],
[-7.3126e-02, 4.7624e-02, -7.0620e-02, 7.3309e-02, 7.7585e-03],
[-7.9208e-02, -7.3783e-03, -5.3142e-02, 3.4386e-02, 5.6230e-03],
[-4.4492e-02, 7.5403e-02, -2.8887e-02, 2.6937e-02, 7.0698e-03]],
[[-1.7104e-02, -6.6983e-02, -5.7655e-02, 5.1198e-02, -8.0137e-02],
[-7.2406e-02, -3.7856e-02, -7.2086e-02, 7.2641e-02, -7.0749e-02],
[ 6.6330e-02, 2.8797e-02, 6.2197e-02, -4.6068e-03, 1.7392e-03],
[ 4.2384e-02, 5.9572e-02, -4.5953e-02, 6.6345e-03, 7.3979e-02],
[-4.8313e-02, -1.3063e-02, 1.7648e-02, -5.0903e-02, -7.1852e-02]],
[[ 3.2147e-02, -8.1585e-02, 6.7507e-02, -7.7056e-02, 1.7667e-02],
[ 2.0535e-03, -7.4221e-02, 1.0343e-02, 4.3018e-02, 9.7351e-03],
[-2.1410e-02, 5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
[ 4.9847e-02, 6.9608e-02, 3.6233e-03, 5.7025e-02, 6.3206e-02],
[ 1.4611e-02, -2.9885e-02, 5.6140e-02, -6.4338e-02, 8.5266e-03]]],
...,
[[[ 5.7051e-02, 3.4026e-02, -3.7723e-02, 1.4372e-02, -4.4266e-03],
[-8.0557e-02, 1.1810e-02, -6.9374e-02, 3.4264e-02, -3.9068e-02],
[ 3.2814e-02, -4.9334e-02, -3.2234e-02, 3.7901e-02, 9.9268e-03],
[ 1.2846e-03, -5.9199e-02, -5.6303e-02, 1.2189e-03, 7.8874e-02],
[ 7.6858e-04, 2.4341e-02, 4.0423e-02, -7.7602e-02, -3.6388e-02]],
[[ 5.9494e-02, 4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
[-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
[-6.8915e-02, 1.6417e-02, -9.0384e-03, -5.6058e-02, 1.1540e-02],
[ 3.9632e-02, 3.8881e-02, 5.5834e-02, 7.5591e-02, 1.8463e-02],
[ 4.5034e-02, -6.4665e-02, 6.7883e-02, 7.1108e-02, 8.0694e-02]],
[[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
[-7.9629e-02, -2.5316e-03, -3.4649e-02, 7.9736e-02, 2.4963e-02],
[-1.4102e-02, 3.0896e-02, -5.4594e-02, 5.7641e-02, 7.8276e-02],
[ 3.3722e-02, 1.6397e-02, 6.6251e-02, 2.5637e-02, 4.0073e-02],
[ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],
[[ 3.9434e-02, 6.7049e-02, 7.1627e-02, 6.9307e-02, -5.7508e-03],
[ 2.8151e-02, 7.9890e-02, -6.4687e-02, -6.8959e-02, 6.8179e-02],
[ 1.2583e-02, 6.6052e-02, 6.7770e-02, 1.0853e-02, 6.3935e-02],
[ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
[-1.1810e-04, 6.2292e-02, -2.1831e-02, -4.1282e-02, 3.4718e-02]],
[[-8.9495e-03, -3.5923e-02, -4.9030e-02, 1.7068e-02, 5.7835e-02],
[-6.2950e-02, 6.9258e-02, 1.4909e-02, -3.9252e-02, 3.0917e-02],
[-5.0831e-02, -2.6109e-02, -4.2526e-02, 4.9180e-03, -6.7907e-02],
[-1.4867e-02, 8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
[ 6.5369e-02, 3.5118e-02, -3.5070e-02, 3.1514e-02, -1.7773e-02]],
[[-1.9000e-02, 4.8772e-02, -4.0550e-02, 5.7766e-02, -4.8687e-02],
[ 7.0112e-02, 7.4851e-02, -5.0324e-02, 4.2522e-02, 6.6367e-02],
[-6.6793e-02, -6.3487e-02, -6.3574e-02, 7.3530e-02, -6.7062e-02],
[ 1.9297e-02, 3.9876e-02, 7.0333e-03, 3.6541e-02, 3.0865e-02],
[ 6.9009e-02, 2.7737e-03, -6.0400e-02, 1.5249e-03, -1.5177e-03]]],
[[[-4.9098e-02, -1.2656e-02, 3.0326e-02, 4.6450e-02, 4.1143e-02],
[ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02, 7.3691e-02],
[-5.2880e-02, -5.3283e-02, -4.6874e-02, 2.0506e-02, 1.4432e-02],
[ 5.3466e-05, 6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
[-7.2209e-02, -2.8706e-02, 6.6109e-02, 5.8108e-02, -1.8114e-02]],
[[-5.8877e-02, 3.5183e-02, 6.5460e-02, 5.2934e-02, 3.5997e-02],
[-6.5718e-02, 2.7700e-02, 7.1110e-02, -5.7825e-02, 6.1866e-03],
[-5.3281e-03, -7.6189e-02, -6.9421e-02, 6.4743e-02, -1.1912e-02],
[ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02, 4.5729e-02],
[ 2.0473e-03, -3.1893e-02, -3.0088e-02, 6.1322e-02, -1.3287e-02]],
[[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
[-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
[-5.0579e-02, 5.1921e-02, 4.0006e-02, -5.3846e-02, 3.6710e-03],
[-5.5284e-03, 5.2453e-02, 3.5617e-02, -4.4475e-02, 2.7835e-02],
[ 3.6465e-02, 2.2936e-02, 4.9494e-02, -6.8768e-02, -6.8512e-02]],
[[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
[ 3.3837e-02, 6.9886e-02, 2.5937e-03, -4.4428e-02, -6.1442e-03],
[-3.3799e-02, 7.6725e-02, 1.5202e-02, 2.7467e-02, -7.2112e-02],
[-5.3887e-02, 5.3134e-02, -5.5426e-02, 8.1476e-02, 1.0773e-02],
[-1.9578e-02, 1.7628e-02, -2.2382e-02, 6.7076e-02, -1.3475e-02]],
[[ 3.7281e-02, 2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
[-5.1809e-02, -1.0889e-02, 4.4019e-02, -4.0099e-02, -6.2939e-02],
[ 7.8826e-02, 1.4336e-02, -7.8953e-02, -4.1699e-03, 2.1759e-02],
[ 4.3422e-02, 6.1053e-02, -5.1035e-02, 2.5170e-02, 8.1194e-02],
[-3.5907e-02, 3.5084e-02, 5.4858e-02, 5.7819e-02, -6.8527e-02]],
[[ 6.0340e-02, -4.5873e-02, 4.5307e-02, -1.8559e-02, -5.9891e-02],
[ 7.1101e-02, 5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
[ 4.5972e-02, 4.6237e-02, -1.8353e-02, 5.5372e-03, 5.8802e-02],
[-8.0939e-02, 2.2098e-03, -2.7943e-03, 6.9556e-02, 3.5299e-03],
[-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],
[[[-5.8326e-02, 4.3804e-02, 5.4642e-02, 2.9479e-02, 5.5766e-02],
[-6.2955e-02, 4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
[ 7.8974e-02, 1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
[ 1.1472e-02, 6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
[-5.7437e-02, -5.7150e-02, 4.9108e-02, 2.2168e-02, -5.4964e-02]],
[[-3.2895e-02, -2.2746e-03, 6.8428e-02, -7.4781e-02, 6.5675e-02],
[-8.0232e-02, -2.6468e-02, -2.1136e-02, 2.1449e-02, 6.4572e-02],
[ 2.9930e-03, 1.1987e-02, 4.8122e-03, 3.4183e-02, -7.8918e-02],
[ 6.3749e-02, -2.5083e-02, 1.1253e-02, -4.4485e-02, 3.3380e-02],
[ 5.0096e-03, -1.7321e-02, 8.0185e-02, -2.3853e-02, -2.9333e-03]],
[[ 2.6648e-02, -7.6799e-02, 3.2204e-03, -7.7476e-02, -4.4615e-03],
[ 5.7110e-02, 7.8575e-02, 5.3204e-02, -7.8592e-02, 4.1383e-03],
[ 1.6194e-02, 2.5400e-02, 7.4070e-02, -3.9092e-03, -2.9417e-02],
[-7.9407e-02, 2.5042e-02, -3.8854e-02, 2.8143e-02, 2.8485e-03],
[-3.3828e-02, -7.5645e-02, 7.8511e-02, -4.4048e-02, 6.0887e-02]],
[[-6.4552e-02, -3.1646e-02, 6.5499e-02, -6.8577e-02, -5.1529e-02],
[ 6.1176e-02, -4.8461e-02, 4.7687e-02, -3.0069e-02, -1.7665e-02],
[ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
[ 6.1360e-02, -1.9826e-02, -6.4593e-02, 3.5071e-02, -5.9178e-02],
[-6.6739e-02, 2.6098e-02, -5.5998e-02, 8.1334e-02, 3.7472e-02]],
[[-5.5207e-02, 1.4355e-02, -2.2037e-02, -2.4025e-02, 7.2631e-02],
[-1.0448e-02, 1.9105e-03, -5.5223e-02, 4.6377e-02, -6.8534e-02],
[-2.4292e-02, 7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
[ 4.5334e-02, -2.3274e-02, -4.3572e-02, 4.3487e-03, -4.6057e-02],
[-5.3757e-02, -2.0336e-02, -5.2245e-02, 2.2213e-02, -6.7578e-03]],
[[ 5.7154e-02, 6.9033e-02, -2.7450e-02, -5.9039e-02, 3.0233e-02],
[ 5.5904e-02, 5.2798e-02, -2.2586e-02, 2.8411e-02, -6.8010e-03],
[ 5.1257e-02, -4.3710e-02, 8.7161e-03, 1.9411e-02, -3.5285e-03],
[-8.0450e-02, 6.1012e-02, -7.7756e-02, -2.1472e-02, 4.7537e-02],
[-4.7231e-02, 3.7300e-02, 2.7754e-02, -2.4025e-02, 1.0065e-02]]]],
requires_grad=True)
tensor: tensor([[[[-9.3244e-02, 7.2930e-02, -1.4004e-02, 4.0448e-02, 4.7394e-02],
[ 3.4764e-03, 3.3045e-03, -6.4835e-02, 4.5686e-02, 9.0246e-02],
[ 2.8020e-02, 1.8739e-02, -5.1935e-02, -3.1779e-02, -6.5367e-02],
[-1.0352e-01, 6.2800e-02, 3.3458e-02, 1.6324e-01, 7.5646e-02],
[-1.4166e-02, -7.0325e-02, -2.7509e-02, -1.0614e-02, -7.1825e-02]],
[[-1.2719e-01, 8.5667e-02, 3.7926e-02, 3.3971e-04, -1.0475e-02],
[ 7.9021e-02, 3.0413e-03, -1.6426e-01, 1.3341e-02, -2.5737e-02],
[ 2.6697e-02, -1.3385e-01, 6.9021e-03, 3.0326e-02, 6.2550e-02],
[-1.3212e-02, -6.8929e-02, 1.4790e-01, -2.7115e-02, 2.0803e-02],
[ 7.5669e-02, 1.1428e-01, -5.9129e-02, 7.8492e-02, -3.0813e-02]],
[[-6.6693e-02, -1.3390e-01, 3.4061e-02, -1.3415e-02, 6.6936e-03],
[ 8.5119e-02, -8.8495e-02, 5.1307e-03, 5.9909e-02, -7.1546e-02],
[-3.1164e-02, 2.3079e-02, 4.8668e-02, 2.0298e-02, -1.6369e-02],
[-1.4313e-01, -5.0021e-02, -5.2223e-02, 1.3387e-02, -1.0466e-01],
[ 5.5770e-02, 1.9499e-02, -1.1884e-03, 1.0549e-02, 7.3979e-02]],
[[-1.7652e-03, -8.5820e-03, 3.8955e-02, -5.1522e-04, 3.0524e-02],
[-7.6496e-02, -1.0952e-01, 1.2580e-02, -1.0383e-01, 1.3392e-02],
[ 5.0957e-02, 7.2177e-02, 6.7416e-02, 4.0250e-02, -6.1816e-02],
[ 7.9146e-02, -2.6732e-02, 3.3583e-02, -5.8278e-02, 1.7562e-01],
[-7.7750e-03, 1.5841e-02, -1.4416e-01, 7.1971e-02, -3.2296e-02]],
[[ 3.1765e-02, 7.7612e-03, -9.2562e-03, -7.5009e-02, -7.6771e-02],
[ 5.2192e-02, 8.9829e-02, 8.0666e-02, -7.5012e-02, 8.2851e-02],
[-2.9721e-02, 5.4033e-02, -5.8591e-02, -8.4672e-02, -3.5956e-02],
[ 4.6556e-03, -8.2645e-02, -2.9310e-02, 2.2807e-03, 6.9149e-02],
[ 5.1460e-03, -5.1742e-02, -9.3557e-02, -2.8508e-03, 6.1673e-02]],
[[-7.7168e-02, 7.8402e-02, -1.1065e-01, 3.5423e-02, 2.2741e-02],
[-3.0649e-02, 5.4842e-02, -5.2164e-02, -7.5771e-02, -3.3854e-02],
[-5.6595e-03, -2.0763e-02, 5.1015e-02, 6.8125e-04, -4.5188e-02],
[-6.0728e-02, 1.5548e-02, 1.1190e-02, -3.3515e-02, 9.3201e-02],
[-8.1993e-02, 5.8329e-02, 4.3559e-02, -1.3500e-01, -6.2694e-02]]],
[[[-2.5201e-03, 1.0502e-01, -1.6741e-02, 7.6163e-02, -1.7975e-02],
[ 7.7132e-02, -4.8635e-02, 2.0458e-02, -1.4585e-01, 2.5789e-02],
[-2.3901e-02, -3.0057e-02, 1.0268e-01, -9.8991e-03, -1.2776e-02],
[-3.0604e-02, -1.9044e-02, -1.0034e-01, -3.4933e-03, -1.0588e-01],
[-4.3267e-02, 4.4136e-02, -1.4653e-02, 8.0375e-02, -4.3509e-03]],
[[-7.2069e-02, 7.5176e-02, 3.5200e-02, -1.4117e-02, 1.4985e-01],
[ 9.3505e-02, -1.5317e-02, -2.7400e-02, 7.2958e-02, -3.5223e-02],
[ 1.0651e-01, 4.8038e-03, -4.8003e-03, -2.4901e-02, 5.0562e-02],
[ 8.1387e-02, -1.0440e-01, 9.1612e-03, -1.0838e-01, 3.4357e-02],
[ 3.7520e-02, 1.0998e-02, -1.2370e-01, -1.4700e-01, 2.2564e-02]],
[[-2.6191e-02, 5.5962e-03, 5.2655e-02, 3.7240e-02, -9.0327e-04],
[ 5.6995e-03, -5.7081e-03, 2.0693e-02, 1.2346e-02, 4.0957e-03],
[ 9.9342e-02, -9.1494e-02, -7.2406e-02, -2.3302e-03, -4.5766e-02],
[-9.5160e-02, -4.8321e-02, 5.4787e-02, 7.6552e-02, 6.8904e-02],
[ 8.1214e-02, 4.4162e-02, 3.9048e-02, 2.1962e-01, 7.6550e-02]],
[[ 4.3722e-02, 1.4723e-01, 9.8580e-02, 5.3083e-02, 1.7335e-02],
[ 1.6842e-02, 9.2504e-02, -1.2577e-01, 4.3781e-02, 4.5560e-02],
[-2.2627e-03, 4.4276e-02, -3.1316e-02, -3.0269e-03, 1.1357e-01],
[ 1.1986e-01, -5.7009e-02, -1.8555e-02, 3.8186e-03, 1.8096e-01],
[-2.2767e-02, 1.2045e-01, -2.4305e-02, 5.8296e-02, 9.8688e-02]],
[[ 1.5614e-02, -5.0428e-02, 9.2436e-02, 1.5314e-02, 3.8010e-02],
[-3.0040e-02, -4.5455e-02, -8.2217e-02, -4.5022e-03, 1.6350e-02],
[-3.2035e-02, -7.1502e-02, 6.1119e-02, 1.0737e-02, 1.0597e-01],
[ 1.1230e-01, 2.2913e-03, 2.2935e-02, 4.4280e-02, 4.6554e-02],
[ 6.3816e-02, 5.0896e-02, -8.1320e-03, -5.9140e-02, -1.3816e-02]],
[[ 2.7798e-02, 2.0180e-02, -6.8479e-02, -3.8563e-02, -2.1970e-02],
[-1.1560e-02, 1.0920e-01, 4.0527e-02, 1.5456e-02, -3.0192e-02],
[-3.4818e-02, 2.9930e-02, 3.8703e-02, 6.9933e-02, 1.3544e-01],
[-1.1075e-02, 1.0404e-01, 2.5879e-03, -3.4677e-02, 1.3333e-01],
[-9.6145e-02, 4.9468e-02, 2.3339e-02, -2.3944e-02, 2.2928e-02]]],
[[[ 1.6603e-03, -7.8303e-02, 1.2174e-01, 1.7890e-02, 4.8606e-02],
[-5.6957e-03, -1.3151e-01, -5.8330e-02, 2.0716e-02, -1.2664e-01],
[ 1.0507e-01, 7.2393e-02, 3.5766e-02, -4.7793e-03, -1.0268e-02],
[ 2.1223e-03, 2.5655e-02, 1.9555e-02, 1.4210e-01, 1.0301e-02],
[-1.4564e-01, -1.8514e-02, -4.7687e-02, 1.5104e-01, -3.0745e-02]],
[[-3.2338e-04, 2.1239e-02, -5.0826e-02, 6.4949e-02, 5.0461e-04],
[-8.2764e-02, -1.5075e-02, 1.5609e-02, -1.3402e-01, -1.1833e-02],
[ 4.4014e-02, -5.9378e-03, -1.2508e-01, -7.1479e-03, 6.9006e-02],
[-8.8556e-03, -7.4991e-02, 1.4296e-02, -6.0615e-02, 3.2985e-02],
[ 1.5668e-01, -8.7098e-02, -5.0354e-02, -1.2569e-01, -1.5741e-01]],
[[-1.0422e-01, 1.4817e-01, -9.1310e-02, 5.6767e-02, -1.2098e-02],
[-2.7485e-02, 1.3706e-01, -3.2490e-03, -6.4062e-02, 1.3419e-01],
[ 1.0312e-02, -8.3901e-02, -1.7662e-02, -1.5195e-01, -8.3465e-02],
[-5.6543e-04, -4.3555e-02, -5.6245e-02, 7.3568e-02, 2.1678e-02],
[ 7.6391e-03, -4.8854e-03, 5.4257e-02, 7.8329e-03, 1.4124e-01]],
[[-4.9119e-02, -5.2545e-02, 2.3897e-02, 3.6596e-04, 4.2827e-02],
[ 2.3076e-02, -1.8356e-02, 2.9836e-02, 1.2485e-01, 1.2366e-01],
[-7.8947e-02, 1.4472e-01, -4.3746e-02, 1.0565e-02, 4.6012e-02],
[-5.7214e-02, -4.1273e-03, -5.6967e-02, -4.2106e-03, 7.4851e-02],
[-9.0327e-02, 8.1366e-03, -5.6940e-02, 8.7756e-02, -1.0730e-01]],
[[-2.5770e-02, -1.0837e-01, -6.5197e-02, 1.0767e-01, -1.8528e-02],
[ 3.6587e-02, -1.0597e-01, -6.8699e-02, 2.4231e-02, -1.2268e-01],
[ 1.2846e-01, -1.1661e-02, 5.9071e-02, 3.8168e-02, 8.9368e-02],
[ 7.8908e-02, 3.6826e-02, -1.2949e-02, -5.2112e-02, 6.9411e-02],
[-1.5190e-02, -6.7784e-02, 4.7824e-02, -4.4707e-02, -1.0409e-01]],
[[ 6.1472e-02, -8.7888e-02, 9.1761e-02, -9.8433e-02, -1.2520e-02],
[-3.4953e-02, -1.2213e-01, 5.5129e-02, -2.4383e-03, -3.0046e-02],
[-2.7596e-02, 5.2428e-02, 1.3167e-01, -3.1463e-02, 1.8369e-02],
[ 7.4151e-03, 9.0555e-02, -1.2976e-02, 4.0466e-02, 1.1523e-01],
[ 1.1646e-01, -2.0131e-02, 7.8028e-02, -6.9566e-02, -5.8961e-04]]],
...,
[[[ 1.8656e-02, 3.4279e-02, -1.0672e-01, -1.0023e-02, 2.2085e-02],
[-4.6725e-02, 7.7616e-02, -1.4231e-02, 4.5028e-02, -2.9861e-02],
[ 9.5183e-03, -1.7526e-03, 5.1253e-02, 7.1305e-02, 2.6603e-02],
[-3.0437e-02, -7.2302e-02, -1.0535e-01, 1.4485e-02, 6.3855e-02],
[-5.4683e-03, 5.6009e-02, 7.2122e-03, -1.0415e-01, 3.5188e-02]],
[[ 1.2157e-01, 6.7376e-02, 2.1207e-02, -1.0825e-01, -1.5351e-01],
[-7.0602e-02, -9.9218e-03, 3.5316e-02, -1.9505e-02, -8.3313e-02],
[-1.7510e-02, 9.2322e-02, -6.4570e-02, -8.1725e-02, 1.3421e-02],
[ 2.8115e-02, -4.7646e-03, 8.4901e-02, 3.6822e-02, -4.3886e-02],
[-1.4940e-02, -7.2755e-02, 1.3097e-01, 4.6871e-02, 5.6414e-03]],
[[-1.4081e-01, 4.8099e-02, -1.5321e-01, 3.9479e-02, -7.0164e-02],
[-7.3554e-02, -1.1309e-02, -6.2749e-02, 1.2635e-01, -4.7698e-02],
[ 1.8018e-02, 1.0312e-01, -4.4310e-02, 7.4449e-02, 2.4631e-02],
[ 3.6606e-02, -2.5128e-02, 3.0890e-02, 3.3290e-03, 3.1577e-02],
[-1.8735e-02, -1.2198e-01, -1.8037e-02, -9.1278e-02, -1.3295e-01]],
[[-2.3376e-02, 6.9800e-02, 2.6887e-02, -1.7161e-02, -8.7163e-02],
[ 1.5594e-02, 5.3034e-02, -1.1990e-01, 7.9062e-02, 1.0965e-01],
[ 1.7833e-02, 8.8824e-02, 3.1692e-02, 8.7611e-02, 6.3664e-02],
[-1.8869e-03, -1.1539e-01, -4.1857e-02, -8.9586e-02, -7.2488e-02],
[ 1.9608e-02, 1.3425e-01, -8.0856e-02, -8.3590e-02, -4.5203e-02]],
[[ 1.0801e-02, 1.5272e-02, -1.4688e-01, -3.7206e-02, 5.1565e-02],
[-2.0180e-02, -5.6527e-03, -8.9918e-02, -7.9719e-02, 2.5003e-02],
[ 4.3385e-02, -4.1702e-03, 6.6365e-02, -8.6634e-03, -1.1395e-01],
[ 1.2979e-02, 7.9360e-02, -3.8312e-02, -5.8354e-02, -7.3427e-02],
[ 9.6145e-03, -1.4716e-02, -1.0418e-02, -4.8538e-02, -4.1533e-02]],
[[ 4.0994e-02, -6.5618e-03, -3.9249e-02, 1.4940e-01, 6.1836e-02],
[ 1.4003e-01, 1.0897e-01, -9.9539e-02, 7.6101e-02, 4.7514e-02],
[ 1.8425e-02, -4.1603e-02, -3.0696e-02, 3.7560e-02, -5.6115e-02],
[ 1.1114e-01, -1.9242e-04, 1.1626e-02, 5.5963e-02, -4.1088e-02],
[ 1.1746e-01, -1.4505e-02, -3.8933e-02, -9.7828e-03, -4.2146e-02]]],
[[[ 2.7058e-02, -5.4692e-02, -1.6676e-02, 1.3220e-01, 6.2700e-02],
[ 3.9145e-02, -6.8409e-02, -7.5723e-03, 6.6492e-03, 3.2803e-02],
[-5.9434e-02, -1.6559e-02, -6.5511e-02, -1.8650e-02, 4.8143e-02],
[-7.7847e-03, 9.0031e-02, -6.6468e-02, -5.4535e-02, -1.9140e-02],
[-9.0769e-02, -1.1950e-01, 6.6162e-02, 7.2317e-02, -2.5983e-02]],
[[-1.8429e-02, 3.7081e-02, 3.3936e-03, 8.1601e-02, 7.5250e-02],
[-1.6394e-01, -1.9931e-02, 9.0637e-02, -1.5919e-01, 6.6636e-02],
[ 3.0836e-02, -4.4388e-02, -1.4237e-01, 6.4615e-02, -1.9187e-02],
[ 2.0550e-01, 5.2410e-02, -8.6873e-02, 9.8989e-03, 1.0537e-02],
[ 9.6366e-02, -4.6873e-02, 2.8437e-02, 9.9133e-02, -1.2426e-02]],
[[ 4.3380e-02, -1.6536e-02, -1.1810e-01, -1.5150e-02, -3.4028e-02],
[-5.2022e-02, -4.7191e-02, 3.4113e-02, 2.2352e-02, -1.6375e-02],
[-2.4627e-03, 6.9880e-02, 3.8121e-02, -1.2208e-01, 4.7194e-02],
[ 3.0015e-02, 8.7410e-03, -2.2237e-02, 7.8091e-02, 7.9454e-02],
[-5.1503e-02, 2.3619e-02, 8.4588e-02, -1.0277e-01, 8.3154e-03]],
[[-6.7868e-02, -1.2149e-01, -1.7491e-02, 1.0022e-01, -4.3358e-02],
[ 3.1464e-02, 7.4058e-02, -3.2289e-03, -2.6757e-02, -8.5261e-02],
[ 6.9376e-02, 1.2054e-01, 3.5335e-02, 9.2568e-02, 2.1358e-02],
[-4.7479e-02, 2.5571e-02, -4.7574e-02, 1.1774e-02, -2.6935e-02],
[-4.9366e-02, 1.1098e-02, -4.2080e-02, 8.9565e-02, 4.6191e-03]],
[[-1.7267e-02, 5.0607e-02, -4.9050e-02, -1.0601e-02, -5.8215e-02],
[-1.5839e-02, -2.6607e-02, -1.8453e-02, -7.0404e-02, -1.2247e-01],
[ 1.5462e-01, 4.8792e-03, -1.8811e-01, 5.1346e-03, 3.5262e-02],
[ 6.8772e-02, 6.6209e-02, -8.8063e-02, 1.0676e-01, 1.2461e-01],
[ 4.1533e-03, -4.0087e-02, 6.4295e-02, -8.9270e-03, -4.2519e-02]],
[[ 4.6680e-02, -2.2108e-02, 4.5739e-02, -6.1995e-02, -7.3645e-02],
[ 1.2634e-01, 5.2019e-02, -1.0544e-01, -1.1378e-02, 4.0753e-03],
[ 7.9469e-02, 3.9514e-02, -1.6562e-02, 3.0074e-02, 3.8365e-02],
[-1.0846e-01, -1.3977e-01, -4.5255e-03, -7.7981e-03, -2.7929e-02],
[ 1.2333e-02, -7.6110e-02, -1.5976e-02, 2.1001e-02, -2.0793e-02]]],
[[[ 3.5135e-02, 3.3776e-02, -2.1454e-02, 4.0599e-02, 7.6321e-02],
[-2.2444e-02, 5.9005e-02, -9.7919e-02, -5.0292e-02, -2.5280e-02],
[ 8.2182e-02, 1.1689e-02, -4.0238e-02, -1.5465e-02, -1.0631e-01],
[-7.9225e-02, 3.7144e-02, -2.1691e-02, 7.6563e-02, -2.9466e-02],
[-6.8622e-02, -8.2420e-02, 1.0508e-01, 5.7378e-02, -5.2815e-02]],
[[-8.5097e-02, 8.7198e-02, 1.3588e-01, -1.0821e-02, 6.6252e-02],
[-8.5424e-02, -1.2942e-01, -3.9451e-02, 2.3167e-03, 9.8789e-02],
[ 4.4885e-02, 7.8130e-02, 6.4348e-02, 1.4060e-01, -6.9407e-02],
[ 6.7809e-02, -1.0070e-01, 8.3850e-03, 2.3277e-02, 7.3318e-02],
[ 1.9004e-02, 3.3662e-04, 1.2365e-01, 3.7783e-02, -4.4722e-02]],
[[-3.6487e-03, -3.5660e-02, -3.1783e-03, -1.0376e-01, -2.5887e-02],
[ 8.2903e-02, 1.1622e-01, 7.4534e-02, -6.7094e-02, 2.8624e-02],
[ 5.6662e-03, 7.5438e-02, 2.3080e-02, -2.7497e-02, -4.4862e-02],
[-1.0469e-01, -1.1870e-03, -1.0996e-01, 1.3349e-01, 4.3062e-02],
[-5.9718e-02, -1.3859e-02, -5.2602e-03, 4.7957e-02, 6.4259e-02]],
[[ 1.6269e-02, -5.9308e-02, 4.2326e-02, -1.7657e-01, -3.3243e-02],
[ 2.8233e-02, -3.1127e-02, 2.9193e-02, -9.8563e-02, -2.3034e-02],
[ 6.9031e-02, -3.6136e-02, -6.7969e-02, -4.8913e-02, -5.7119e-02],
[ 1.7338e-02, -1.3542e-02, -8.9504e-02, 7.9912e-02, -6.6703e-02],
[-4.8714e-02, -1.3528e-02, -9.2340e-02, 1.3472e-01, 1.0488e-01]],
[[-8.7143e-02, 4.2698e-02, -6.4632e-02, 1.7840e-02, 5.3781e-02],
[ 1.1379e-02, 3.3266e-02, -1.4602e-02, 7.8038e-02, -1.2929e-01],
[-1.1283e-01, 9.1412e-02, -1.4805e-02, -2.7218e-02, -1.6207e-01],
[-4.6992e-04, -6.5583e-02, -3.9846e-02, -5.5303e-02, -8.9434e-02],
[ 1.8681e-03, -7.4996e-02, 5.9021e-02, 2.6332e-02, 2.7155e-03]],
[[ 4.9158e-02, 7.0112e-02, -1.1107e-01, -1.1723e-01, 8.6018e-02],
[ 1.0823e-02, 6.9816e-02, -3.1992e-02, 4.0998e-02, -8.9671e-02],
[ 2.2780e-02, -3.4835e-02, 5.5046e-03, 4.8949e-02, -4.9500e-02],
[-6.3235e-02, -4.4342e-02, -2.1533e-01, -3.0609e-02, 4.5277e-02],
[-5.5601e-02, -3.0792e-02, 4.2926e-02, 2.9293e-02, -2.3960e-02]]]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345, 0.0684, 0.0584, -0.0804,
0.0098, -0.0490, -0.0535, 0.0145, 0.0056, 0.0082, -0.0256, 0.0140],
requires_grad=True)
tensor: tensor([ 0.0483, -0.0138, -0.0666, -0.0158, 0.0085, 0.0382, 0.0643, -0.0354,
0.1643, 0.0064, -0.0694, 0.0054, 0.0250, 0.0626, -0.1082, 0.0414],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., 0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.]],
[[-0., -0., 0., 0., -0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.],
[0., 0., 0., 0., -0.]],
[[-0., -0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[0., 0., -0., 0., 0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., 0., -0.]],
[[-0., 0., 0., -0., 0.],
[0., -0., 0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., -0., 0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., 0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[0., 0., 0., -0., 0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., 0., 0.],
[0., 0., -0., 0., 0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., 0.]]],
[[[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., 0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[-0., 0., -0., 0., -0.],
[-0., 0., -0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., -0., 0., -0., -0.]],
[[0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., -0., 0.]]],
...,
[[[0., 0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., -0., 0., 0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0.]],
[[0., 0., -0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., 0., 0.]],
[[-0., -0., -0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.]],
[[0., 0., 0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0.]],
[[-0., -0., -0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., -0.]],
[[-0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., -0., 0., -0.]]],
[[[-0., -0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., 0., -0.]],
[[0., -0., -0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., 0., 0., 0., -0.]],
[[0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., -0., -0.]]],
[[[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., -0., -0.]],
[[0., -0., 0., -0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., 0., 0.],
[-0., -0., 0., -0., 0.]],
[[-0., -0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., -0., 0., 0.]],
[[-0., 0., -0., -0., 0.],
[-0., 0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.]],
[[0., 0., -0., -0., 0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., -0., 0.],
[-0., 0., 0., -0., 0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-2.9409e-02, 6.0003e-02, 5.3245e-02, 2.8255e-02, 5.1865e-02],
[-1.3120e-03, 6.9924e-02, -2.1958e-02, 4.5196e-02, 5.5987e-02],
[ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
[-6.1391e-02, 2.6227e-02, 2.2223e-02, 5.1530e-02, 7.8572e-02],
[ 1.1986e-02, -3.5154e-02, 2.2590e-02, -2.8268e-02, -4.1092e-02]],
[[-7.7353e-02, -3.3001e-03, 2.2446e-02, 3.6329e-02, -8.4312e-03],
[ 5.9898e-02, 3.3921e-02, -4.0756e-02, 5.6990e-02, -7.9915e-02],
[ 1.1058e-02, -7.5468e-02, 3.1362e-02, -1.2730e-02, 7.0498e-02],
[-6.7726e-03, -3.9453e-02, 6.8676e-02, 2.8066e-02, 4.1992e-02],
[ 8.0445e-02, 6.9594e-02, 2.9287e-02, 4.4932e-02, -7.2432e-02]],
[[-6.8457e-02, -7.5792e-02, 3.8843e-02, 3.2058e-02, 2.1625e-02],
[ 7.5601e-02, -5.7455e-02, 1.2939e-02, -3.5165e-02, -5.6556e-02],
[ 8.3346e-03, 2.2498e-02, 6.8399e-03, -2.3131e-02, -2.1235e-03],
[-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
[ 4.4843e-02, 1.4243e-02, -7.8475e-02, 2.0007e-02, 6.8106e-02]],
[[-7.1642e-03, -4.0418e-02, 1.6455e-03, -7.4808e-03, 3.0932e-02],
[-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02, 7.4918e-02],
[ 2.7178e-03, 7.6559e-02, 2.3546e-02, 1.0888e-02, -3.2943e-02],
[-3.1639e-02, -4.4183e-02, -1.1435e-02, 1.8024e-02, 7.9902e-02],
[-1.6372e-02, 6.6337e-03, -5.5328e-02, 1.6804e-02, -1.1563e-02]],
[[ 3.4681e-02, 3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
[ 8.0547e-02, 4.9685e-02, 7.5298e-02, -2.9211e-02, 2.2224e-02],
[-2.4409e-02, 7.0003e-02, 2.1558e-04, -4.4977e-02, 2.3133e-02],
[-7.4964e-02, -6.4867e-02, -4.5671e-02, 3.1878e-02, 3.5104e-02],
[ 2.2541e-02, 3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],
[[-7.0267e-02, -2.8274e-03, -3.7755e-02, 3.8392e-02, -1.3160e-05],
[-2.4893e-02, 2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
[ 4.5711e-02, -3.3377e-02, 2.4278e-02, -4.1101e-03, 1.8952e-02],
[ 2.0236e-02, 6.4522e-03, -5.7711e-02, -4.5909e-02, 7.4404e-02],
[-3.7771e-02, -3.1536e-03, 6.4748e-02, -4.4604e-02, -2.2752e-02]]],
[[[-2.9177e-02, 5.0791e-02, -5.5935e-03, 4.7169e-02, -1.1366e-02],
[ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02, 6.6127e-02],
[-1.1869e-03, 7.9363e-03, 6.0124e-02, 8.7514e-04, -2.6833e-02],
[-4.4174e-02, 1.9975e-02, 2.7939e-02, 2.7791e-02, -1.7919e-02],
[-2.9191e-02, 2.5500e-02, 3.6869e-03, 3.1063e-02, -3.7672e-02]],
[[-6.7084e-02, 5.3586e-02, 1.8485e-02, -7.8099e-03, 4.6288e-02],
[ 4.9055e-02, -1.4634e-02, 3.2799e-02, 3.9726e-02, -7.0032e-02],
[ 7.0897e-02, 3.4370e-02, -1.4814e-02, -3.9030e-02, 3.0867e-02],
[ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02, 2.6245e-02],
[ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02, 3.3599e-02]],
[[ 1.1310e-02, -2.9816e-02, 6.3727e-02, 3.9850e-02, 1.3761e-02],
[ 3.0453e-02, -4.8504e-02, 5.3189e-02, 1.9425e-02, 4.7484e-02],
[ 6.1376e-02, -7.8290e-02, -6.8859e-02, 1.8497e-02, -1.1496e-02],
[-7.8178e-02, -4.5904e-02, 7.3181e-02, 2.9441e-02, 4.6967e-02],
[ 7.6978e-02, 7.2934e-02, 5.6798e-02, 5.8828e-02, 4.4637e-02]],
[[ 6.0281e-02, 7.7289e-02, 7.9016e-02, -4.1437e-02, 3.1101e-02],
[-5.0620e-02, 3.3108e-02, -5.8687e-02, 2.7694e-02, 5.4294e-02],
[ 2.1156e-02, 1.7004e-03, 2.4742e-02, 6.9593e-02, 5.7699e-02],
[ 6.8876e-02, 3.2239e-02, 3.3322e-02, 2.9973e-02, 7.4267e-02],
[-2.4736e-03, 3.8454e-02, -2.5898e-02, 2.0443e-02, 6.0816e-02]],
[[ 1.6487e-02, -6.8994e-03, 5.2835e-02, 5.7784e-02, -1.7036e-02],
[-7.1883e-02, 3.1576e-04, -5.5839e-02, 1.4949e-02, -1.9834e-03],
[-2.7395e-02, 3.9861e-04, -1.0588e-02, 9.9140e-03, 5.1499e-02],
[ 7.4060e-02, -1.0655e-02, 1.1668e-02, 4.9183e-02, 5.2846e-02],
[ 2.8634e-02, 4.2678e-02, -1.4281e-02, 1.3904e-03, 7.6289e-02]],
[[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
[-4.1579e-02, 2.8281e-02, 3.9429e-02, 7.5058e-03, 6.5170e-03],
[-3.4494e-02, 7.5710e-02, 4.1078e-02, 4.4451e-02, 4.2661e-02],
[-5.4398e-02, 5.1592e-02, -2.6367e-02, -3.2980e-02, 5.3860e-02],
[-4.2436e-02, -8.4286e-03, 7.5331e-02, -6.6725e-02, 4.9887e-02]]],
[[[-1.5211e-02, -6.2506e-03, 3.0621e-03, 3.0725e-02, -7.0877e-03],
[ 1.1974e-02, -5.2611e-02, -2.7415e-02, 4.3479e-02, -4.2108e-02],
[ 3.3816e-02, 6.1523e-02, -9.9011e-03, -3.7770e-02, 6.5915e-04],
[ 5.3678e-03, 5.9921e-02, -3.4530e-02, 5.1942e-02, 5.3762e-02],
[-4.7293e-02, -6.2274e-02, -7.5059e-02, 8.1645e-02, 2.1149e-02]],
[[-1.4459e-02, -2.7155e-02, -2.5730e-02, 7.6751e-02, -1.6932e-02],
[-5.3342e-02, -2.6885e-02, 4.3476e-02, -7.9174e-02, -3.5761e-02],
[ 4.2970e-02, 2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
[-2.5080e-02, 4.1672e-02, 4.2424e-02, -4.8704e-02, -6.0434e-02],
[ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],
[[-5.4303e-02, 6.7885e-02, -5.3922e-02, 6.5582e-02, -5.2617e-03],
[-8.4440e-03, 8.0911e-02, -3.8667e-02, -5.6241e-04, 7.0876e-02],
[ 5.4673e-02, 1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
[-2.8238e-02, -5.0765e-02, -3.4076e-02, 3.1219e-02, -1.8919e-02],
[-4.6076e-02, -7.6516e-02, 3.1247e-02, 4.1743e-02, 7.5575e-02]],
[[ 3.8787e-03, -8.1731e-03, 4.7381e-03, 5.8261e-02, 4.6416e-02],
[ 7.7171e-02, -6.5924e-02, 1.5769e-02, 2.6777e-02, 7.7365e-02],
[-7.3126e-02, 4.7624e-02, -7.0620e-02, 7.3309e-02, 7.7585e-03],
[-7.9208e-02, -7.3783e-03, -5.3142e-02, 3.4386e-02, 5.6230e-03],
[-4.4492e-02, 7.5403e-02, -2.8887e-02, 2.6937e-02, 7.0698e-03]],
[[-1.7104e-02, -6.6983e-02, -5.7655e-02, 5.1198e-02, -8.0137e-02],
[-7.2406e-02, -3.7856e-02, -7.2086e-02, 7.2641e-02, -7.0749e-02],
[ 6.6330e-02, 2.8797e-02, 6.2197e-02, -4.6068e-03, 1.7392e-03],
[ 4.2384e-02, 5.9572e-02, -4.5953e-02, 6.6345e-03, 7.3979e-02],
[-4.8313e-02, -1.3063e-02, 1.7648e-02, -5.0903e-02, -7.1852e-02]],
[[ 3.2147e-02, -8.1585e-02, 6.7507e-02, -7.7056e-02, 1.7667e-02],
[ 2.0535e-03, -7.4221e-02, 1.0343e-02, 4.3018e-02, 9.7351e-03],
[-2.1410e-02, 5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
[ 4.9847e-02, 6.9608e-02, 3.6233e-03, 5.7025e-02, 6.3206e-02],
[ 1.4611e-02, -2.9885e-02, 5.6140e-02, -6.4338e-02, 8.5266e-03]]],
...,
[[[ 5.7051e-02, 3.4026e-02, -3.7723e-02, 1.4372e-02, -4.4266e-03],
[-8.0557e-02, 1.1810e-02, -6.9374e-02, 3.4264e-02, -3.9068e-02],
[ 3.2814e-02, -4.9334e-02, -3.2234e-02, 3.7901e-02, 9.9268e-03],
[ 1.2846e-03, -5.9199e-02, -5.6303e-02, 1.2189e-03, 7.8874e-02],
[ 7.6858e-04, 2.4341e-02, 4.0423e-02, -7.7602e-02, -3.6388e-02]],
[[ 5.9494e-02, 4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
[-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
[-6.8915e-02, 1.6417e-02, -9.0384e-03, -5.6058e-02, 1.1540e-02],
[ 3.9632e-02, 3.8881e-02, 5.5834e-02, 7.5591e-02, 1.8463e-02],
[ 4.5034e-02, -6.4665e-02, 6.7883e-02, 7.1108e-02, 8.0694e-02]],
[[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
[-7.9629e-02, -2.5316e-03, -3.4649e-02, 7.9736e-02, 2.4963e-02],
[-1.4102e-02, 3.0896e-02, -5.4594e-02, 5.7641e-02, 7.8276e-02],
[ 3.3722e-02, 1.6397e-02, 6.6251e-02, 2.5637e-02, 4.0073e-02],
[ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],
[[ 3.9434e-02, 6.7049e-02, 7.1627e-02, 6.9307e-02, -5.7508e-03],
[ 2.8151e-02, 7.9890e-02, -6.4687e-02, -6.8959e-02, 6.8179e-02],
[ 1.2583e-02, 6.6052e-02, 6.7770e-02, 1.0853e-02, 6.3935e-02],
[ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
[-1.1810e-04, 6.2292e-02, -2.1831e-02, -4.1282e-02, 3.4718e-02]],
[[-8.9495e-03, -3.5923e-02, -4.9030e-02, 1.7068e-02, 5.7835e-02],
[-6.2950e-02, 6.9258e-02, 1.4909e-02, -3.9252e-02, 3.0917e-02],
[-5.0831e-02, -2.6109e-02, -4.2526e-02, 4.9180e-03, -6.7907e-02],
[-1.4867e-02, 8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
[ 6.5369e-02, 3.5118e-02, -3.5070e-02, 3.1514e-02, -1.7773e-02]],
[[-1.9000e-02, 4.8772e-02, -4.0550e-02, 5.7766e-02, -4.8687e-02],
[ 7.0112e-02, 7.4851e-02, -5.0324e-02, 4.2522e-02, 6.6367e-02],
[-6.6793e-02, -6.3487e-02, -6.3574e-02, 7.3530e-02, -6.7062e-02],
[ 1.9297e-02, 3.9876e-02, 7.0333e-03, 3.6541e-02, 3.0865e-02],
[ 6.9009e-02, 2.7737e-03, -6.0400e-02, 1.5249e-03, -1.5177e-03]]],
[[[-4.9098e-02, -1.2656e-02, 3.0326e-02, 4.6450e-02, 4.1143e-02],
[ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02, 7.3691e-02],
[-5.2880e-02, -5.3283e-02, -4.6874e-02, 2.0506e-02, 1.4432e-02],
[ 5.3466e-05, 6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
[-7.2209e-02, -2.8706e-02, 6.6109e-02, 5.8108e-02, -1.8114e-02]],
[[-5.8877e-02, 3.5183e-02, 6.5460e-02, 5.2934e-02, 3.5997e-02],
[-6.5718e-02, 2.7700e-02, 7.1110e-02, -5.7825e-02, 6.1866e-03],
[-5.3281e-03, -7.6189e-02, -6.9421e-02, 6.4743e-02, -1.1912e-02],
[ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02, 4.5729e-02],
[ 2.0473e-03, -3.1893e-02, -3.0088e-02, 6.1322e-02, -1.3287e-02]],
[[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
[-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
[-5.0579e-02, 5.1921e-02, 4.0006e-02, -5.3846e-02, 3.6710e-03],
[-5.5284e-03, 5.2453e-02, 3.5617e-02, -4.4475e-02, 2.7835e-02],
[ 3.6465e-02, 2.2936e-02, 4.9494e-02, -6.8768e-02, -6.8512e-02]],
[[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
[ 3.3837e-02, 6.9886e-02, 2.5937e-03, -4.4428e-02, -6.1442e-03],
[-3.3799e-02, 7.6725e-02, 1.5202e-02, 2.7467e-02, -7.2112e-02],
[-5.3887e-02, 5.3134e-02, -5.5426e-02, 8.1476e-02, 1.0773e-02],
[-1.9578e-02, 1.7628e-02, -2.2382e-02, 6.7076e-02, -1.3475e-02]],
[[ 3.7281e-02, 2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
[-5.1809e-02, -1.0889e-02, 4.4019e-02, -4.0099e-02, -6.2939e-02],
[ 7.8826e-02, 1.4336e-02, -7.8953e-02, -4.1699e-03, 2.1759e-02],
[ 4.3422e-02, 6.1053e-02, -5.1035e-02, 2.5170e-02, 8.1194e-02],
[-3.5907e-02, 3.5084e-02, 5.4858e-02, 5.7819e-02, -6.8527e-02]],
[[ 6.0340e-02, -4.5873e-02, 4.5307e-02, -1.8559e-02, -5.9891e-02],
[ 7.1101e-02, 5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
[ 4.5972e-02, 4.6237e-02, -1.8353e-02, 5.5372e-03, 5.8802e-02],
[-8.0939e-02, 2.2098e-03, -2.7943e-03, 6.9556e-02, 3.5299e-03],
[-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],
[[[-5.8326e-02, 4.3804e-02, 5.4642e-02, 2.9479e-02, 5.5766e-02],
[-6.2955e-02, 4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
[ 7.8974e-02, 1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
[ 1.1472e-02, 6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
[-5.7437e-02, -5.7150e-02, 4.9108e-02, 2.2168e-02, -5.4964e-02]],
[[-3.2895e-02, -2.2746e-03, 6.8428e-02, -7.4781e-02, 6.5675e-02],
[-8.0232e-02, -2.6468e-02, -2.1136e-02, 2.1449e-02, 6.4572e-02],
[ 2.9930e-03, 1.1987e-02, 4.8122e-03, 3.4183e-02, -7.8918e-02],
[ 6.3749e-02, -2.5083e-02, 1.1253e-02, -4.4485e-02, 3.3380e-02],
[ 5.0096e-03, -1.7321e-02, 8.0185e-02, -2.3853e-02, -2.9333e-03]],
[[ 2.6648e-02, -7.6799e-02, 3.2204e-03, -7.7476e-02, -4.4615e-03],
[ 5.7110e-02, 7.8575e-02, 5.3204e-02, -7.8592e-02, 4.1383e-03],
[ 1.6194e-02, 2.5400e-02, 7.4070e-02, -3.9092e-03, -2.9417e-02],
[-7.9407e-02, 2.5042e-02, -3.8854e-02, 2.8143e-02, 2.8485e-03],
[-3.3828e-02, -7.5645e-02, 7.8511e-02, -4.4048e-02, 6.0887e-02]],
[[-6.4552e-02, -3.1646e-02, 6.5499e-02, -6.8577e-02, -5.1529e-02],
[ 6.1176e-02, -4.8461e-02, 4.7687e-02, -3.0069e-02, -1.7665e-02],
[ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
[ 6.1360e-02, -1.9826e-02, -6.4593e-02, 3.5071e-02, -5.9178e-02],
[-6.6739e-02, 2.6098e-02, -5.5998e-02, 8.1334e-02, 3.7472e-02]],
[[-5.5207e-02, 1.4355e-02, -2.2037e-02, -2.4025e-02, 7.2631e-02],
[-1.0448e-02, 1.9105e-03, -5.5223e-02, 4.6377e-02, -6.8534e-02],
[-2.4292e-02, 7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
[ 4.5334e-02, -2.3274e-02, -4.3572e-02, 4.3487e-03, -4.6057e-02],
[-5.3757e-02, -2.0336e-02, -5.2245e-02, 2.2213e-02, -6.7578e-03]],
[[ 5.7154e-02, 6.9033e-02, -2.7450e-02, -5.9039e-02, 3.0233e-02],
[ 5.5904e-02, 5.2798e-02, -2.2586e-02, 2.8411e-02, -6.8010e-03],
[ 5.1257e-02, -4.3710e-02, 8.7161e-03, 1.9411e-02, -3.5285e-03],
[-8.0450e-02, 6.1012e-02, -7.7756e-02, -2.1472e-02, 4.7537e-02],
[-4.7231e-02, 3.7300e-02, 2.7754e-02, -2.4025e-02, 1.0065e-02]]]])
(bias): Normal:
loc: tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345, 0.0684, 0.0584, -0.0804,
0.0098, -0.0490, -0.0535, 0.0145, 0.0056, 0.0082, -0.0256, 0.0140])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
...,
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, ..., 0.0498, 0.0498, 0.0498]],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[-0.0489, -0.0457, 0.0358, ..., 0.0488, -0.0310, -0.0318],
[-0.0198, 0.0492, -0.0495, ..., 0.0437, -0.0228, -0.0161],
[ 0.0042, 0.0213, -0.0018, ..., -0.0004, 0.0377, 0.0324],
...,
[ 0.0020, -0.0197, 0.0377, ..., -0.0133, -0.0496, 0.0166],
[ 0.0128, -0.0165, 0.0298, ..., -0.0352, 0.0281, 0.0219],
[ 0.0448, -0.0166, -0.0012, ..., -0.0042, -0.0289, -0.0339]],
requires_grad=True)
tensor: tensor([[-0.0611, -0.0210, -0.0065, ..., 0.1009, -0.0345, -0.0676],
[-0.0331, 0.0645, 0.0254, ..., 0.0572, -0.0789, 0.0253],
[-0.0835, 0.0054, -0.0547, ..., 0.0489, -0.0221, -0.0500],
...,
[-0.0506, -0.0202, 0.0091, ..., -0.0377, -0.0570, 0.0214],
[ 0.0031, 0.0121, 0.0345, ..., -0.0832, 0.0463, 0.1483],
[ 0.0786, -0.0678, 0.0186, ..., -0.0506, -0.0724, -0.0756]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([ 0.0350, 0.0458, -0.0330, 0.0477, 0.0383, -0.0136, -0.0182, 0.0285,
0.0197, -0.0431, 0.0120, -0.0445, 0.0171, -0.0019, -0.0141, -0.0021,
-0.0429, -0.0159, 0.0028, 0.0272, -0.0290, 0.0047, 0.0452, -0.0022,
0.0279, 0.0323, -0.0433, 0.0049, 0.0063, -0.0388, 0.0090, -0.0233,
0.0251, 0.0375, 0.0274, -0.0337, 0.0122, 0.0217, 0.0230, -0.0405,
-0.0476, -0.0063, -0.0021, 0.0267, 0.0014, 0.0228, -0.0130, -0.0471,
-0.0170, -0.0349, -0.0472, 0.0116, 0.0002, -0.0426, -0.0129, 0.0492,
0.0117, -0.0143, -0.0025, 0.0040, -0.0466, -0.0037, 0.0341, -0.0261,
0.0327, -0.0433, 0.0025, 0.0201, 0.0211, -0.0235, 0.0472, -0.0291,
0.0431, -0.0314, -0.0255, 0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
-0.0305, -0.0172, 0.0238, 0.0029, -0.0029, 0.0172, 0.0227, 0.0006,
0.0120, -0.0068, -0.0043, -0.0289, 0.0060, 0.0199, 0.0122, 0.0423,
-0.0015, -0.0034, -0.0201, -0.0374, 0.0159, -0.0258, 0.0075, -0.0097,
-0.0048, 0.0477, -0.0470, 0.0045, 0.0128, -0.0441, 0.0218, 0.0365,
-0.0206, 0.0348, -0.0249, 0.0256, 0.0222, 0.0019, 0.0289, 0.0248],
requires_grad=True)
tensor: tensor([-0.0037, 0.0363, -0.0302, 0.0283, -0.0190, 0.0309, 0.0084, -0.0720,
0.0065, -0.0307, 0.0175, -0.0668, 0.0494, 0.0374, 0.0391, 0.0511,
-0.0019, -0.0578, -0.0897, 0.0769, -0.0574, 0.0433, 0.0668, 0.0505,
-0.0497, 0.0800, -0.1190, 0.0599, 0.0376, -0.0340, 0.0294, -0.0257,
0.0688, 0.0540, 0.0491, -0.0268, 0.1529, -0.0314, 0.0392, -0.1326,
-0.0932, -0.0667, 0.0950, 0.0431, 0.1215, -0.0887, 0.0444, -0.0129,
0.0161, -0.0260, 0.0330, 0.0153, -0.0312, -0.1064, 0.0864, 0.0355,
0.0347, -0.0217, -0.0082, 0.0588, -0.0612, -0.0228, 0.1342, -0.1155,
0.0447, 0.0436, -0.0388, -0.0279, 0.0129, -0.0139, 0.0634, -0.0067,
0.0634, -0.0554, 0.0294, -0.0108, -0.0582, -0.1314, -0.0447, -0.0445,
0.0491, -0.0629, 0.1620, 0.0590, -0.0344, -0.1214, -0.0177, -0.0116,
0.0295, -0.0703, -0.0219, -0.1209, 0.0475, 0.0840, -0.1216, 0.0770,
0.0231, -0.0501, 0.0136, -0.0599, 0.0445, -0.0822, -0.0009, 0.1143,
0.0308, -0.0435, -0.0758, -0.0836, -0.0294, 0.0312, 0.0669, 0.0201,
0.0623, 0.0240, -0.0735, 0.0121, 0.0403, 0.0362, 0.0488, 0.0317],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., -0., 0., ..., 0., -0., -0.],
[-0., 0., -0., ..., 0., -0., -0.],
[0., 0., -0., ..., -0., 0., 0.],
...,
[0., -0., 0., ..., -0., -0., 0.],
[0., -0., 0., ..., -0., 0., 0.],
[0., -0., -0., ..., -0., -0., -0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.0489, -0.0457, 0.0358, ..., 0.0488, -0.0310, -0.0318],
[-0.0198, 0.0492, -0.0495, ..., 0.0437, -0.0228, -0.0161],
[ 0.0042, 0.0213, -0.0018, ..., -0.0004, 0.0377, 0.0324],
...,
[ 0.0020, -0.0197, 0.0377, ..., -0.0133, -0.0496, 0.0166],
[ 0.0128, -0.0165, 0.0298, ..., -0.0352, 0.0281, 0.0219],
[ 0.0448, -0.0166, -0.0012, ..., -0.0042, -0.0289, -0.0339]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
-0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
-0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0350, 0.0458, -0.0330, 0.0477, 0.0383, -0.0136, -0.0182, 0.0285,
0.0197, -0.0431, 0.0120, -0.0445, 0.0171, -0.0019, -0.0141, -0.0021,
-0.0429, -0.0159, 0.0028, 0.0272, -0.0290, 0.0047, 0.0452, -0.0022,
0.0279, 0.0323, -0.0433, 0.0049, 0.0063, -0.0388, 0.0090, -0.0233,
0.0251, 0.0375, 0.0274, -0.0337, 0.0122, 0.0217, 0.0230, -0.0405,
-0.0476, -0.0063, -0.0021, 0.0267, 0.0014, 0.0228, -0.0130, -0.0471,
-0.0170, -0.0349, -0.0472, 0.0116, 0.0002, -0.0426, -0.0129, 0.0492,
0.0117, -0.0143, -0.0025, 0.0040, -0.0466, -0.0037, 0.0341, -0.0261,
0.0327, -0.0433, 0.0025, 0.0201, 0.0211, -0.0235, 0.0472, -0.0291,
0.0431, -0.0314, -0.0255, 0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
-0.0305, -0.0172, 0.0238, 0.0029, -0.0029, 0.0172, 0.0227, 0.0006,
0.0120, -0.0068, -0.0043, -0.0289, 0.0060, 0.0199, 0.0122, 0.0423,
-0.0015, -0.0034, -0.0201, -0.0374, 0.0159, -0.0258, 0.0075, -0.0097,
-0.0048, 0.0477, -0.0470, 0.0045, 0.0128, -0.0441, 0.0218, 0.0365,
-0.0206, 0.0348, -0.0249, 0.0256, 0.0222, 0.0019, 0.0289, 0.0248])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=2, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498, 0.0498, 0.0498]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806, 0.0095, -0.0646, 0.0145,
-0.0077, -0.0512, -0.0072, -0.0008, 0.0279, -0.0383, -0.0224, 0.0434,
-0.0912, -0.0509, 0.0756, -0.0889, 0.0356, -0.0862, -0.0046, 0.0507,
-0.0356, -0.0616, -0.0509, -0.0035, 0.0123, 0.0190, -0.0453, 0.0815,
-0.0149, 0.0448, -0.0308, -0.0292, -0.0423, 0.0691, 0.0686, -0.0398,
-0.0657, 0.0157, -0.0508, 0.0847, -0.0897, 0.0655, 0.0407, 0.0535,
-0.0541, -0.0812, 0.0122, -0.0665, -0.0799, 0.0247, -0.0409, 0.0105,
-0.0471, -0.0825, -0.0042, 0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
0.0852, 0.0711, 0.0833, 0.0214, 0.0743, 0.0575, 0.0583, -0.0007,
0.0814, 0.0736, -0.0248, 0.0284, 0.0873, -0.0174, 0.0206, -0.0740,
0.0276, -0.0414, 0.0508, -0.0087, -0.0581, 0.0255, 0.0058, 0.0142,
-0.0266, 0.0067, -0.0468, 0.0654, 0.0305, -0.0043, -0.0613, 0.0733,
0.0400, -0.0446, -0.0243, -0.0434, -0.0616, 0.0371, 0.0253, 0.0681,
0.0847, -0.0068, 0.0176, -0.0169, -0.0387, 0.0219, -0.0046, -0.0663],
[-0.0324, -0.0686, 0.0105, 0.0805, 0.0090, 0.0304, 0.0097, -0.0191,
-0.0591, 0.0876, -0.0748, 0.0383, 0.0680, 0.0441, 0.0479, 0.0484,
0.0302, -0.0039, 0.0855, -0.0066, 0.0661, -0.0492, 0.0843, -0.0566,
0.0517, 0.0880, 0.0308, -0.0874, -0.0144, 0.0143, -0.0663, -0.0484,
-0.0368, 0.0709, 0.0610, 0.0495, -0.0031, -0.0503, 0.0562, -0.0030,
0.0753, 0.0173, 0.0221, -0.0259, 0.0145, 0.0206, -0.0740, 0.0226,
-0.0414, 0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
0.0882, 0.0519, 0.0840, 0.0558, 0.0087, 0.0270, 0.0901, 0.0010,
0.0620, 0.0696, 0.0825, 0.0557, -0.0043, -0.0531, -0.0447, 0.0474,
0.0724, 0.0483, -0.0868, 0.0503, -0.0060, 0.0524, -0.0355, 0.0002,
-0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416, 0.0861,
0.0530, 0.0840, -0.0316, -0.0839, -0.0451, -0.0664, 0.0725, 0.0301,
0.0456, -0.0145, 0.0455, -0.0850, 0.0010, -0.0722, -0.0800, 0.0512,
-0.0753, -0.0348, -0.0249, 0.0067, 0.0063, -0.0506, -0.0310, 0.0592,
0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]],
requires_grad=True)
tensor: tensor([[ 0.1279, -0.0866, -0.0686, 0.0555, -0.0638, 0.0050, -0.0615, -0.0530,
-0.0351, -0.0392, -0.0150, -0.0328, 0.0167, 0.0180, -0.0615, 0.2103,
-0.0772, -0.0698, 0.1113, -0.0363, 0.0195, -0.1669, 0.1130, 0.0848,
-0.0937, -0.1075, -0.0521, 0.0558, -0.0513, -0.0210, 0.0153, 0.0361,
-0.0155, 0.0751, -0.0595, -0.0718, -0.1218, 0.1084, -0.0045, -0.0181,
-0.0626, 0.0720, -0.0465, 0.0618, -0.0506, 0.0861, 0.0443, 0.0697,
-0.0507, -0.0497, -0.0155, -0.0383, -0.0927, 0.0032, 0.0311, 0.0651,
-0.0475, -0.0090, -0.0013, 0.0427, 0.0172, -0.0191, -0.0302, 0.0091,
0.0124, -0.1253, -0.0366, -0.0554, -0.0467, -0.0135, -0.1149, -0.0916,
0.1359, 0.0563, 0.0745, 0.0458, 0.0777, 0.0731, 0.1130, -0.0496,
0.0771, 0.1293, -0.0316, 0.0355, 0.0950, 0.0435, 0.0511, -0.0068,
-0.0193, -0.0411, -0.0321, -0.0270, -0.0606, 0.0143, 0.1399, 0.1026,
-0.0679, 0.0816, -0.0428, 0.1694, -0.0943, -0.0838, -0.0730, 0.0811,
-0.0070, -0.1130, -0.0120, -0.0559, -0.0140, 0.0470, 0.0229, 0.0403,
0.1721, -0.0774, 0.0857, 0.0897, -0.0025, -0.0037, -0.1085, -0.0464],
[-0.0931, -0.0503, 0.0025, 0.0769, 0.0082, -0.0647, 0.0253, 0.0304,
-0.1109, 0.1148, -0.0474, 0.0097, -0.0118, 0.0804, 0.1731, 0.0798,
0.1452, 0.0859, 0.0719, -0.0387, 0.0086, 0.0264, 0.1407, -0.0077,
0.0467, 0.0441, 0.0155, -0.1639, 0.1011, 0.0373, -0.0116, -0.0501,
-0.0860, 0.0579, -0.0236, 0.0648, 0.0227, -0.1241, 0.1053, 0.0215,
0.1890, 0.0845, 0.0857, -0.1136, 0.0768, -0.0305, -0.0581, 0.0310,
-0.0614, 0.1498, -0.0297, 0.0213, -0.0608, -0.0596, -0.0909, 0.0021,
0.1340, 0.0689, 0.0496, 0.0734, -0.0441, -0.0336, 0.1958, -0.0212,
0.1678, 0.0891, 0.1948, 0.0640, -0.0615, 0.0174, -0.0600, 0.0613,
0.0016, 0.1543, -0.1042, 0.0458, 0.0164, 0.1102, -0.0672, -0.0036,
-0.0017, -0.0501, 0.0129, -0.0794, -0.0849, -0.0259, -0.1499, 0.0300,
0.0810, 0.1579, -0.0678, -0.1053, 0.0361, -0.0487, 0.0602, -0.0358,
0.0447, -0.0971, 0.0473, -0.1580, -0.0150, -0.0541, -0.1103, 0.0850,
-0.0783, -0.0095, -0.0338, -0.0945, -0.0547, -0.0184, -0.1288, 0.0039,
-0.0115, -0.0988, -0.0244, -0.1131, 0.0605, -0.0540, -0.0119, -0.0788]],
grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0.0576, 0.0882], requires_grad=True)
tensor: tensor([-0.0505, 0.0428], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., 0.,
-0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
-0., -0., 0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0.,
0., 0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0.,
-0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., 0., -0., -0., 0., -0., -0.],
[-0., -0., 0., 0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., -0., 0., -0., 0., -0.,
0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.,
-0., 0., -0., -0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., -0., 0.,
0., 0., -0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806, 0.0095, -0.0646, 0.0145,
-0.0077, -0.0512, -0.0072, -0.0008, 0.0279, -0.0383, -0.0224, 0.0434,
-0.0912, -0.0509, 0.0756, -0.0889, 0.0356, -0.0862, -0.0046, 0.0507,
-0.0356, -0.0616, -0.0509, -0.0035, 0.0123, 0.0190, -0.0453, 0.0815,
-0.0149, 0.0448, -0.0308, -0.0292, -0.0423, 0.0691, 0.0686, -0.0398,
-0.0657, 0.0157, -0.0508, 0.0847, -0.0897, 0.0655, 0.0407, 0.0535,
-0.0541, -0.0812, 0.0122, -0.0665, -0.0799, 0.0247, -0.0409, 0.0105,
-0.0471, -0.0825, -0.0042, 0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
0.0852, 0.0711, 0.0833, 0.0214, 0.0743, 0.0575, 0.0583, -0.0007,
0.0814, 0.0736, -0.0248, 0.0284, 0.0873, -0.0174, 0.0206, -0.0740,
0.0276, -0.0414, 0.0508, -0.0087, -0.0581, 0.0255, 0.0058, 0.0142,
-0.0266, 0.0067, -0.0468, 0.0654, 0.0305, -0.0043, -0.0613, 0.0733,
0.0400, -0.0446, -0.0243, -0.0434, -0.0616, 0.0371, 0.0253, 0.0681,
0.0847, -0.0068, 0.0176, -0.0169, -0.0387, 0.0219, -0.0046, -0.0663],
[-0.0324, -0.0686, 0.0105, 0.0805, 0.0090, 0.0304, 0.0097, -0.0191,
-0.0591, 0.0876, -0.0748, 0.0383, 0.0680, 0.0441, 0.0479, 0.0484,
0.0302, -0.0039, 0.0855, -0.0066, 0.0661, -0.0492, 0.0843, -0.0566,
0.0517, 0.0880, 0.0308, -0.0874, -0.0144, 0.0143, -0.0663, -0.0484,
-0.0368, 0.0709, 0.0610, 0.0495, -0.0031, -0.0503, 0.0562, -0.0030,
0.0753, 0.0173, 0.0221, -0.0259, 0.0145, 0.0206, -0.0740, 0.0226,
-0.0414, 0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
0.0882, 0.0519, 0.0840, 0.0558, 0.0087, 0.0270, 0.0901, 0.0010,
0.0620, 0.0696, 0.0825, 0.0557, -0.0043, -0.0531, -0.0447, 0.0474,
0.0724, 0.0483, -0.0868, 0.0503, -0.0060, 0.0524, -0.0355, 0.0002,
-0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416, 0.0861,
0.0530, 0.0840, -0.0316, -0.0839, -0.0451, -0.0664, 0.0725, 0.0301,
0.0456, -0.0145, 0.0455, -0.0850, 0.0010, -0.0722, -0.0800, 0.0512,
-0.0753, -0.0348, -0.0249, 0.0067, 0.0063, -0.0506, -0.0310, 0.0592,
0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]])
(bias): Normal:
loc: tensor([-0., 0.])
scale: tensor([0.7071, 0.7071])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0576, 0.0882])
)
(observed): Observed()
)
)
Fit the model¶
Finally we can set up the training loop
optim = torch.optim.Adam(net.parameters())
for i in range(1):
for data, target in loader:
net.observe(classification=target)
borch.sample(net)
net(data)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
loss.backward()
optim.step()
optim.zero_grad()
Now we can check the accuracy, Note that one should stop condtioning on the target by setting net.observe(None)
net.observe(None)
tot_acc = 0
with torch.no_grad():
for i, (data, target) in enumerate(loader):
borch.sample(net)
out = net(data)
acc = float((target == out).sum().float() / target.shape[0]) * 100
tot_acc += acc
tot_acc /= i + 1
print(tot_acc)
Out:
69.9999988079071
the accuracy is basically random, this is due to the fact that we are fitting white noise so it to be expected.
But in case you have trouble getting higher accuracy you should consider running for more epochs, setting up an augmentation pipeline (see: the data loading tutorial) and changing your posterior. The posterior can be changed using
net.apply(borch.set_posteriors(borch.posterior.Automatic))
Out:
Net(
(posterior): Automatic()
(prior): Module(
(classification): Categorical:
logits: tensor([[-0.1017, -0.1535],
[ 0.1836, -0.1202],
[ 0.2641, -0.0718],
[ 0.1435, 0.0717],
[ 0.1556, 0.0155],
[ 0.0704, 0.0826],
[-0.0427, -0.1238],
[-0.1261, -0.1772],
[ 0.0826, 0.0842],
[ 0.0605, 0.0630],
[ 0.1126, -0.0491],
[ 0.1652, -0.0307],
[ 0.2373, -0.1125],
[ 0.2617, 0.0919],
[ 0.0945, -0.0213],
[-0.1082, -0.0242],
[ 0.2161, 0.0436],
[ 0.2820, -0.0276],
[ 0.1257, -0.0373],
[ 0.2219, 0.0406]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([])
)
(observed): Observed()
(conv1): Conv2d(
1, 6, kernel_size=(5, 5), stride=(1, 1)
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., -0., 0., -0., 0., 0.], requires_grad=True)
tensor: tensor([-0.5569, -0.0093, -0.1423, -0.3891, -0.2014, -0.4750],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., -0., 0., 0.],
[-0., -0., 0., 0., -0.],
[0., 0., 0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., 0.],
[-0., 0., 0., -0., -0.],
[0., 0., 0., -0., -0.]]],
[[[-0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[-0., -0., 0., -0., 0.],
[0., 0., -0., 0., -0.]]],
[[[0., -0., -0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., -0.],
[0., 0., 0., 0., -0.],
[-0., -0., 0., 0., 0.]]],
[[[-0., 0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., -0., -0.]]],
[[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., -0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[0., -0., 0., -0., 0.]]]])
scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],
[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
[0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-2.3901e-02, 1.7009e-01, -4.0790e-02, 1.9047e-01, 1.1854e-01],
[-1.0204e-01, -1.4715e-01, 9.0185e-03, 8.0316e-02, -1.4466e-01],
[ 1.7187e-01, 8.4639e-02, 1.6665e-01, -1.5608e-01, -9.5238e-02],
[ 1.9682e-01, -9.3068e-02, -1.9482e-02, 1.6220e-01, -5.5923e-02],
[-1.5368e-01, -1.2494e-01, 3.1866e-02, -1.7428e-01, -6.5961e-02]]],
[[[-5.4801e-02, 5.0452e-02, -1.5372e-01, -1.1482e-01, 1.5138e-01],
[ 1.8247e-03, 4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
[-1.9791e-02, -1.0920e-01, 1.2796e-01, -6.0495e-02, 1.2142e-01],
[-1.0610e-01, 2.8335e-02, 2.0862e-02, -1.2132e-01, -5.6049e-03],
[ 2.7007e-02, 1.5627e-01, 7.0422e-02, -2.1336e-03, -5.9226e-02]]],
[[[-1.6497e-01, -9.5347e-02, 9.7235e-02, 1.7565e-01, -1.4118e-01],
[-1.1203e-02, -5.6668e-02, 9.0249e-02, 1.9961e-01, -2.0049e-02],
[-4.5493e-02, -2.0235e-02, -1.9463e-01, 1.5131e-01, 1.6076e-01],
[-1.9071e-01, -1.6333e-01, 1.0380e-01, -7.2042e-02, 1.0249e-01],
[ 1.7660e-01, 1.8708e-02, -1.0379e-01, 8.4113e-02, -1.3492e-01]]],
[[[ 1.3284e-01, -1.7679e-02, -9.9538e-02, 1.5133e-01, 1.0864e-01],
[-1.9522e-01, 1.0066e-01, -1.0742e-01, -1.1599e-01, 1.6930e-01],
[-1.0281e-01, -1.4473e-01, 1.6300e-01, -7.4540e-02, -6.5797e-02],
[ 1.5015e-01, 7.7701e-03, 7.3404e-02, 9.5653e-02, -1.2661e-01],
[-3.2228e-02, -7.9872e-02, 1.9932e-01, 6.0159e-02, 1.3894e-01]]],
[[[-1.7235e-01, 5.8651e-06, 8.9371e-02, -1.5355e-01, -1.2702e-01],
[ 5.9223e-02, 1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
[ 9.6366e-03, -1.0979e-01, -1.1570e-01, 7.1673e-03, 8.9918e-02],
[ 2.4720e-02, 5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
[-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],
[[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
[-7.8788e-03, 4.7546e-02, 1.4185e-01, 8.6187e-02, -1.0654e-01],
[-6.4892e-02, -7.4628e-02, 5.9973e-02, 5.0245e-02, 1.1456e-01],
[-4.4647e-02, 9.8620e-02, -1.2445e-01, -6.6966e-02, 3.5321e-02],
[ 1.9966e-01, -3.8652e-03, 3.5615e-03, -1.3894e-02, 3.5925e-02]]]])
(bias): Normal:
loc: tensor([-0., -0., 0., -0., 0., 0.])
scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.1531, -0.0004, 0.1924, -0.0493, 0.0953, 0.0265])
)
(observed): Observed()
)
(conv2): Conv2d(
6, 16, kernel_size=(5, 5), stride=(1, 1)
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.],
requires_grad=True)
tensor: tensor([ 0.3142, 0.4254, -0.2881, -0.2760, 0.0442, 0.1406, -0.1709, -0.1576,
-0.0890, 0.1179, 0.0289, 0.0774, -0.6501, -0.1156, -0.0588, 0.1319],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[[[-0., 0., 0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[0., -0., 0., -0., -0.]],
[[-0., -0., 0., 0., -0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., -0., 0.],
[-0., -0., 0., 0., 0.],
[0., 0., 0., 0., -0.]],
[[-0., -0., 0., 0., 0.],
[0., -0., 0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., -0., -0., -0., -0.],
[0., 0., -0., 0., 0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., -0., -0., 0.],
[0., 0., 0., 0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., -0., -0.],
[0., 0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.]],
[[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., -0., 0., -0., 0.],
[0., 0., -0., -0., 0.],
[-0., -0., 0., -0., -0.]]],
[[[-0., 0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., 0., 0., -0.]],
[[-0., 0., 0., -0., 0.],
[0., -0., 0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., -0., 0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., 0.],
[0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.],
[0., 0., 0., 0., 0.]],
[[0., 0., 0., -0., 0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[0., -0., 0., 0., -0.],
[-0., 0., -0., 0., -0.],
[-0., 0., -0., 0., 0.],
[0., -0., 0., 0., 0.],
[0., 0., -0., 0., 0.]],
[[-0., -0., -0., -0., -0.],
[-0., 0., 0., 0., 0.],
[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., 0.],
[-0., -0., 0., -0., 0.]]],
[[[-0., -0., 0., 0., -0.],
[0., -0., -0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., 0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., 0., 0., -0., -0.],
[0., -0., -0., -0., -0.]],
[[-0., 0., -0., 0., -0.],
[-0., 0., -0., -0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., -0., 0., -0.],
[-0., -0., 0., 0., 0.]],
[[0., -0., 0., 0., 0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.]],
[[-0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., -0., 0., -0., -0.]],
[[0., -0., 0., -0., 0.],
[0., -0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., -0., 0.]]],
...,
[[[0., 0., -0., 0., -0.],
[-0., 0., -0., 0., -0.],
[0., -0., -0., 0., 0.],
[0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0.]],
[[0., 0., -0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., 0., 0., 0.]],
[[-0., -0., -0., -0., -0.],
[-0., -0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.]],
[[0., 0., 0., 0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., 0., 0., 0.],
[0., -0., -0., -0., -0.],
[-0., 0., -0., -0., 0.]],
[[-0., -0., -0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., 0., -0.]],
[[-0., 0., -0., 0., -0.],
[0., 0., -0., 0., 0.],
[-0., -0., -0., 0., -0.],
[0., 0., 0., 0., 0.],
[0., 0., -0., 0., -0.]]],
[[[-0., -0., 0., 0., 0.],
[0., -0., -0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[-0., 0., 0., 0., 0.],
[-0., 0., 0., -0., 0.],
[-0., -0., -0., 0., -0.],
[0., -0., -0., -0., 0.],
[0., -0., -0., 0., -0.]],
[[0., -0., -0., -0., -0.],
[-0., -0., -0., -0., -0.],
[-0., 0., 0., -0., 0.],
[-0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.]],
[[-0., -0., -0., -0., -0.],
[0., 0., 0., -0., -0.],
[-0., 0., 0., 0., -0.],
[-0., 0., -0., 0., 0.],
[-0., 0., -0., 0., -0.]],
[[0., 0., -0., -0., -0.],
[-0., -0., 0., -0., -0.],
[0., 0., -0., -0., 0.],
[0., 0., -0., 0., 0.],
[-0., 0., 0., 0., -0.]],
[[0., -0., 0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., 0., 0.],
[-0., 0., -0., 0., 0.],
[-0., -0., -0., -0., -0.]]],
[[[-0., 0., 0., 0., 0.],
[-0., 0., -0., -0., -0.],
[0., 0., -0., -0., -0.],
[0., 0., -0., -0., -0.],
[-0., -0., 0., 0., -0.]],
[[-0., -0., 0., -0., 0.],
[-0., -0., -0., 0., 0.],
[0., 0., 0., 0., -0.],
[0., -0., 0., -0., 0.],
[0., -0., 0., -0., -0.]],
[[0., -0., 0., -0., -0.],
[0., 0., 0., -0., 0.],
[0., 0., 0., -0., -0.],
[-0., 0., -0., 0., 0.],
[-0., -0., 0., -0., 0.]],
[[-0., -0., 0., -0., -0.],
[0., -0., 0., -0., -0.],
[0., -0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., 0., -0., 0., 0.]],
[[-0., 0., -0., -0., 0.],
[-0., 0., -0., 0., -0.],
[-0., 0., -0., -0., -0.],
[0., -0., -0., 0., -0.],
[-0., -0., -0., 0., -0.]],
[[0., 0., -0., -0., 0.],
[0., 0., -0., 0., -0.],
[0., -0., 0., 0., -0.],
[-0., 0., -0., -0., 0.],
[-0., 0., 0., -0., 0.]]]])
scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
...,
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],
[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],
[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
[0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[[[-2.9409e-02, 6.0003e-02, 5.3245e-02, 2.8255e-02, 5.1865e-02],
[-1.3120e-03, 6.9924e-02, -2.1958e-02, 4.5196e-02, 5.5987e-02],
[ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
[-6.1391e-02, 2.6227e-02, 2.2223e-02, 5.1530e-02, 7.8572e-02],
[ 1.1986e-02, -3.5154e-02, 2.2590e-02, -2.8268e-02, -4.1092e-02]],
[[-7.7353e-02, -3.3001e-03, 2.2446e-02, 3.6329e-02, -8.4312e-03],
[ 5.9898e-02, 3.3921e-02, -4.0756e-02, 5.6990e-02, -7.9915e-02],
[ 1.1058e-02, -7.5468e-02, 3.1362e-02, -1.2730e-02, 7.0498e-02],
[-6.7726e-03, -3.9453e-02, 6.8676e-02, 2.8066e-02, 4.1992e-02],
[ 8.0445e-02, 6.9594e-02, 2.9287e-02, 4.4932e-02, -7.2432e-02]],
[[-6.8457e-02, -7.5792e-02, 3.8843e-02, 3.2058e-02, 2.1625e-02],
[ 7.5601e-02, -5.7455e-02, 1.2939e-02, -3.5165e-02, -5.6556e-02],
[ 8.3346e-03, 2.2498e-02, 6.8399e-03, -2.3131e-02, -2.1235e-03],
[-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
[ 4.4843e-02, 1.4243e-02, -7.8475e-02, 2.0007e-02, 6.8106e-02]],
[[-7.1642e-03, -4.0418e-02, 1.6455e-03, -7.4808e-03, 3.0932e-02],
[-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02, 7.4918e-02],
[ 2.7178e-03, 7.6559e-02, 2.3546e-02, 1.0888e-02, -3.2943e-02],
[-3.1639e-02, -4.4183e-02, -1.1435e-02, 1.8024e-02, 7.9902e-02],
[-1.6372e-02, 6.6337e-03, -5.5328e-02, 1.6804e-02, -1.1563e-02]],
[[ 3.4681e-02, 3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
[ 8.0547e-02, 4.9685e-02, 7.5298e-02, -2.9211e-02, 2.2224e-02],
[-2.4409e-02, 7.0003e-02, 2.1558e-04, -4.4977e-02, 2.3133e-02],
[-7.4964e-02, -6.4867e-02, -4.5671e-02, 3.1878e-02, 3.5104e-02],
[ 2.2541e-02, 3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],
[[-7.0267e-02, -2.8274e-03, -3.7755e-02, 3.8392e-02, -1.3160e-05],
[-2.4893e-02, 2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
[ 4.5711e-02, -3.3377e-02, 2.4278e-02, -4.1101e-03, 1.8952e-02],
[ 2.0236e-02, 6.4522e-03, -5.7711e-02, -4.5909e-02, 7.4404e-02],
[-3.7771e-02, -3.1536e-03, 6.4748e-02, -4.4604e-02, -2.2752e-02]]],
[[[-2.9177e-02, 5.0791e-02, -5.5935e-03, 4.7169e-02, -1.1366e-02],
[ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02, 6.6127e-02],
[-1.1869e-03, 7.9363e-03, 6.0124e-02, 8.7514e-04, -2.6833e-02],
[-4.4174e-02, 1.9975e-02, 2.7939e-02, 2.7791e-02, -1.7919e-02],
[-2.9191e-02, 2.5500e-02, 3.6869e-03, 3.1063e-02, -3.7672e-02]],
[[-6.7084e-02, 5.3586e-02, 1.8485e-02, -7.8099e-03, 4.6288e-02],
[ 4.9055e-02, -1.4634e-02, 3.2799e-02, 3.9726e-02, -7.0032e-02],
[ 7.0897e-02, 3.4370e-02, -1.4814e-02, -3.9030e-02, 3.0867e-02],
[ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02, 2.6245e-02],
[ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02, 3.3599e-02]],
[[ 1.1310e-02, -2.9816e-02, 6.3727e-02, 3.9850e-02, 1.3761e-02],
[ 3.0453e-02, -4.8504e-02, 5.3189e-02, 1.9425e-02, 4.7484e-02],
[ 6.1376e-02, -7.8290e-02, -6.8859e-02, 1.8497e-02, -1.1496e-02],
[-7.8178e-02, -4.5904e-02, 7.3181e-02, 2.9441e-02, 4.6967e-02],
[ 7.6978e-02, 7.2934e-02, 5.6798e-02, 5.8828e-02, 4.4637e-02]],
[[ 6.0281e-02, 7.7289e-02, 7.9016e-02, -4.1437e-02, 3.1101e-02],
[-5.0620e-02, 3.3108e-02, -5.8687e-02, 2.7694e-02, 5.4294e-02],
[ 2.1156e-02, 1.7004e-03, 2.4742e-02, 6.9593e-02, 5.7699e-02],
[ 6.8876e-02, 3.2239e-02, 3.3322e-02, 2.9973e-02, 7.4267e-02],
[-2.4736e-03, 3.8454e-02, -2.5898e-02, 2.0443e-02, 6.0816e-02]],
[[ 1.6487e-02, -6.8994e-03, 5.2835e-02, 5.7784e-02, -1.7036e-02],
[-7.1883e-02, 3.1576e-04, -5.5839e-02, 1.4949e-02, -1.9834e-03],
[-2.7395e-02, 3.9861e-04, -1.0588e-02, 9.9140e-03, 5.1499e-02],
[ 7.4060e-02, -1.0655e-02, 1.1668e-02, 4.9183e-02, 5.2846e-02],
[ 2.8634e-02, 4.2678e-02, -1.4281e-02, 1.3904e-03, 7.6289e-02]],
[[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
[-4.1579e-02, 2.8281e-02, 3.9429e-02, 7.5058e-03, 6.5170e-03],
[-3.4494e-02, 7.5710e-02, 4.1078e-02, 4.4451e-02, 4.2661e-02],
[-5.4398e-02, 5.1592e-02, -2.6367e-02, -3.2980e-02, 5.3860e-02],
[-4.2436e-02, -8.4286e-03, 7.5331e-02, -6.6725e-02, 4.9887e-02]]],
[[[-1.5211e-02, -6.2506e-03, 3.0621e-03, 3.0725e-02, -7.0877e-03],
[ 1.1974e-02, -5.2611e-02, -2.7415e-02, 4.3479e-02, -4.2108e-02],
[ 3.3816e-02, 6.1523e-02, -9.9011e-03, -3.7770e-02, 6.5915e-04],
[ 5.3678e-03, 5.9921e-02, -3.4530e-02, 5.1942e-02, 5.3762e-02],
[-4.7293e-02, -6.2274e-02, -7.5059e-02, 8.1645e-02, 2.1149e-02]],
[[-1.4459e-02, -2.7155e-02, -2.5730e-02, 7.6751e-02, -1.6932e-02],
[-5.3342e-02, -2.6885e-02, 4.3476e-02, -7.9174e-02, -3.5761e-02],
[ 4.2970e-02, 2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
[-2.5080e-02, 4.1672e-02, 4.2424e-02, -4.8704e-02, -6.0434e-02],
[ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],
[[-5.4303e-02, 6.7885e-02, -5.3922e-02, 6.5582e-02, -5.2617e-03],
[-8.4440e-03, 8.0911e-02, -3.8667e-02, -5.6241e-04, 7.0876e-02],
[ 5.4673e-02, 1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
[-2.8238e-02, -5.0765e-02, -3.4076e-02, 3.1219e-02, -1.8919e-02],
[-4.6076e-02, -7.6516e-02, 3.1247e-02, 4.1743e-02, 7.5575e-02]],
[[ 3.8787e-03, -8.1731e-03, 4.7381e-03, 5.8261e-02, 4.6416e-02],
[ 7.7171e-02, -6.5924e-02, 1.5769e-02, 2.6777e-02, 7.7365e-02],
[-7.3126e-02, 4.7624e-02, -7.0620e-02, 7.3309e-02, 7.7585e-03],
[-7.9208e-02, -7.3783e-03, -5.3142e-02, 3.4386e-02, 5.6230e-03],
[-4.4492e-02, 7.5403e-02, -2.8887e-02, 2.6937e-02, 7.0698e-03]],
[[-1.7104e-02, -6.6983e-02, -5.7655e-02, 5.1198e-02, -8.0137e-02],
[-7.2406e-02, -3.7856e-02, -7.2086e-02, 7.2641e-02, -7.0749e-02],
[ 6.6330e-02, 2.8797e-02, 6.2197e-02, -4.6068e-03, 1.7392e-03],
[ 4.2384e-02, 5.9572e-02, -4.5953e-02, 6.6345e-03, 7.3979e-02],
[-4.8313e-02, -1.3063e-02, 1.7648e-02, -5.0903e-02, -7.1852e-02]],
[[ 3.2147e-02, -8.1585e-02, 6.7507e-02, -7.7056e-02, 1.7667e-02],
[ 2.0535e-03, -7.4221e-02, 1.0343e-02, 4.3018e-02, 9.7351e-03],
[-2.1410e-02, 5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
[ 4.9847e-02, 6.9608e-02, 3.6233e-03, 5.7025e-02, 6.3206e-02],
[ 1.4611e-02, -2.9885e-02, 5.6140e-02, -6.4338e-02, 8.5266e-03]]],
...,
[[[ 5.7051e-02, 3.4026e-02, -3.7723e-02, 1.4372e-02, -4.4266e-03],
[-8.0557e-02, 1.1810e-02, -6.9374e-02, 3.4264e-02, -3.9068e-02],
[ 3.2814e-02, -4.9334e-02, -3.2234e-02, 3.7901e-02, 9.9268e-03],
[ 1.2846e-03, -5.9199e-02, -5.6303e-02, 1.2189e-03, 7.8874e-02],
[ 7.6858e-04, 2.4341e-02, 4.0423e-02, -7.7602e-02, -3.6388e-02]],
[[ 5.9494e-02, 4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
[-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
[-6.8915e-02, 1.6417e-02, -9.0384e-03, -5.6058e-02, 1.1540e-02],
[ 3.9632e-02, 3.8881e-02, 5.5834e-02, 7.5591e-02, 1.8463e-02],
[ 4.5034e-02, -6.4665e-02, 6.7883e-02, 7.1108e-02, 8.0694e-02]],
[[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
[-7.9629e-02, -2.5316e-03, -3.4649e-02, 7.9736e-02, 2.4963e-02],
[-1.4102e-02, 3.0896e-02, -5.4594e-02, 5.7641e-02, 7.8276e-02],
[ 3.3722e-02, 1.6397e-02, 6.6251e-02, 2.5637e-02, 4.0073e-02],
[ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],
[[ 3.9434e-02, 6.7049e-02, 7.1627e-02, 6.9307e-02, -5.7508e-03],
[ 2.8151e-02, 7.9890e-02, -6.4687e-02, -6.8959e-02, 6.8179e-02],
[ 1.2583e-02, 6.6052e-02, 6.7770e-02, 1.0853e-02, 6.3935e-02],
[ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
[-1.1810e-04, 6.2292e-02, -2.1831e-02, -4.1282e-02, 3.4718e-02]],
[[-8.9495e-03, -3.5923e-02, -4.9030e-02, 1.7068e-02, 5.7835e-02],
[-6.2950e-02, 6.9258e-02, 1.4909e-02, -3.9252e-02, 3.0917e-02],
[-5.0831e-02, -2.6109e-02, -4.2526e-02, 4.9180e-03, -6.7907e-02],
[-1.4867e-02, 8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
[ 6.5369e-02, 3.5118e-02, -3.5070e-02, 3.1514e-02, -1.7773e-02]],
[[-1.9000e-02, 4.8772e-02, -4.0550e-02, 5.7766e-02, -4.8687e-02],
[ 7.0112e-02, 7.4851e-02, -5.0324e-02, 4.2522e-02, 6.6367e-02],
[-6.6793e-02, -6.3487e-02, -6.3574e-02, 7.3530e-02, -6.7062e-02],
[ 1.9297e-02, 3.9876e-02, 7.0333e-03, 3.6541e-02, 3.0865e-02],
[ 6.9009e-02, 2.7737e-03, -6.0400e-02, 1.5249e-03, -1.5177e-03]]],
[[[-4.9098e-02, -1.2656e-02, 3.0326e-02, 4.6450e-02, 4.1143e-02],
[ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02, 7.3691e-02],
[-5.2880e-02, -5.3283e-02, -4.6874e-02, 2.0506e-02, 1.4432e-02],
[ 5.3466e-05, 6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
[-7.2209e-02, -2.8706e-02, 6.6109e-02, 5.8108e-02, -1.8114e-02]],
[[-5.8877e-02, 3.5183e-02, 6.5460e-02, 5.2934e-02, 3.5997e-02],
[-6.5718e-02, 2.7700e-02, 7.1110e-02, -5.7825e-02, 6.1866e-03],
[-5.3281e-03, -7.6189e-02, -6.9421e-02, 6.4743e-02, -1.1912e-02],
[ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02, 4.5729e-02],
[ 2.0473e-03, -3.1893e-02, -3.0088e-02, 6.1322e-02, -1.3287e-02]],
[[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
[-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
[-5.0579e-02, 5.1921e-02, 4.0006e-02, -5.3846e-02, 3.6710e-03],
[-5.5284e-03, 5.2453e-02, 3.5617e-02, -4.4475e-02, 2.7835e-02],
[ 3.6465e-02, 2.2936e-02, 4.9494e-02, -6.8768e-02, -6.8512e-02]],
[[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
[ 3.3837e-02, 6.9886e-02, 2.5937e-03, -4.4428e-02, -6.1442e-03],
[-3.3799e-02, 7.6725e-02, 1.5202e-02, 2.7467e-02, -7.2112e-02],
[-5.3887e-02, 5.3134e-02, -5.5426e-02, 8.1476e-02, 1.0773e-02],
[-1.9578e-02, 1.7628e-02, -2.2382e-02, 6.7076e-02, -1.3475e-02]],
[[ 3.7281e-02, 2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
[-5.1809e-02, -1.0889e-02, 4.4019e-02, -4.0099e-02, -6.2939e-02],
[ 7.8826e-02, 1.4336e-02, -7.8953e-02, -4.1699e-03, 2.1759e-02],
[ 4.3422e-02, 6.1053e-02, -5.1035e-02, 2.5170e-02, 8.1194e-02],
[-3.5907e-02, 3.5084e-02, 5.4858e-02, 5.7819e-02, -6.8527e-02]],
[[ 6.0340e-02, -4.5873e-02, 4.5307e-02, -1.8559e-02, -5.9891e-02],
[ 7.1101e-02, 5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
[ 4.5972e-02, 4.6237e-02, -1.8353e-02, 5.5372e-03, 5.8802e-02],
[-8.0939e-02, 2.2098e-03, -2.7943e-03, 6.9556e-02, 3.5299e-03],
[-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],
[[[-5.8326e-02, 4.3804e-02, 5.4642e-02, 2.9479e-02, 5.5766e-02],
[-6.2955e-02, 4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
[ 7.8974e-02, 1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
[ 1.1472e-02, 6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
[-5.7437e-02, -5.7150e-02, 4.9108e-02, 2.2168e-02, -5.4964e-02]],
[[-3.2895e-02, -2.2746e-03, 6.8428e-02, -7.4781e-02, 6.5675e-02],
[-8.0232e-02, -2.6468e-02, -2.1136e-02, 2.1449e-02, 6.4572e-02],
[ 2.9930e-03, 1.1987e-02, 4.8122e-03, 3.4183e-02, -7.8918e-02],
[ 6.3749e-02, -2.5083e-02, 1.1253e-02, -4.4485e-02, 3.3380e-02],
[ 5.0096e-03, -1.7321e-02, 8.0185e-02, -2.3853e-02, -2.9333e-03]],
[[ 2.6648e-02, -7.6799e-02, 3.2204e-03, -7.7476e-02, -4.4615e-03],
[ 5.7110e-02, 7.8575e-02, 5.3204e-02, -7.8592e-02, 4.1383e-03],
[ 1.6194e-02, 2.5400e-02, 7.4070e-02, -3.9092e-03, -2.9417e-02],
[-7.9407e-02, 2.5042e-02, -3.8854e-02, 2.8143e-02, 2.8485e-03],
[-3.3828e-02, -7.5645e-02, 7.8511e-02, -4.4048e-02, 6.0887e-02]],
[[-6.4552e-02, -3.1646e-02, 6.5499e-02, -6.8577e-02, -5.1529e-02],
[ 6.1176e-02, -4.8461e-02, 4.7687e-02, -3.0069e-02, -1.7665e-02],
[ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
[ 6.1360e-02, -1.9826e-02, -6.4593e-02, 3.5071e-02, -5.9178e-02],
[-6.6739e-02, 2.6098e-02, -5.5998e-02, 8.1334e-02, 3.7472e-02]],
[[-5.5207e-02, 1.4355e-02, -2.2037e-02, -2.4025e-02, 7.2631e-02],
[-1.0448e-02, 1.9105e-03, -5.5223e-02, 4.6377e-02, -6.8534e-02],
[-2.4292e-02, 7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
[ 4.5334e-02, -2.3274e-02, -4.3572e-02, 4.3487e-03, -4.6057e-02],
[-5.3757e-02, -2.0336e-02, -5.2245e-02, 2.2213e-02, -6.7578e-03]],
[[ 5.7154e-02, 6.9033e-02, -2.7450e-02, -5.9039e-02, 3.0233e-02],
[ 5.5904e-02, 5.2798e-02, -2.2586e-02, 2.8411e-02, -6.8010e-03],
[ 5.1257e-02, -4.3710e-02, 8.7161e-03, 1.9411e-02, -3.5285e-03],
[-8.0450e-02, 6.1012e-02, -7.7756e-02, -2.1472e-02, 4.7537e-02],
[-4.7231e-02, 3.7300e-02, 2.7754e-02, -2.4025e-02, 1.0065e-02]]]])
(bias): Normal:
loc: tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.])
scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345, 0.0684, 0.0584, -0.0804,
0.0098, -0.0490, -0.0535, 0.0145, 0.0056, 0.0082, -0.0256, 0.0140])
)
(observed): Observed()
)
(fc1): Linear(
in_features=400, out_features=120, bias=True
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
-0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
-0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.],
requires_grad=True)
tensor: tensor([ 0.0928, -0.0557, -0.0987, -0.0677, 0.0549, -0.0778, 0.0566, -0.0783,
-0.1079, -0.0651, 0.1489, -0.0263, -0.0720, 0.0551, -0.1069, 0.0144,
-0.0300, 0.0989, 0.0729, -0.0451, 0.0721, 0.0009, 0.1521, 0.1494,
0.1114, -0.0930, -0.1384, -0.0729, -0.0837, -0.0453, 0.0476, 0.0383,
-0.0459, -0.0405, -0.1977, 0.0385, -0.0533, -0.0039, 0.1198, 0.0202,
-0.1197, -0.0355, -0.0581, -0.1910, 0.0959, -0.0030, 0.0761, 0.1127,
-0.0603, -0.1291, 0.0822, 0.1164, 0.0096, -0.1714, 0.0553, -0.1927,
-0.1611, -0.1322, -0.0809, -0.0486, 0.1202, -0.0975, -0.0132, 0.0918,
0.1541, -0.1653, 0.0311, 0.0700, 0.0002, -0.0480, -0.0241, -0.2082,
-0.1500, 0.0072, 0.0052, 0.0860, 0.0694, 0.1631, -0.0141, -0.1648,
-0.1694, -0.0393, 0.0137, -0.0039, 0.0152, 0.0567, 0.0944, 0.0417,
0.0136, -0.1908, -0.0396, 0.1616, 0.1286, 0.2245, -0.0121, -0.1299,
-0.1069, 0.0543, 0.0613, -0.1372, -0.0757, -0.0744, -0.0156, -0.0350,
-0.1974, -0.0417, 0.1595, -0.1018, -0.0879, -0.0230, -0.1762, 0.0350,
-0.1005, -0.0702, -0.1774, 0.0834, 0.0309, 0.1576, 0.0500, 0.1578],
grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., -0., 0., ..., 0., -0., -0.],
[-0., 0., -0., ..., 0., -0., -0.],
[0., 0., -0., ..., -0., 0., 0.],
...,
[0., -0., 0., ..., -0., -0., 0.],
[0., -0., 0., ..., -0., 0., 0.],
[0., -0., -0., ..., -0., -0., -0.]])
scale: tensor([[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
...,
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500],
[0.0500, 0.0500, 0.0500, ..., 0.0500, 0.0500, 0.0500]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.0489, -0.0457, 0.0358, ..., 0.0488, -0.0310, -0.0318],
[-0.0198, 0.0492, -0.0495, ..., 0.0437, -0.0228, -0.0161],
[ 0.0042, 0.0213, -0.0018, ..., -0.0004, 0.0377, 0.0324],
...,
[ 0.0020, -0.0197, 0.0377, ..., -0.0133, -0.0496, 0.0166],
[ 0.0128, -0.0165, 0.0298, ..., -0.0352, 0.0281, 0.0219],
[ 0.0448, -0.0166, -0.0012, ..., -0.0042, -0.0289, -0.0339]])
(bias): Normal:
loc: tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
-0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
-0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.])
scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([ 0.0350, 0.0458, -0.0330, 0.0477, 0.0383, -0.0136, -0.0182, 0.0285,
0.0197, -0.0431, 0.0120, -0.0445, 0.0171, -0.0019, -0.0141, -0.0021,
-0.0429, -0.0159, 0.0028, 0.0272, -0.0290, 0.0047, 0.0452, -0.0022,
0.0279, 0.0323, -0.0433, 0.0049, 0.0063, -0.0388, 0.0090, -0.0233,
0.0251, 0.0375, 0.0274, -0.0337, 0.0122, 0.0217, 0.0230, -0.0405,
-0.0476, -0.0063, -0.0021, 0.0267, 0.0014, 0.0228, -0.0130, -0.0471,
-0.0170, -0.0349, -0.0472, 0.0116, 0.0002, -0.0426, -0.0129, 0.0492,
0.0117, -0.0143, -0.0025, 0.0040, -0.0466, -0.0037, 0.0341, -0.0261,
0.0327, -0.0433, 0.0025, 0.0201, 0.0211, -0.0235, 0.0472, -0.0291,
0.0431, -0.0314, -0.0255, 0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
-0.0305, -0.0172, 0.0238, 0.0029, -0.0029, 0.0172, 0.0227, 0.0006,
0.0120, -0.0068, -0.0043, -0.0289, 0.0060, 0.0199, 0.0122, 0.0423,
-0.0015, -0.0034, -0.0201, -0.0374, 0.0159, -0.0258, 0.0075, -0.0097,
-0.0048, 0.0477, -0.0470, 0.0045, 0.0128, -0.0441, 0.0218, 0.0365,
-0.0206, 0.0348, -0.0249, 0.0256, 0.0222, 0.0019, 0.0289, 0.0248])
)
(observed): Observed()
)
(fc2): Linear(
in_features=120, out_features=2, bias=True
(posterior): Automatic(
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.7071, 0.7071], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., 0.], requires_grad=True)
tensor: tensor([-0.4664, -0.5310], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[0., -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., 0.,
-0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
-0., -0., 0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0.,
0., 0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0.,
-0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., 0., -0., -0., 0., -0., -0.],
[-0., -0., 0., 0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., -0., 0., -0., 0., -0.,
0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.,
-0., 0., -0., -0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., -0., 0.,
0., 0., -0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0.,
0., -0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0.]])
scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913],
[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
0.0913, 0.0913, 0.0913]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806, 0.0095, -0.0646, 0.0145,
-0.0077, -0.0512, -0.0072, -0.0008, 0.0279, -0.0383, -0.0224, 0.0434,
-0.0912, -0.0509, 0.0756, -0.0889, 0.0356, -0.0862, -0.0046, 0.0507,
-0.0356, -0.0616, -0.0509, -0.0035, 0.0123, 0.0190, -0.0453, 0.0815,
-0.0149, 0.0448, -0.0308, -0.0292, -0.0423, 0.0691, 0.0686, -0.0398,
-0.0657, 0.0157, -0.0508, 0.0847, -0.0897, 0.0655, 0.0407, 0.0535,
-0.0541, -0.0812, 0.0122, -0.0665, -0.0799, 0.0247, -0.0409, 0.0105,
-0.0471, -0.0825, -0.0042, 0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
0.0852, 0.0711, 0.0833, 0.0214, 0.0743, 0.0575, 0.0583, -0.0007,
0.0814, 0.0736, -0.0248, 0.0284, 0.0873, -0.0174, 0.0206, -0.0740,
0.0276, -0.0414, 0.0508, -0.0087, -0.0581, 0.0255, 0.0058, 0.0142,
-0.0266, 0.0067, -0.0468, 0.0654, 0.0305, -0.0043, -0.0613, 0.0733,
0.0400, -0.0446, -0.0243, -0.0434, -0.0616, 0.0371, 0.0253, 0.0681,
0.0847, -0.0068, 0.0176, -0.0169, -0.0387, 0.0219, -0.0046, -0.0663],
[-0.0324, -0.0686, 0.0105, 0.0805, 0.0090, 0.0304, 0.0097, -0.0191,
-0.0591, 0.0876, -0.0748, 0.0383, 0.0680, 0.0441, 0.0479, 0.0484,
0.0302, -0.0039, 0.0855, -0.0066, 0.0661, -0.0492, 0.0843, -0.0566,
0.0517, 0.0880, 0.0308, -0.0874, -0.0144, 0.0143, -0.0663, -0.0484,
-0.0368, 0.0709, 0.0610, 0.0495, -0.0031, -0.0503, 0.0562, -0.0030,
0.0753, 0.0173, 0.0221, -0.0259, 0.0145, 0.0206, -0.0740, 0.0226,
-0.0414, 0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
0.0882, 0.0519, 0.0840, 0.0558, 0.0087, 0.0270, 0.0901, 0.0010,
0.0620, 0.0696, 0.0825, 0.0557, -0.0043, -0.0531, -0.0447, 0.0474,
0.0724, 0.0483, -0.0868, 0.0503, -0.0060, 0.0524, -0.0355, 0.0002,
-0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416, 0.0861,
0.0530, 0.0840, -0.0316, -0.0839, -0.0451, -0.0664, 0.0725, 0.0301,
0.0456, -0.0145, 0.0455, -0.0850, 0.0010, -0.0722, -0.0800, 0.0512,
-0.0753, -0.0348, -0.0249, 0.0067, 0.0063, -0.0506, -0.0310, 0.0592,
0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]])
(bias): Normal:
loc: tensor([-0., 0.])
scale: tensor([0.7071, 0.7071])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.0576, 0.0882])
)
(observed): Observed()
)
)
One can also set the posterior when one creates the module
nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))
Out:
Linear(
in_features=10, out_features=10, bias=True
(posterior): Normal(
(weight): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498],
[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498]], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([[-0., 0., 0., -0., 0., -0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0., -0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0., -0., 0., 0., -0., 0.],
[0., -0., 0., -0., 0., 0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0., -0., -0., -0., 0., 0.],
[-0., -0., 0., -0., 0., 0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0., -0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0., -0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0., 0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0., 0., -0., 0., -0., 0.]], requires_grad=True)
tensor: tensor([[-0.0003, 0.0373, -0.0517, -0.0143, 0.0863, 0.0142, -0.0411, 0.0303,
0.0511, -0.0203],
[ 0.0079, -0.0920, -0.0046, -0.0104, -0.0095, 0.0203, 0.0310, 0.0202,
-0.0121, -0.0181],
[-0.0519, -0.0255, 0.0884, -0.0048, 0.0630, 0.0475, -0.0290, -0.0855,
0.0418, -0.0381],
[-0.1063, 0.0282, 0.0966, -0.0082, -0.0365, 0.0427, -0.0479, 0.0883,
0.0667, 0.0052],
[-0.0698, -0.0686, 0.0423, 0.0692, 0.1213, 0.0422, 0.0058, 0.0344,
-0.0401, -0.0121],
[-0.0083, 0.0183, 0.0247, 0.0571, -0.0163, 0.1211, 0.0445, -0.1065,
-0.0504, 0.0108],
[-0.0505, -0.0583, -0.0083, -0.0533, 0.0170, -0.0274, -0.1463, -0.0559,
0.0722, 0.0532],
[-0.0191, -0.0246, 0.0021, 0.1291, -0.0634, 0.0648, -0.0016, 0.0228,
-0.0267, 0.0685],
[ 0.0373, 0.0184, 0.0326, -0.1323, 0.0879, -0.0471, -0.0354, -0.0547,
-0.0245, 0.0197],
[-0.0411, 0.0050, 0.0211, -0.0370, 0.0525, -0.0592, 0.0245, 0.0202,
-0.0753, 0.0099]], grad_fn=<AddBackward0>)
(bias): Normal:
posterior: Automatic()
prior: Module()
observed: Observed()
scale: Transform:
tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
0.0498], grad_fn=<ExpBackward0>)
loc: Parameter containing:
tensor([-0., -0., -0., 0., 0., -0., -0., 0., 0., -0.], requires_grad=True)
tensor: tensor([-0.0311, -0.0460, -0.1053, 0.0139, 0.0346, -0.0215, 0.0681, 0.0781,
0.0452, 0.0135], grad_fn=<AddBackward0>)
)
(prior): Module(
(weight): Normal:
loc: tensor([[-0., 0., 0., -0., 0., -0., 0., 0., 0., -0.],
[0., 0., -0., -0., -0., -0., -0., -0., 0., 0.],
[0., 0., 0., -0., -0., -0., 0., 0., -0., 0.],
[0., -0., 0., -0., 0., 0., -0., -0., -0., 0.],
[-0., 0., 0., -0., -0., -0., -0., -0., 0., 0.],
[-0., -0., 0., -0., 0., 0., 0., -0., 0., -0.],
[0., -0., 0., 0., 0., -0., 0., 0., -0., -0.],
[0., -0., 0., -0., -0., -0., -0., 0., 0., -0.],
[-0., -0., 0., 0., -0., 0., 0., -0., -0., 0.],
[-0., 0., 0., -0., -0., 0., -0., 0., -0., 0.]])
scale: tensor([[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162],
[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162]])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([[-0.2634, 0.2061, 0.3041, -0.1349, 0.0784, -0.1657, 0.1957, 0.2490,
0.3160, -0.0849],
[ 0.0335, 0.2293, -0.0953, -0.1795, -0.1057, -0.0171, -0.1661, -0.0754,
0.0487, 0.0602],
[ 0.2953, 0.0108, 0.1470, -0.1304, -0.2690, -0.3157, 0.2241, 0.0583,
-0.1642, 0.2801],
[ 0.1391, -0.0884, 0.2268, -0.0267, 0.1603, 0.0974, -0.0735, -0.3121,
-0.0606, 0.2517],
[-0.2838, 0.1884, 0.2694, -0.1517, -0.0660, -0.2486, -0.0599, -0.1401,
0.2265, 0.2869],
[-0.2059, -0.0081, 0.2682, -0.1052, 0.1061, 0.2965, 0.1716, -0.0551,
0.1544, -0.1494],
[ 0.1081, -0.1222, 0.0729, 0.0693, 0.2599, -0.2775, 0.0092, 0.0497,
-0.2638, -0.1386],
[ 0.0835, -0.2367, 0.0503, -0.1869, -0.0921, -0.2095, -0.2027, 0.0749,
0.1702, -0.2820],
[-0.2977, -0.0958, 0.0508, 0.1128, -0.2777, 0.1770, 0.1128, -0.2343,
-0.0854, 0.0976],
[-0.1118, 0.3101, 0.1880, -0.0396, -0.1929, 0.1096, -0.1623, 0.1923,
-0.2513, 0.1232]])
(bias): Normal:
loc: tensor([-0., -0., -0., 0., 0., -0., -0., 0., 0., -0.])
scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
0.3162])
posterior: Automatic()
prior: Module()
observed: Observed()
tensor: tensor([-0.2393, -0.2382, -0.0876, 0.1571, 0.1483, -0.3024, -0.2600, 0.2103,
0.3014, -0.0652])
)
(observed): Observed()
)
See the borch.posterior documentation for other posteriors and what parameters
you can set. Note that all posteriors does not work with all parameters but you can
have different posteriors for the different borch.Module’s in your network.
Exercises¶
Use what you have learned to train an image classifier for MNIST, you should achieve an accuracy larger than 98 %. Note: you can access MNST using
torchvision.datasets.MNIST.Fit the same model architecture with normal torch and compare the likelihood with the borch network, What are the differences and why?
Port the model to CIFAR and see how you can improve the accuracy.
Show how the Categorical distribution is related to the cross entropy loss function that is commonly used in frequentest deep learning.
Total running time of the script: ( 0 minutes 0.185 seconds)