Training an Image classifier

You will learn the basics of how to create an image classifier using the borch.nn package and fit it using the infer package.

Lets start of with importing what we need

import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F

The module borch.nn provides implementations of neural network modules that are used for deep probabilistic programming. It provides an interface almost identical to the torch.nn modules and in many cases it is possible to just switch

from torch import nn

to

from borch import nn

Data

In this example we will use simulated data and not run the fitting until convergence, but show how the model is set up and how one can construct the training loop. We will just generate some random data, where data represent the image and target is the class.

data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)

Model

Lets set up the model. In order to use infer and the borch to the fullest, we need to select a a likelihood distribution. For classification the distributions.Categorical is suitable.

class Net(borch.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=borch.posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 2)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return self.classification

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-0.1927,  0.1399, -0.1354, -0.1729, -0.1509],
                [-0.1643,  0.1672,  0.1231,  0.0760,  0.1983],
                [-0.1877,  0.1498,  0.1842, -0.1852, -0.0757],
                [-0.0739, -0.0365,  0.1345,  0.0660, -0.1410],
                [ 0.0395,  0.1924, -0.0889,  0.1817, -0.1028]]],


              [[[ 0.0161, -0.0283,  0.0567,  0.0753,  0.1422],
                [ 0.1557, -0.0744,  0.1324,  0.1368,  0.1860],
                [ 0.1084, -0.0066, -0.1755,  0.0982, -0.1038],
                [ 0.0692, -0.1975,  0.1228,  0.1460, -0.1969],
                [-0.0408,  0.0335, -0.1200,  0.0135, -0.0343]]],


              [[[ 0.1090, -0.1523, -0.0839, -0.1336,  0.1845],
                [-0.0021, -0.1854, -0.0692,  0.0818,  0.0268],
                [ 0.0554,  0.0253, -0.0801,  0.0925, -0.1053],
                [-0.0562,  0.0456,  0.1366, -0.1447, -0.1639],
                [-0.0741, -0.1671,  0.1945, -0.1811, -0.1519]]],


              [[[ 0.0950,  0.1374,  0.1735,  0.1682,  0.0029],
                [ 0.0662, -0.0615, -0.1451,  0.1452, -0.1408],
                [-0.1634,  0.0675,  0.1090, -0.1899, -0.0123],
                [ 0.0842, -0.0821,  0.1183, -0.0658,  0.1601],
                [-0.1688,  0.1212, -0.0177, -0.1106,  0.1500]]],


              [[[-0.0711,  0.0887,  0.0110, -0.0312,  0.1235],
                [ 0.1727, -0.1303, -0.1418,  0.0785,  0.0870],
                [ 0.0772,  0.0361, -0.1826, -0.0933,  0.0847],
                [ 0.0795,  0.1156, -0.1951, -0.1324, -0.1288],
                [-0.1276, -0.1009, -0.1230, -0.1850,  0.0189]]],


              [[[-0.0376,  0.1030,  0.1858,  0.1619,  0.1681],
                [ 0.0571, -0.1018, -0.1616,  0.1122,  0.1009],
                [ 0.1114,  0.0628, -0.0361,  0.0023,  0.0281],
                [ 0.0725,  0.1015, -0.0013,  0.1214, -0.1805],
                [ 0.1711,  0.1865, -0.1692,  0.0962,  0.0706]]]], requires_grad=True)
       tensor: tensor([[[[-0.2157,  0.0700, -0.0908, -0.1634, -0.1610],
                [-0.1348,  0.2526,  0.0959, -0.0100,  0.2411],
                [-0.0937,  0.2073,  0.1696, -0.1477, -0.1274],
                [-0.0645,  0.0310,  0.1207,  0.1160, -0.2358],
                [ 0.1197,  0.2355, -0.0854,  0.1648, -0.1267]]],


              [[[-0.0590,  0.0077,  0.0397,  0.0469,  0.1285],
                [ 0.0683, -0.1121,  0.1573,  0.1348,  0.1819],
                [ 0.1375, -0.0533, -0.2035,  0.1420, -0.1399],
                [ 0.0028, -0.0762,  0.1180,  0.0648, -0.2013],
                [-0.0807,  0.0769, -0.1508, -0.0319, -0.0546]]],


              [[[ 0.0979, -0.1779, -0.0642, -0.0850,  0.1955],
                [-0.0440, -0.2337, -0.0903,  0.0791,  0.0162],
                [ 0.0858, -0.0131, -0.0561,  0.1656, -0.0313],
                [-0.0810,  0.0303,  0.1008, -0.0657, -0.1811],
                [-0.1711, -0.0730,  0.1278, -0.1982, -0.1405]]],


              [[[ 0.0481,  0.1565,  0.2030,  0.1710,  0.0092],
                [ 0.0199, -0.0892, -0.2083,  0.1719, -0.0671],
                [-0.2068,  0.0169,  0.0941, -0.1937,  0.0171],
                [ 0.0881, -0.0840,  0.0991, -0.0349,  0.1863],
                [-0.2759,  0.0587, -0.0132, -0.1172,  0.2117]]],


              [[[-0.0500,  0.0174,  0.0018, -0.0023,  0.1109],
                [ 0.2215, -0.1161, -0.1446,  0.1037,  0.1574],
                [ 0.1490,  0.0843, -0.2119, -0.1595,  0.1073],
                [ 0.1108,  0.1599, -0.1039, -0.0907, -0.1746],
                [-0.0869, -0.0473, -0.0885, -0.1409,  0.0183]]],


              [[[-0.0328,  0.2260,  0.2178,  0.0954,  0.1086],
                [ 0.0526, -0.1322, -0.1627,  0.1323,  0.1208],
                [ 0.1250,  0.0943,  0.0322,  0.0166,  0.0406],
                [-0.0006,  0.0562, -0.0055,  0.1553, -0.2777],
                [ 0.2405,  0.2811, -0.2061,  0.1368,  0.0990]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.1597,  0.1243,  0.0051, -0.0381,  0.0524,  0.0078],
             requires_grad=True)
       tensor: tensor([-0.1479,  0.0487, -0.0111, -0.0327,  0.0722,  0.0004],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.]]],


              [[[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[0., 0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., 0.]]],


              [[[-0., 0., 0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.]]],


              [[[-0., 0., 0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., 0., 0.],
                [0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.1927,  0.1399, -0.1354, -0.1729, -0.1509],
                [-0.1643,  0.1672,  0.1231,  0.0760,  0.1983],
                [-0.1877,  0.1498,  0.1842, -0.1852, -0.0757],
                [-0.0739, -0.0365,  0.1345,  0.0660, -0.1410],
                [ 0.0395,  0.1924, -0.0889,  0.1817, -0.1028]]],


              [[[ 0.0161, -0.0283,  0.0567,  0.0753,  0.1422],
                [ 0.1557, -0.0744,  0.1324,  0.1368,  0.1860],
                [ 0.1084, -0.0066, -0.1755,  0.0982, -0.1038],
                [ 0.0692, -0.1975,  0.1228,  0.1460, -0.1969],
                [-0.0408,  0.0335, -0.1200,  0.0135, -0.0343]]],


              [[[ 0.1090, -0.1523, -0.0839, -0.1336,  0.1845],
                [-0.0021, -0.1854, -0.0692,  0.0818,  0.0268],
                [ 0.0554,  0.0253, -0.0801,  0.0925, -0.1053],
                [-0.0562,  0.0456,  0.1366, -0.1447, -0.1639],
                [-0.0741, -0.1671,  0.1945, -0.1811, -0.1519]]],


              [[[ 0.0950,  0.1374,  0.1735,  0.1682,  0.0029],
                [ 0.0662, -0.0615, -0.1451,  0.1452, -0.1408],
                [-0.1634,  0.0675,  0.1090, -0.1899, -0.0123],
                [ 0.0842, -0.0821,  0.1183, -0.0658,  0.1601],
                [-0.1688,  0.1212, -0.0177, -0.1106,  0.1500]]],


              [[[-0.0711,  0.0887,  0.0110, -0.0312,  0.1235],
                [ 0.1727, -0.1303, -0.1418,  0.0785,  0.0870],
                [ 0.0772,  0.0361, -0.1826, -0.0933,  0.0847],
                [ 0.0795,  0.1156, -0.1951, -0.1324, -0.1288],
                [-0.1276, -0.1009, -0.1230, -0.1850,  0.0189]]],


              [[[-0.0376,  0.1030,  0.1858,  0.1619,  0.1681],
                [ 0.0571, -0.1018, -0.1616,  0.1122,  0.1009],
                [ 0.1114,  0.0628, -0.0361,  0.0023,  0.0281],
                [ 0.0725,  0.1015, -0.0013,  0.1214, -0.1805],
                [ 0.1711,  0.1865, -0.1692,  0.0962,  0.0706]]]])
      (bias): Normal:
       loc: tensor([-0., 0., 0., -0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1597,  0.1243,  0.0051, -0.0381,  0.0524,  0.0078])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[ 0.0228, -0.0484,  0.0637, -0.0249, -0.0302],
                [ 0.0105,  0.0733,  0.0576,  0.0393,  0.0094],
                [-0.0590, -0.0747, -0.0506, -0.0495,  0.0402],
                [-0.0647,  0.0789, -0.0096,  0.0477,  0.0700],
                [-0.0450,  0.0176,  0.0443, -0.0176,  0.0407]],

               [[-0.0148,  0.0075, -0.0554,  0.0808,  0.0812],
                [ 0.0636, -0.0776,  0.0617,  0.0497,  0.0760],
                [-0.0350,  0.0633,  0.0318,  0.0562,  0.0742],
                [-0.0229, -0.0075,  0.0338,  0.0709, -0.0382],
                [ 0.0655,  0.0249,  0.0556,  0.0329, -0.0734]],

               [[-0.0531,  0.0393,  0.0355, -0.0358,  0.0298],
                [-0.0772, -0.0365,  0.0397, -0.0093, -0.0571],
                [ 0.0087,  0.0256,  0.0568,  0.0123,  0.0098],
                [ 0.0070,  0.0767,  0.0580,  0.0573,  0.0128],
                [ 0.0476, -0.0090,  0.0190,  0.0597, -0.0108]],

               [[-0.0097, -0.0755,  0.0251,  0.0349,  0.0423],
                [ 0.0800,  0.0594,  0.0019,  0.0419, -0.0641],
                [-0.0501, -0.0639, -0.0783,  0.0276,  0.0425],
                [ 0.0793, -0.0161,  0.0031,  0.0700, -0.0077],
                [ 0.0717,  0.0564,  0.0130,  0.0126,  0.0055]],

               [[-0.0711, -0.0004, -0.0350, -0.0729, -0.0567],
                [-0.0351,  0.0119,  0.0163,  0.0542,  0.0558],
                [-0.0308,  0.0337,  0.0634,  0.0556,  0.0051],
                [-0.0305, -0.0103, -0.0515, -0.0746,  0.0701],
                [ 0.0014, -0.0249,  0.0724,  0.0282, -0.0773]],

               [[ 0.0647,  0.0735, -0.0460,  0.0321,  0.0060],
                [-0.0624, -0.0262,  0.0131, -0.0270, -0.0768],
                [ 0.0053,  0.0301,  0.0262, -0.0012,  0.0230],
                [ 0.0758, -0.0297, -0.0011, -0.0424,  0.0659],
                [ 0.0155, -0.0027, -0.0785, -0.0715,  0.0603]]],


              [[[-0.0680,  0.0027, -0.0140,  0.0451,  0.0738],
                [-0.0323, -0.0084, -0.0281,  0.0522, -0.0210],
                [-0.0525, -0.0399, -0.0072, -0.0577,  0.0280],
                [-0.0374, -0.0055,  0.0143, -0.0258,  0.0404],
                [-0.0804,  0.0563, -0.0096, -0.0485, -0.0052]],

               [[ 0.0772, -0.0417, -0.0619, -0.0375, -0.0217],
                [ 0.0312, -0.0186, -0.0225,  0.0433,  0.0327],
                [ 0.0751, -0.0210,  0.0098,  0.0275,  0.0606],
                [ 0.0308, -0.0484, -0.0063,  0.0799, -0.0184],
                [-0.0195,  0.0014, -0.0263, -0.0651, -0.0203]],

               [[-0.0441, -0.0229, -0.0547,  0.0813, -0.0195],
                [-0.0343,  0.0006,  0.0058,  0.0113,  0.0668],
                [-0.0674,  0.0276, -0.0072,  0.0360, -0.0707],
                [-0.0468,  0.0240, -0.0658,  0.0561,  0.0396],
                [-0.0704, -0.0553,  0.0372, -0.0310,  0.0525]],

               [[-0.0081, -0.0563,  0.0125,  0.0304,  0.0486],
                [-0.0515,  0.0539, -0.0165,  0.0263,  0.0241],
                [-0.0189,  0.0789, -0.0254,  0.0502,  0.0336],
                [-0.0288, -0.0308,  0.0397, -0.0238, -0.0365],
                [ 0.0195,  0.0625, -0.0496,  0.0805,  0.0320]],

               [[ 0.0451,  0.0478, -0.0347, -0.0325,  0.0203],
                [-0.0460, -0.0406,  0.0251,  0.0666,  0.0718],
                [-0.0109, -0.0145,  0.0313,  0.0103,  0.0525],
                [-0.0299, -0.0280, -0.0036,  0.0186,  0.0350],
                [ 0.0073,  0.0463,  0.0292, -0.0732, -0.0485]],

               [[ 0.0059, -0.0394,  0.0745,  0.0484,  0.0053],
                [-0.0341, -0.0371,  0.0608,  0.0058, -0.0537],
                [-0.0605, -0.0791, -0.0253,  0.0762,  0.0762],
                [-0.0059,  0.0153,  0.0580,  0.0198, -0.0173],
                [-0.0174,  0.0578, -0.0368, -0.0736, -0.0631]]],


              [[[ 0.0033, -0.0336,  0.0105,  0.0191,  0.0074],
                [ 0.0714, -0.0613, -0.0475, -0.0452, -0.0098],
                [-0.0617, -0.0515,  0.0353, -0.0375,  0.0360],
                [ 0.0542, -0.0326, -0.0236,  0.0674, -0.0347],
                [ 0.0394,  0.0746, -0.0467,  0.0522,  0.0424]],

               [[ 0.0503, -0.0693,  0.0626, -0.0424, -0.0488],
                [-0.0617, -0.0619, -0.0342,  0.0127, -0.0253],
                [ 0.0319,  0.0270, -0.0089,  0.0175, -0.0308],
                [-0.0551,  0.0107,  0.0212,  0.0026,  0.0019],
                [ 0.0115,  0.0343,  0.0572,  0.0811, -0.0354]],

               [[ 0.0632, -0.0197,  0.0072, -0.0372,  0.0581],
                [-0.0492, -0.0602, -0.0201,  0.0465, -0.0482],
                [ 0.0403, -0.0086, -0.0292, -0.0711, -0.0294],
                [-0.0228, -0.0102,  0.0450,  0.0804,  0.0032],
                [-0.0009,  0.0815,  0.0113,  0.0315, -0.0110]],

               [[-0.0627,  0.0313, -0.0443, -0.0353,  0.0797],
                [-0.0039, -0.0053, -0.0678, -0.0335, -0.0764],
                [-0.0202,  0.0296,  0.0548,  0.0185, -0.0028],
                [ 0.0044,  0.0496,  0.0734, -0.0418, -0.0792],
                [ 0.0025,  0.0171, -0.0681, -0.0224,  0.0077]],

               [[ 0.0424,  0.0476,  0.0079, -0.0167,  0.0523],
                [-0.0335,  0.0003, -0.0560, -0.0448, -0.0173],
                [ 0.0426, -0.0002,  0.0133, -0.0705,  0.0085],
                [ 0.0724,  0.0664,  0.0455,  0.0461, -0.0607],
                [-0.0546,  0.0696, -0.0658,  0.0364, -0.0562]],

               [[-0.0390,  0.0704,  0.0786,  0.0033,  0.0295],
                [-0.0145, -0.0198,  0.0638,  0.0332, -0.0197],
                [ 0.0663,  0.0625, -0.0292,  0.0036, -0.0408],
                [ 0.0201,  0.0786, -0.0615, -0.0266, -0.0627],
                [-0.0323,  0.0509, -0.0516, -0.0624, -0.0025]]],


              ...,


              [[[ 0.0777,  0.0133, -0.0427,  0.0800, -0.0246],
                [-0.0430,  0.0370,  0.0550,  0.0002, -0.0453],
                [-0.0782,  0.0747, -0.0467,  0.0130, -0.0217],
                [ 0.0417,  0.0371,  0.0066,  0.0761,  0.0344],
                [ 0.0590,  0.0188, -0.0681,  0.0579,  0.0579]],

               [[ 0.0803,  0.0284, -0.0485,  0.0127,  0.0668],
                [-0.0049, -0.0313, -0.0693, -0.0646, -0.0111],
                [ 0.0460,  0.0529,  0.0564, -0.0790, -0.0583],
                [-0.0102,  0.0005,  0.0437,  0.0308, -0.0437],
                [ 0.0205,  0.0394, -0.0644,  0.0770,  0.0368]],

               [[ 0.0157, -0.0298, -0.0073,  0.0653, -0.0469],
                [-0.0527,  0.0319, -0.0647,  0.0084, -0.0241],
                [ 0.0004,  0.0514, -0.0025,  0.0630, -0.0294],
                [ 0.0689,  0.0191,  0.0546,  0.0365, -0.0539],
                [-0.0045, -0.0745,  0.0544, -0.0686,  0.0816]],

               [[ 0.0061, -0.0117,  0.0211,  0.0131, -0.0410],
                [ 0.0168, -0.0675, -0.0497,  0.0285, -0.0718],
                [-0.0401, -0.0345,  0.0003, -0.0376, -0.0554],
                [ 0.0065,  0.0611, -0.0163,  0.0418, -0.0605],
                [-0.0081,  0.0791,  0.0381, -0.0200, -0.0402]],

               [[ 0.0350,  0.0182, -0.0750,  0.0471,  0.0045],
                [ 0.0441,  0.0266, -0.0399, -0.0425, -0.0318],
                [-0.0428, -0.0136, -0.0466, -0.0151,  0.0100],
                [-0.0644,  0.0688,  0.0709,  0.0026,  0.0585],
                [-0.0269, -0.0130,  0.0715, -0.0413, -0.0011]],

               [[-0.0537,  0.0477, -0.0147, -0.0553,  0.0459],
                [ 0.0619,  0.0005, -0.0194,  0.0200,  0.0296],
                [-0.0205, -0.0382,  0.0277,  0.0322, -0.0201],
                [-0.0251, -0.0238, -0.0758,  0.0534,  0.0487],
                [-0.0023,  0.0328, -0.0237,  0.0328, -0.0399]]],


              [[[ 0.0501, -0.0288, -0.0173,  0.0392, -0.0050],
                [ 0.0156,  0.0807,  0.0459, -0.0056,  0.0346],
                [-0.0486,  0.0271,  0.0530, -0.0753, -0.0678],
                [ 0.0070,  0.0007, -0.0208, -0.0359,  0.0587],
                [ 0.0728, -0.0304,  0.0334, -0.0355,  0.0323]],

               [[-0.0734,  0.0403,  0.0763,  0.0570, -0.0033],
                [-0.0664,  0.0105, -0.0366,  0.0042,  0.0546],
                [-0.0745,  0.0639, -0.0417, -0.0033,  0.0495],
                [-0.0550,  0.0292, -0.0198, -0.0241,  0.0367],
                [ 0.0107, -0.0213, -0.0517,  0.0664,  0.0738]],

               [[ 0.0123,  0.0034, -0.0046,  0.0620, -0.0356],
                [-0.0580, -0.0810,  0.0235,  0.0637, -0.0323],
                [ 0.0501,  0.0611,  0.0696, -0.0725, -0.0392],
                [-0.0327,  0.0616, -0.0098,  0.0648,  0.0049],
                [-0.0644,  0.0434,  0.0465,  0.0378,  0.0250]],

               [[ 0.0816,  0.0159,  0.0255,  0.0219, -0.0652],
                [ 0.0783,  0.0748, -0.0595,  0.0515, -0.0486],
                [-0.0709, -0.0491, -0.0587, -0.0085, -0.0437],
                [ 0.0395, -0.0117,  0.0683,  0.0806, -0.0066],
                [-0.0332, -0.0257, -0.0023, -0.0359,  0.0064]],

               [[-0.0328, -0.0616, -0.0107, -0.0231, -0.0393],
                [ 0.0030,  0.0048, -0.0813, -0.0253, -0.0723],
                [-0.0680, -0.0350,  0.0409,  0.0464,  0.0235],
                [-0.0085,  0.0688, -0.0767, -0.0011, -0.0570],
                [ 0.0553, -0.0721,  0.0039, -0.0811, -0.0608]],

               [[ 0.0617,  0.0303, -0.0521,  0.0155, -0.0364],
                [-0.0589, -0.0223, -0.0112, -0.0599, -0.0590],
                [ 0.0729,  0.0326,  0.0761, -0.0415, -0.0048],
                [ 0.0036, -0.0197, -0.0393, -0.0060,  0.0785],
                [-0.0679, -0.0750,  0.0671,  0.0385, -0.0260]]],


              [[[-0.0456,  0.0197,  0.0548,  0.0420, -0.0569],
                [ 0.0518, -0.0172, -0.0758, -0.0328,  0.0196],
                [-0.0712, -0.0446,  0.0593, -0.0403, -0.0250],
                [-0.0142, -0.0058, -0.0283,  0.0783,  0.0075],
                [ 0.0755,  0.0161,  0.0319, -0.0562,  0.0378]],

               [[ 0.0572,  0.0773,  0.0243,  0.0638, -0.0472],
                [-0.0081,  0.0225,  0.0298, -0.0442, -0.0075],
                [-0.0814,  0.0369, -0.0680, -0.0471,  0.0187],
                [ 0.0290, -0.0338,  0.0786,  0.0685, -0.0263],
                [-0.0453, -0.0716, -0.0462,  0.0556,  0.0159]],

               [[ 0.0359, -0.0511,  0.0707, -0.0696,  0.0407],
                [-0.0717,  0.0521,  0.0813,  0.0335, -0.0515],
                [-0.0295, -0.0124, -0.0406, -0.0247,  0.0162],
                [-0.0252,  0.0105,  0.0624, -0.0701,  0.0153],
                [-0.0490,  0.0815,  0.0331,  0.0130, -0.0174]],

               [[-0.0801, -0.0390, -0.0246,  0.0187,  0.0752],
                [-0.0654, -0.0242,  0.0666,  0.0303, -0.0114],
                [ 0.0783, -0.0565, -0.0200, -0.0462, -0.0119],
                [ 0.0788,  0.0656,  0.0623,  0.0350,  0.0254],
                [-0.0227,  0.0380, -0.0172,  0.0293, -0.0065]],

               [[-0.0086, -0.0572,  0.0217, -0.0286,  0.0476],
                [ 0.0695, -0.0679,  0.0714,  0.0371,  0.0638],
                [ 0.0099, -0.0652, -0.0545, -0.0068,  0.0805],
                [-0.0506, -0.0737, -0.0110, -0.0198,  0.0047],
                [-0.0288,  0.0730,  0.0794, -0.0033,  0.0242]],

               [[-0.0616, -0.0632, -0.0110, -0.0658, -0.0470],
                [ 0.0425, -0.0136,  0.0665, -0.0201, -0.0727],
                [ 0.0189,  0.0189,  0.0641, -0.0384,  0.0180],
                [ 0.0002, -0.0737,  0.0365,  0.0311, -0.0378],
                [ 0.0789,  0.0037, -0.0582,  0.0148,  0.0323]]]], requires_grad=True)
       tensor: tensor([[[[ 4.9998e-02, -9.8260e-02,  1.4967e-01, -4.4207e-03, -8.7300e-02],
                [ 7.5572e-02,  4.8974e-02,  1.6032e-02, -1.6826e-02, -1.1976e-01],
                [-6.9849e-02, -2.5324e-02, -5.0632e-02, -1.1284e-01,  4.8002e-02],
                [-1.0279e-01,  7.0627e-03,  1.4877e-01,  8.8626e-02,  7.5717e-02],
                [-5.4043e-02,  1.1418e-02,  6.2735e-02, -1.8255e-02,  5.9029e-03]],

               [[-2.7467e-02,  2.2406e-02, -6.6786e-02,  8.7846e-02,  1.1628e-02],
                [ 6.6246e-02, -2.0746e-02,  1.5433e-01,  4.1994e-02,  8.7076e-02],
                [-4.2903e-02,  1.0213e-01,  7.2365e-02,  2.7206e-02,  1.3627e-01],
                [ 6.4725e-03, -6.9502e-02,  8.2167e-02,  3.2131e-02, -4.3159e-02],
                [ 1.0807e-01, -8.2906e-03,  1.3195e-01,  7.2410e-02, -5.3980e-02]],

               [[-2.5737e-02,  1.0729e-01,  6.7881e-02,  1.2729e-02,  2.2233e-02],
                [-5.4383e-02, -3.2342e-02,  1.1222e-01,  2.9954e-02, -5.6793e-02],
                [ 6.1569e-02, -5.0857e-03, -7.2801e-03,  2.1612e-02,  5.6074e-02],
                [ 8.1350e-02,  1.3603e-01,  6.5937e-02,  5.7080e-02, -3.3051e-02],
                [ 9.2686e-02,  8.6042e-03,  5.0341e-02,  1.2102e-01, -4.2846e-02]],

               [[ 3.1385e-02, -1.8880e-02, -4.3763e-02,  9.5897e-02,  2.5644e-03],
                [ 1.0728e-01,  1.2668e-01, -2.4991e-02,  1.2733e-01, -1.1972e-02],
                [-1.4936e-02, -8.9163e-02, -6.3422e-02, -5.8102e-02,  8.8998e-02],
                [-1.2742e-03,  1.0116e-02, -1.3688e-03,  8.9744e-02,  3.2801e-02],
                [ 5.3857e-02,  3.3760e-02,  9.7652e-02, -1.2658e-02, -4.1213e-02]],

               [[-1.0549e-01,  1.1089e-02,  2.0750e-02, -8.4346e-02, -4.5414e-02],
                [-7.8882e-02,  4.7766e-02,  3.7996e-02, -3.4279e-03,  1.0021e-01],
                [-1.2175e-02,  8.0692e-02,  1.6572e-01,  1.4422e-01, -6.1436e-02],
                [-6.2487e-02, -1.2147e-02, -5.6277e-02, -4.4224e-04, -7.0803e-03],
                [-3.5251e-02, -6.1638e-02,  7.6515e-02, -4.3100e-02, -9.2633e-02]],

               [[ 8.4949e-02,  1.0109e-01, -5.7688e-02,  7.8953e-02,  1.1039e-02],
                [-1.8447e-03,  6.6093e-03,  5.5160e-02, -2.1061e-03, -7.6814e-02],
                [-4.7536e-02,  5.4149e-03,  4.1168e-02, -3.2414e-03, -1.2026e-02],
                [ 6.3581e-02,  2.1284e-03,  4.8441e-02, -2.8984e-02,  3.7449e-02],
                [ 6.2353e-02,  1.7654e-02,  6.6143e-03, -2.6222e-02,  1.9975e-02]]],


              [[[-9.0780e-03,  7.4168e-04,  5.0728e-02,  9.6808e-02,  4.4157e-02],
                [-4.6064e-02, -4.7032e-02, -3.6292e-02,  1.0961e-01, -5.1720e-02],
                [-2.7391e-02, -8.0438e-02,  2.3853e-02, -1.0298e-01,  1.1040e-02],
                [-6.7849e-03,  1.1938e-03,  2.2602e-02,  5.5898e-02,  7.7367e-02],
                [ 4.9093e-03,  8.2457e-02, -1.4296e-02, -3.9243e-02,  2.8431e-02]],

               [[ 1.2188e-01, -9.1309e-02, -2.5300e-02,  3.8637e-02, -1.1966e-02],
                [ 7.2569e-02,  3.6117e-02, -5.2285e-02,  5.4047e-02,  4.2884e-03],
                [ 3.4672e-02, -7.6760e-04, -2.4154e-02, -3.2986e-02,  6.2598e-02],
                [ 3.8618e-02,  5.3779e-02,  5.0678e-02,  2.6810e-02, -1.5022e-02],
                [ 5.0676e-02,  9.1586e-02,  1.7718e-02, -4.8524e-02, -2.0987e-02]],

               [[ 1.8777e-02, -3.6430e-03, -5.1505e-02,  2.5505e-02, -9.2066e-02],
                [-8.3652e-02,  3.8952e-02,  5.1191e-02,  1.6772e-02,  2.0135e-01],
                [-8.1641e-02,  7.3640e-02,  7.7579e-02,  1.5251e-02, -2.0884e-02],
                [-1.3105e-01,  1.5262e-02, -5.6016e-02, -5.0386e-02,  7.2419e-02],
                [-1.2733e-01, -3.3231e-02,  8.8184e-02,  6.6656e-02,  6.6930e-02]],

               [[ 4.0774e-02, -1.3535e-01,  1.3731e-02,  2.4058e-02,  1.0693e-02],
                [-5.5981e-02,  7.2441e-02, -3.5613e-03, -1.2205e-02,  4.7171e-02],
                [-1.1775e-02,  2.7184e-02, -7.1557e-02,  1.0347e-01,  5.8896e-03],
                [ 2.0610e-02, -8.3879e-02, -7.2425e-02,  2.1839e-02, -4.3554e-02],
                [ 6.8138e-02,  3.1565e-02, -1.4528e-01,  1.7535e-01,  3.2564e-02]],

               [[ 1.8692e-02,  3.7771e-02,  5.5205e-02, -1.3458e-02,  4.8685e-02],
                [ 2.7660e-02, -1.7240e-02, -4.8733e-02,  9.4653e-02,  5.5204e-02],
                [ 1.2270e-01,  6.1749e-03,  5.6012e-02,  1.0850e-01,  9.6627e-02],
                [-6.4985e-02, -5.4565e-02,  7.3250e-02,  4.8888e-03,  1.9920e-02],
                [ 3.1991e-02,  1.5198e-02, -2.6509e-02, -1.4226e-01, -1.3095e-01]],

               [[ 1.2591e-02, -5.1089e-02,  7.1670e-02,  9.0602e-02, -2.2682e-02],
                [-1.1734e-01, -1.1738e-01,  2.7109e-02, -1.0356e-01, -3.0563e-02],
                [-1.0117e-01, -1.4831e-01,  7.1119e-02,  2.7001e-02,  1.9252e-01],
                [ 1.5424e-02,  1.2361e-01,  1.2723e-01,  2.6140e-02, -1.6384e-02],
                [-4.4956e-02,  4.8875e-02, -3.7166e-02, -1.5698e-02, -5.9875e-02]]],


              [[[-4.0060e-02, -7.5376e-02,  9.0522e-02, -1.3705e-02,  5.0330e-02],
                [ 1.2410e-01, -6.8057e-04, -1.6394e-02, -4.6147e-02,  7.4656e-03],
                [-2.4582e-02, -4.8235e-02,  1.1284e-02, -2.6863e-02,  4.7994e-02],
                [ 1.7968e-02,  1.4713e-02, -4.4666e-02,  8.4956e-02, -1.1188e-01],
                [ 6.1636e-02,  5.7068e-02, -7.3509e-02,  1.2213e-02,  7.9951e-02]],

               [[ 1.3779e-02, -5.6942e-02,  6.5871e-02, -9.4319e-02, -1.0116e-01],
                [-8.5877e-02, -1.4712e-01,  2.4106e-02,  4.5554e-02, -1.7046e-01],
                [ 7.1870e-02,  8.6280e-02,  2.3830e-02,  1.2538e-01, -6.0586e-02],
                [-5.0712e-02, -5.6716e-02,  2.2811e-02,  4.6917e-03,  2.1000e-02],
                [ 1.4774e-02, -3.2031e-02,  1.3210e-01,  1.1982e-01, -8.9682e-02]],

               [[ 1.6284e-01, -1.0248e-01,  2.2695e-02, -5.2900e-02,  6.1035e-02],
                [ 2.0388e-02, -6.5181e-02, -1.0114e-01,  4.6024e-02, -4.4779e-02],
                [-3.0573e-04, -1.1372e-02, -3.3065e-02,  2.6070e-02, -5.5357e-02],
                [-1.0599e-02, -5.1236e-02,  9.6448e-02,  7.5023e-02, -7.7667e-02],
                [-7.2784e-02,  1.3918e-01, -1.4998e-02,  4.7092e-02,  5.4265e-02]],

               [[-3.7319e-02,  3.8862e-02, -6.2182e-02,  6.0732e-02,  1.1152e-01],
                [-6.8488e-02, -4.2621e-02, -1.0851e-01,  3.8274e-02, -2.8389e-02],
                [-2.4795e-03,  6.5596e-02, -4.4869e-03,  4.4756e-02, -6.8641e-02],
                [-1.9861e-02,  2.4411e-02,  1.5822e-01, -8.8753e-02, -4.8671e-02],
                [-3.9062e-02, -4.8847e-02, -7.2248e-02, -4.5459e-04,  2.7606e-02]],

               [[-3.5982e-03,  6.8191e-02,  9.6101e-02,  5.5497e-02,  7.6344e-02],
                [-1.9417e-02, -1.5556e-01, -1.9490e-02,  5.0801e-02, -3.7765e-02],
                [-1.1910e-02, -9.1878e-02,  4.2141e-02, -9.9299e-02, -5.6762e-02],
                [ 7.9792e-02,  1.8574e-01,  4.4123e-02, -6.9109e-03, -6.1165e-02],
                [-1.2580e-01,  3.9366e-02, -1.0186e-01,  3.6060e-02, -6.7405e-02]],

               [[-6.8059e-04,  7.9329e-02,  4.2778e-02,  4.9050e-03, -3.9437e-02],
                [-3.1596e-02,  5.7778e-02,  1.1185e-01,  4.8292e-02, -4.3491e-02],
                [ 9.8594e-02,  6.9557e-02,  1.0506e-02, -1.4311e-01, -7.4874e-02],
                [ 4.5525e-02,  1.2678e-01, -5.5705e-04, -3.9545e-02, -5.2146e-02],
                [-6.7104e-02,  2.1735e-02, -2.9231e-02,  1.9563e-03,  3.3551e-02]]],


              ...,


              [[[ 7.9785e-03,  9.6521e-02, -1.3083e-02,  8.3893e-02, -2.8526e-02],
                [-1.5271e-01,  6.7064e-02, -1.3156e-02,  1.5758e-02, -1.6052e-02],
                [-2.0791e-02,  1.1361e-01, -1.4968e-01,  3.5289e-02, -4.0983e-02],
                [-3.4127e-02,  7.5340e-02,  2.4503e-02,  1.1301e-01,  1.6968e-01],
                [-1.2136e-01,  1.4160e-02, -1.2052e-01,  3.4367e-02,  4.9111e-02]],

               [[ 9.0677e-02, -5.5500e-02, -9.1093e-02,  4.9036e-02,  4.8991e-02],
                [ 8.5656e-02, -1.3216e-01, -6.3505e-02, -3.2264e-02,  7.5033e-03],
                [ 1.0007e-01,  3.5466e-02,  5.0721e-02, -5.8828e-02, -1.1460e-01],
                [-3.6455e-02,  1.1234e-02,  2.6212e-02, -8.2686e-02, -6.2735e-02],
                [ 2.6178e-02, -3.6884e-03, -1.3710e-01,  1.6569e-02,  2.8620e-02]],

               [[-5.0623e-02, -4.4185e-02,  4.6047e-02,  3.9004e-02, -7.5805e-03],
                [-4.6561e-02,  4.3014e-02,  1.0764e-02, -4.7813e-03, -4.7869e-02],
                [-1.0108e-01, -1.7958e-03,  1.0872e-02,  1.2845e-01, -1.5940e-02],
                [ 4.3679e-02,  7.6629e-02,  2.1028e-02,  5.5007e-02, -7.4609e-02],
                [ 9.3900e-02, -8.2633e-02,  2.7914e-02,  2.5687e-02, -2.3510e-02]],

               [[ 7.7326e-03,  1.1611e-02,  2.9614e-02,  1.0284e-02,  2.9104e-02],
                [ 3.0220e-02, -8.0378e-02, -2.7949e-02,  2.2446e-02,  8.0459e-03],
                [ 6.5281e-02, -5.7804e-02,  8.3168e-02,  5.4988e-02, -6.6446e-02],
                [-7.8295e-02,  3.3191e-02, -9.8877e-02,  1.0647e-01, -3.7064e-03],
                [ 9.2358e-02,  1.5699e-01,  7.2006e-02, -3.4939e-02, -2.7969e-02]],

               [[ 4.8673e-02, -1.7830e-02, -7.4592e-02,  1.0337e-01,  4.3368e-02],
                [ 3.7064e-02,  5.9052e-02, -6.1762e-02,  5.2548e-03,  1.1638e-02],
                [-3.2117e-02, -1.3252e-01,  5.2480e-02, -5.5170e-02, -7.9090e-03],
                [-9.8518e-02,  9.4210e-02,  1.4956e-02,  6.6029e-02,  3.8304e-02],
                [-3.7070e-02, -7.0037e-02,  1.9132e-01, -4.1338e-02, -2.5241e-02]],

               [[-4.6659e-02,  4.2209e-03, -4.2241e-02,  1.9754e-02,  7.7949e-02],
                [ 6.5216e-02,  2.0466e-02,  1.5637e-02,  3.5587e-04,  6.2204e-02],
                [-1.2770e-01,  1.4639e-03,  5.8013e-02,  3.8882e-02, -3.6296e-02],
                [-5.8678e-02,  4.7554e-02, -4.9292e-02,  8.2163e-02, -2.8183e-02],
                [ 3.1189e-02, -4.4150e-02, -6.2355e-02,  4.8382e-02,  4.5446e-02]]],


              [[[ 9.5645e-02, -1.0301e-01,  4.6273e-02,  7.3074e-02, -6.2531e-03],
                [-1.1468e-02, -1.7400e-02,  3.8642e-03,  1.0186e-01,  6.0620e-02],
                [-3.4799e-03,  9.2991e-02,  1.3765e-01, -1.7145e-01, -8.0353e-02],
                [ 4.2481e-02,  4.6607e-02,  1.6294e-02,  3.1393e-02,  2.1176e-01],
                [ 1.1362e-01,  8.9815e-03,  6.5340e-02, -4.6683e-02,  6.6436e-02]],

               [[-4.0611e-02,  7.9461e-02,  1.3987e-01,  7.9224e-02,  1.3151e-02],
                [ 5.1524e-03, -1.9823e-02, -2.6767e-02,  7.3505e-02,  7.4152e-02],
                [-9.5245e-02,  5.9597e-02,  3.4118e-02, -3.6963e-02, -1.3391e-02],
                [-2.4452e-02, -4.9684e-03, -7.2417e-02, -2.5737e-02,  9.2375e-03],
                [ 5.6024e-02, -2.5624e-02, -2.9548e-02,  1.0268e-01,  6.6207e-02]],

               [[-2.8858e-02, -1.6285e-03,  2.3269e-03,  5.0814e-02, -1.1348e-01],
                [-1.0934e-01, -5.2677e-02,  5.9965e-02, -5.4806e-02,  1.5476e-02],
                [ 1.4079e-01, -1.4167e-03,  3.5588e-02, -8.0859e-02, -3.6366e-02],
                [ 1.9536e-02,  1.0368e-01, -7.7792e-02,  1.5817e-01,  5.5855e-02],
                [-8.3879e-02,  4.9135e-02,  4.2749e-02,  4.6981e-02, -3.3361e-02]],

               [[ 8.6765e-02, -7.7088e-03, -4.8883e-02,  3.6100e-02, -6.2219e-02],
                [ 4.5523e-02,  5.6251e-02, -5.2107e-02,  1.2930e-01, -5.1350e-02],
                [ 6.5248e-03, -1.0238e-01, -6.8020e-02, -1.1667e-01, -1.9077e-02],
                [ 6.1764e-02, -1.6211e-02,  1.3773e-01,  9.2896e-02, -5.1193e-02],
                [-3.0970e-02,  3.7340e-02,  3.6440e-02,  2.1100e-02,  1.5138e-03]],

               [[-3.9962e-02, -1.1714e-01,  5.2935e-02, -9.2095e-02,  5.2764e-02],
                [ 5.3439e-02,  8.3302e-02, -1.9304e-02,  3.3452e-02, -1.0472e-01],
                [ 5.2374e-03, -7.3624e-02, -2.3803e-02,  5.9410e-02,  2.6254e-02],
                [ 4.3904e-02,  7.5527e-02, -5.5117e-02, -3.0872e-02, -7.4420e-02],
                [ 5.9820e-02, -1.4020e-01, -3.1974e-02, -1.3241e-01, -8.4654e-02]],

               [[ 2.9300e-02,  8.5097e-02, -4.9078e-02,  1.3321e-02,  2.6851e-02],
                [-9.7055e-02,  1.2338e-01, -5.9273e-02, -1.1743e-01, -6.2792e-02],
                [-8.2991e-03,  3.5198e-02,  1.1417e-01, -7.3450e-02,  1.5150e-02],
                [ 2.6375e-02,  2.8174e-03, -3.1688e-02, -6.8431e-02,  1.1147e-01],
                [-2.8619e-02, -8.2053e-03,  5.5010e-02,  1.2504e-01, -1.0463e-03]]],


              [[[-1.1390e-01,  1.1178e-01,  2.0661e-02,  7.5986e-02, -3.5510e-02],
                [ 6.9884e-02,  7.9008e-03, -3.1892e-02, -6.2933e-02, -1.8630e-02],
                [-1.6785e-01, -1.3450e-01,  9.2579e-02, -9.7790e-02,  5.4691e-02],
                [-4.4550e-02, -8.8956e-03, -5.8851e-02,  1.7682e-01, -4.1232e-02],
                [ 1.1308e-01, -1.0594e-01,  8.4209e-02, -9.3546e-02,  4.6904e-02]],

               [[ 8.0923e-03,  5.9262e-02, -3.0235e-02,  7.7266e-02, -4.6861e-02],
                [ 9.1123e-03,  1.6089e-02,  2.7371e-02,  2.4351e-02,  4.7679e-03],
                [-1.2141e-01,  7.4713e-02, -1.1535e-01,  2.6721e-02,  1.0784e-01],
                [ 9.8689e-02, -6.4300e-02,  1.3670e-01,  7.1682e-02, -3.5123e-02],
                [-1.6633e-02, -1.0464e-01, -4.0829e-02,  5.4382e-02,  7.7213e-03]],

               [[ 3.5077e-02,  4.2135e-02,  3.0175e-02, -4.3647e-02,  1.5406e-01],
                [-1.6523e-02,  4.2702e-02,  3.7327e-03, -2.4782e-02, -1.9466e-02],
                [-9.2885e-05,  4.0268e-02, -2.1415e-02, -2.9881e-02, -5.3303e-02],
                [-2.6457e-02,  7.5620e-03,  2.6807e-02, -1.3104e-01,  1.0279e-02],
                [-5.7137e-02, -4.6889e-02, -5.5574e-02, -9.9751e-02, -3.5737e-02]],

               [[-9.9814e-02, -9.5435e-02, -2.6537e-02,  3.1237e-02,  1.5104e-02],
                [-9.9790e-02,  2.9716e-03,  5.0457e-02, -8.9010e-02, -3.9197e-02],
                [ 3.2547e-02,  6.0830e-03, -2.2157e-02, -7.0033e-02, -1.0172e-01],
                [ 4.9774e-02,  5.4871e-02,  5.6303e-02,  4.8631e-02, -3.4986e-02],
                [ 1.2138e-02, -2.2211e-02, -3.3582e-02, -1.7598e-02, -1.0158e-02]],

               [[-7.2373e-03, -4.2170e-02,  5.9605e-02, -8.4486e-02,  6.3807e-02],
                [ 1.0146e-01, -4.4409e-02,  6.9227e-02,  2.7258e-02, -4.4921e-02],
                [ 2.0162e-02, -1.3525e-01, -3.3158e-02, -4.9422e-02,  6.1210e-02],
                [-6.3034e-02,  8.9230e-02,  2.8164e-02, -1.8634e-02,  1.1193e-02],
                [-1.5343e-01,  5.8297e-02,  4.0476e-02, -8.3992e-03, -2.8954e-02]],

               [[-1.4177e-02, -1.4059e-01,  3.4488e-02, -1.1158e-01,  2.7263e-02],
                [ 9.5575e-02, -1.5587e-03,  6.7742e-02, -9.6203e-03, -5.8986e-02],
                [-1.8375e-03, -4.2862e-02,  7.6257e-02,  5.3290e-02, -5.9313e-02],
                [-1.1387e-02, -1.2381e-01,  2.8557e-02,  1.3121e-02,  1.9561e-02],
                [ 9.1249e-02, -3.3576e-03, -3.5484e-02,  2.8714e-02, -7.3947e-03]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-3.9858e-02, -6.2551e-02, -1.4050e-02,  5.5231e-02,  1.7201e-02,
              -2.6817e-02,  2.8532e-02, -2.6857e-02, -2.8279e-02,  9.3339e-04,
              -5.8307e-02, -6.2624e-02,  7.1998e-05,  1.1212e-02, -2.0352e-02,
              -5.8423e-02], requires_grad=True)
       tensor: tensor([-0.0379, -0.0506,  0.0362,  0.0304, -0.0709, -0.0483, -0.0155, -0.0590,
              -0.0637,  0.0148, -0.0911, -0.1465,  0.0204,  0.0746, -0.0178, -0.1370],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.]],

               [[-0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., 0., -0.]],

               [[-0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., 0., -0.]],

               [[-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., 0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., -0., 0., 0., -0.]],

               [[0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.]]],


              [[[-0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., 0., -0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[0., -0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., -0., -0.]]],


              [[[0., -0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.]],

               [[0., -0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.]],

               [[0., 0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., -0., -0., -0.]]],


              ...,


              [[[0., 0., -0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[0., 0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., -0., -0.]],

               [[-0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.]],

               [[-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., -0., 0., 0.]],

               [[0., 0., -0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.]],

               [[0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., -0.]],

               [[0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., 0., -0.]]],


              [[[-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., 0.]],

               [[0., 0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.]],

               [[0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., -0., 0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 0.0228, -0.0484,  0.0637, -0.0249, -0.0302],
                [ 0.0105,  0.0733,  0.0576,  0.0393,  0.0094],
                [-0.0590, -0.0747, -0.0506, -0.0495,  0.0402],
                [-0.0647,  0.0789, -0.0096,  0.0477,  0.0700],
                [-0.0450,  0.0176,  0.0443, -0.0176,  0.0407]],

               [[-0.0148,  0.0075, -0.0554,  0.0808,  0.0812],
                [ 0.0636, -0.0776,  0.0617,  0.0497,  0.0760],
                [-0.0350,  0.0633,  0.0318,  0.0562,  0.0742],
                [-0.0229, -0.0075,  0.0338,  0.0709, -0.0382],
                [ 0.0655,  0.0249,  0.0556,  0.0329, -0.0734]],

               [[-0.0531,  0.0393,  0.0355, -0.0358,  0.0298],
                [-0.0772, -0.0365,  0.0397, -0.0093, -0.0571],
                [ 0.0087,  0.0256,  0.0568,  0.0123,  0.0098],
                [ 0.0070,  0.0767,  0.0580,  0.0573,  0.0128],
                [ 0.0476, -0.0090,  0.0190,  0.0597, -0.0108]],

               [[-0.0097, -0.0755,  0.0251,  0.0349,  0.0423],
                [ 0.0800,  0.0594,  0.0019,  0.0419, -0.0641],
                [-0.0501, -0.0639, -0.0783,  0.0276,  0.0425],
                [ 0.0793, -0.0161,  0.0031,  0.0700, -0.0077],
                [ 0.0717,  0.0564,  0.0130,  0.0126,  0.0055]],

               [[-0.0711, -0.0004, -0.0350, -0.0729, -0.0567],
                [-0.0351,  0.0119,  0.0163,  0.0542,  0.0558],
                [-0.0308,  0.0337,  0.0634,  0.0556,  0.0051],
                [-0.0305, -0.0103, -0.0515, -0.0746,  0.0701],
                [ 0.0014, -0.0249,  0.0724,  0.0282, -0.0773]],

               [[ 0.0647,  0.0735, -0.0460,  0.0321,  0.0060],
                [-0.0624, -0.0262,  0.0131, -0.0270, -0.0768],
                [ 0.0053,  0.0301,  0.0262, -0.0012,  0.0230],
                [ 0.0758, -0.0297, -0.0011, -0.0424,  0.0659],
                [ 0.0155, -0.0027, -0.0785, -0.0715,  0.0603]]],


              [[[-0.0680,  0.0027, -0.0140,  0.0451,  0.0738],
                [-0.0323, -0.0084, -0.0281,  0.0522, -0.0210],
                [-0.0525, -0.0399, -0.0072, -0.0577,  0.0280],
                [-0.0374, -0.0055,  0.0143, -0.0258,  0.0404],
                [-0.0804,  0.0563, -0.0096, -0.0485, -0.0052]],

               [[ 0.0772, -0.0417, -0.0619, -0.0375, -0.0217],
                [ 0.0312, -0.0186, -0.0225,  0.0433,  0.0327],
                [ 0.0751, -0.0210,  0.0098,  0.0275,  0.0606],
                [ 0.0308, -0.0484, -0.0063,  0.0799, -0.0184],
                [-0.0195,  0.0014, -0.0263, -0.0651, -0.0203]],

               [[-0.0441, -0.0229, -0.0547,  0.0813, -0.0195],
                [-0.0343,  0.0006,  0.0058,  0.0113,  0.0668],
                [-0.0674,  0.0276, -0.0072,  0.0360, -0.0707],
                [-0.0468,  0.0240, -0.0658,  0.0561,  0.0396],
                [-0.0704, -0.0553,  0.0372, -0.0310,  0.0525]],

               [[-0.0081, -0.0563,  0.0125,  0.0304,  0.0486],
                [-0.0515,  0.0539, -0.0165,  0.0263,  0.0241],
                [-0.0189,  0.0789, -0.0254,  0.0502,  0.0336],
                [-0.0288, -0.0308,  0.0397, -0.0238, -0.0365],
                [ 0.0195,  0.0625, -0.0496,  0.0805,  0.0320]],

               [[ 0.0451,  0.0478, -0.0347, -0.0325,  0.0203],
                [-0.0460, -0.0406,  0.0251,  0.0666,  0.0718],
                [-0.0109, -0.0145,  0.0313,  0.0103,  0.0525],
                [-0.0299, -0.0280, -0.0036,  0.0186,  0.0350],
                [ 0.0073,  0.0463,  0.0292, -0.0732, -0.0485]],

               [[ 0.0059, -0.0394,  0.0745,  0.0484,  0.0053],
                [-0.0341, -0.0371,  0.0608,  0.0058, -0.0537],
                [-0.0605, -0.0791, -0.0253,  0.0762,  0.0762],
                [-0.0059,  0.0153,  0.0580,  0.0198, -0.0173],
                [-0.0174,  0.0578, -0.0368, -0.0736, -0.0631]]],


              [[[ 0.0033, -0.0336,  0.0105,  0.0191,  0.0074],
                [ 0.0714, -0.0613, -0.0475, -0.0452, -0.0098],
                [-0.0617, -0.0515,  0.0353, -0.0375,  0.0360],
                [ 0.0542, -0.0326, -0.0236,  0.0674, -0.0347],
                [ 0.0394,  0.0746, -0.0467,  0.0522,  0.0424]],

               [[ 0.0503, -0.0693,  0.0626, -0.0424, -0.0488],
                [-0.0617, -0.0619, -0.0342,  0.0127, -0.0253],
                [ 0.0319,  0.0270, -0.0089,  0.0175, -0.0308],
                [-0.0551,  0.0107,  0.0212,  0.0026,  0.0019],
                [ 0.0115,  0.0343,  0.0572,  0.0811, -0.0354]],

               [[ 0.0632, -0.0197,  0.0072, -0.0372,  0.0581],
                [-0.0492, -0.0602, -0.0201,  0.0465, -0.0482],
                [ 0.0403, -0.0086, -0.0292, -0.0711, -0.0294],
                [-0.0228, -0.0102,  0.0450,  0.0804,  0.0032],
                [-0.0009,  0.0815,  0.0113,  0.0315, -0.0110]],

               [[-0.0627,  0.0313, -0.0443, -0.0353,  0.0797],
                [-0.0039, -0.0053, -0.0678, -0.0335, -0.0764],
                [-0.0202,  0.0296,  0.0548,  0.0185, -0.0028],
                [ 0.0044,  0.0496,  0.0734, -0.0418, -0.0792],
                [ 0.0025,  0.0171, -0.0681, -0.0224,  0.0077]],

               [[ 0.0424,  0.0476,  0.0079, -0.0167,  0.0523],
                [-0.0335,  0.0003, -0.0560, -0.0448, -0.0173],
                [ 0.0426, -0.0002,  0.0133, -0.0705,  0.0085],
                [ 0.0724,  0.0664,  0.0455,  0.0461, -0.0607],
                [-0.0546,  0.0696, -0.0658,  0.0364, -0.0562]],

               [[-0.0390,  0.0704,  0.0786,  0.0033,  0.0295],
                [-0.0145, -0.0198,  0.0638,  0.0332, -0.0197],
                [ 0.0663,  0.0625, -0.0292,  0.0036, -0.0408],
                [ 0.0201,  0.0786, -0.0615, -0.0266, -0.0627],
                [-0.0323,  0.0509, -0.0516, -0.0624, -0.0025]]],


              ...,


              [[[ 0.0777,  0.0133, -0.0427,  0.0800, -0.0246],
                [-0.0430,  0.0370,  0.0550,  0.0002, -0.0453],
                [-0.0782,  0.0747, -0.0467,  0.0130, -0.0217],
                [ 0.0417,  0.0371,  0.0066,  0.0761,  0.0344],
                [ 0.0590,  0.0188, -0.0681,  0.0579,  0.0579]],

               [[ 0.0803,  0.0284, -0.0485,  0.0127,  0.0668],
                [-0.0049, -0.0313, -0.0693, -0.0646, -0.0111],
                [ 0.0460,  0.0529,  0.0564, -0.0790, -0.0583],
                [-0.0102,  0.0005,  0.0437,  0.0308, -0.0437],
                [ 0.0205,  0.0394, -0.0644,  0.0770,  0.0368]],

               [[ 0.0157, -0.0298, -0.0073,  0.0653, -0.0469],
                [-0.0527,  0.0319, -0.0647,  0.0084, -0.0241],
                [ 0.0004,  0.0514, -0.0025,  0.0630, -0.0294],
                [ 0.0689,  0.0191,  0.0546,  0.0365, -0.0539],
                [-0.0045, -0.0745,  0.0544, -0.0686,  0.0816]],

               [[ 0.0061, -0.0117,  0.0211,  0.0131, -0.0410],
                [ 0.0168, -0.0675, -0.0497,  0.0285, -0.0718],
                [-0.0401, -0.0345,  0.0003, -0.0376, -0.0554],
                [ 0.0065,  0.0611, -0.0163,  0.0418, -0.0605],
                [-0.0081,  0.0791,  0.0381, -0.0200, -0.0402]],

               [[ 0.0350,  0.0182, -0.0750,  0.0471,  0.0045],
                [ 0.0441,  0.0266, -0.0399, -0.0425, -0.0318],
                [-0.0428, -0.0136, -0.0466, -0.0151,  0.0100],
                [-0.0644,  0.0688,  0.0709,  0.0026,  0.0585],
                [-0.0269, -0.0130,  0.0715, -0.0413, -0.0011]],

               [[-0.0537,  0.0477, -0.0147, -0.0553,  0.0459],
                [ 0.0619,  0.0005, -0.0194,  0.0200,  0.0296],
                [-0.0205, -0.0382,  0.0277,  0.0322, -0.0201],
                [-0.0251, -0.0238, -0.0758,  0.0534,  0.0487],
                [-0.0023,  0.0328, -0.0237,  0.0328, -0.0399]]],


              [[[ 0.0501, -0.0288, -0.0173,  0.0392, -0.0050],
                [ 0.0156,  0.0807,  0.0459, -0.0056,  0.0346],
                [-0.0486,  0.0271,  0.0530, -0.0753, -0.0678],
                [ 0.0070,  0.0007, -0.0208, -0.0359,  0.0587],
                [ 0.0728, -0.0304,  0.0334, -0.0355,  0.0323]],

               [[-0.0734,  0.0403,  0.0763,  0.0570, -0.0033],
                [-0.0664,  0.0105, -0.0366,  0.0042,  0.0546],
                [-0.0745,  0.0639, -0.0417, -0.0033,  0.0495],
                [-0.0550,  0.0292, -0.0198, -0.0241,  0.0367],
                [ 0.0107, -0.0213, -0.0517,  0.0664,  0.0738]],

               [[ 0.0123,  0.0034, -0.0046,  0.0620, -0.0356],
                [-0.0580, -0.0810,  0.0235,  0.0637, -0.0323],
                [ 0.0501,  0.0611,  0.0696, -0.0725, -0.0392],
                [-0.0327,  0.0616, -0.0098,  0.0648,  0.0049],
                [-0.0644,  0.0434,  0.0465,  0.0378,  0.0250]],

               [[ 0.0816,  0.0159,  0.0255,  0.0219, -0.0652],
                [ 0.0783,  0.0748, -0.0595,  0.0515, -0.0486],
                [-0.0709, -0.0491, -0.0587, -0.0085, -0.0437],
                [ 0.0395, -0.0117,  0.0683,  0.0806, -0.0066],
                [-0.0332, -0.0257, -0.0023, -0.0359,  0.0064]],

               [[-0.0328, -0.0616, -0.0107, -0.0231, -0.0393],
                [ 0.0030,  0.0048, -0.0813, -0.0253, -0.0723],
                [-0.0680, -0.0350,  0.0409,  0.0464,  0.0235],
                [-0.0085,  0.0688, -0.0767, -0.0011, -0.0570],
                [ 0.0553, -0.0721,  0.0039, -0.0811, -0.0608]],

               [[ 0.0617,  0.0303, -0.0521,  0.0155, -0.0364],
                [-0.0589, -0.0223, -0.0112, -0.0599, -0.0590],
                [ 0.0729,  0.0326,  0.0761, -0.0415, -0.0048],
                [ 0.0036, -0.0197, -0.0393, -0.0060,  0.0785],
                [-0.0679, -0.0750,  0.0671,  0.0385, -0.0260]]],


              [[[-0.0456,  0.0197,  0.0548,  0.0420, -0.0569],
                [ 0.0518, -0.0172, -0.0758, -0.0328,  0.0196],
                [-0.0712, -0.0446,  0.0593, -0.0403, -0.0250],
                [-0.0142, -0.0058, -0.0283,  0.0783,  0.0075],
                [ 0.0755,  0.0161,  0.0319, -0.0562,  0.0378]],

               [[ 0.0572,  0.0773,  0.0243,  0.0638, -0.0472],
                [-0.0081,  0.0225,  0.0298, -0.0442, -0.0075],
                [-0.0814,  0.0369, -0.0680, -0.0471,  0.0187],
                [ 0.0290, -0.0338,  0.0786,  0.0685, -0.0263],
                [-0.0453, -0.0716, -0.0462,  0.0556,  0.0159]],

               [[ 0.0359, -0.0511,  0.0707, -0.0696,  0.0407],
                [-0.0717,  0.0521,  0.0813,  0.0335, -0.0515],
                [-0.0295, -0.0124, -0.0406, -0.0247,  0.0162],
                [-0.0252,  0.0105,  0.0624, -0.0701,  0.0153],
                [-0.0490,  0.0815,  0.0331,  0.0130, -0.0174]],

               [[-0.0801, -0.0390, -0.0246,  0.0187,  0.0752],
                [-0.0654, -0.0242,  0.0666,  0.0303, -0.0114],
                [ 0.0783, -0.0565, -0.0200, -0.0462, -0.0119],
                [ 0.0788,  0.0656,  0.0623,  0.0350,  0.0254],
                [-0.0227,  0.0380, -0.0172,  0.0293, -0.0065]],

               [[-0.0086, -0.0572,  0.0217, -0.0286,  0.0476],
                [ 0.0695, -0.0679,  0.0714,  0.0371,  0.0638],
                [ 0.0099, -0.0652, -0.0545, -0.0068,  0.0805],
                [-0.0506, -0.0737, -0.0110, -0.0198,  0.0047],
                [-0.0288,  0.0730,  0.0794, -0.0033,  0.0242]],

               [[-0.0616, -0.0632, -0.0110, -0.0658, -0.0470],
                [ 0.0425, -0.0136,  0.0665, -0.0201, -0.0727],
                [ 0.0189,  0.0189,  0.0641, -0.0384,  0.0180],
                [ 0.0002, -0.0737,  0.0365,  0.0311, -0.0378],
                [ 0.0789,  0.0037, -0.0582,  0.0148,  0.0323]]]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-3.9858e-02, -6.2551e-02, -1.4050e-02,  5.5231e-02,  1.7201e-02,
              -2.6817e-02,  2.8532e-02, -2.6857e-02, -2.8279e-02,  9.3339e-04,
              -5.8307e-02, -6.2624e-02,  7.1998e-05,  1.1212e-02, -2.0352e-02,
              -5.8423e-02])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0337, -0.0484,  0.0253,  ..., -0.0286, -0.0485, -0.0091],
              [ 0.0469,  0.0292,  0.0258,  ...,  0.0047, -0.0409,  0.0409],
              [ 0.0367,  0.0313,  0.0040,  ...,  0.0234, -0.0487, -0.0428],
              ...,
              [-0.0139, -0.0203,  0.0175,  ..., -0.0324, -0.0387,  0.0258],
              [ 0.0096,  0.0293,  0.0120,  ...,  0.0264,  0.0297,  0.0347],
              [-0.0126, -0.0436,  0.0311,  ..., -0.0195,  0.0352,  0.0191]],
             requires_grad=True)
       tensor: tensor([[ 0.0732, -0.0255, -0.0578,  ..., -0.0882, -0.0816, -0.0213],
              [ 0.1146,  0.0466, -0.0112,  ...,  0.0476,  0.0129,  0.0057],
              [ 0.0800, -0.0415, -0.0323,  ...,  0.0578, -0.0181, -0.0136],
              ...,
              [-0.0417, -0.0255,  0.0027,  ..., -0.0518, -0.0193,  0.0991],
              [-0.0082,  0.0043,  0.0350,  ...,  0.0453,  0.0136,  0.0155],
              [ 0.0230, -0.0360, -0.0075,  ..., -0.0218,  0.0702,  0.0851]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0227, -0.0239, -0.0015,  0.0243,  0.0179,  0.0240,  0.0462,  0.0295,
              -0.0409, -0.0290, -0.0125, -0.0128, -0.0085,  0.0326, -0.0219,  0.0423,
               0.0489, -0.0179, -0.0246, -0.0301,  0.0485, -0.0077,  0.0185,  0.0481,
              -0.0032, -0.0149, -0.0211, -0.0451,  0.0282, -0.0297,  0.0339, -0.0216,
              -0.0034, -0.0181, -0.0157,  0.0106, -0.0460, -0.0273, -0.0288, -0.0371,
              -0.0023, -0.0038,  0.0074,  0.0185, -0.0443,  0.0172, -0.0478,  0.0460,
               0.0363,  0.0282, -0.0387, -0.0109, -0.0231,  0.0013, -0.0196,  0.0272,
               0.0049,  0.0265, -0.0020, -0.0435,  0.0185, -0.0403,  0.0289,  0.0379,
              -0.0112, -0.0080, -0.0427,  0.0491, -0.0431, -0.0402, -0.0102, -0.0105,
              -0.0474, -0.0193, -0.0236, -0.0411,  0.0166,  0.0335, -0.0161, -0.0324,
              -0.0196,  0.0304, -0.0400,  0.0024,  0.0160,  0.0193,  0.0080, -0.0252,
               0.0398, -0.0498,  0.0386,  0.0138,  0.0152,  0.0196,  0.0355, -0.0123,
              -0.0179,  0.0390,  0.0361, -0.0140, -0.0484,  0.0458,  0.0205, -0.0043,
              -0.0300,  0.0102, -0.0160, -0.0108, -0.0162,  0.0330,  0.0324, -0.0429,
               0.0008,  0.0134,  0.0364, -0.0246,  0.0498,  0.0140, -0.0339,  0.0392],
             requires_grad=True)
       tensor: tensor([ 1.1447e-03, -5.1819e-02,  1.6640e-02,  6.5499e-02, -8.0907e-02,
              -4.8519e-02,  5.6677e-02, -4.8956e-02, -7.7639e-03, -5.7330e-03,
              -7.1541e-02, -8.8738e-03,  3.3044e-03,  3.5168e-02, -5.3030e-06,
               1.1116e-01,  5.3860e-02,  1.8072e-02, -4.6139e-02, -4.0794e-02,
               2.5714e-02, -9.4141e-03, -4.2178e-04, -2.5774e-02, -7.7567e-02,
              -8.2163e-02,  2.3858e-02, -8.7571e-02,  6.2566e-02, -8.3659e-02,
               1.3654e-01, -4.0236e-03,  9.7985e-02,  9.0484e-02,  5.9550e-02,
               2.4722e-02, -3.9010e-02, -2.2623e-02,  5.7798e-02, -3.9112e-02,
              -3.8512e-02,  7.7203e-02, -3.1446e-02, -2.6466e-02, -4.2992e-02,
               3.5493e-03, -2.0020e-02,  5.2569e-03, -2.8117e-02,  7.0782e-02,
               9.1462e-04, -4.1758e-02, -1.9813e-03,  3.4395e-02,  6.1325e-02,
              -1.3560e-02, -8.4248e-02, -4.5469e-02, -1.5659e-02, -1.4572e-02,
              -1.8009e-02, -1.7755e-02,  5.8105e-03,  2.0910e-02,  5.4390e-03,
              -5.3495e-02, -1.9020e-02,  6.4683e-02, -5.4073e-02, -3.9146e-02,
               5.6359e-02, -2.8833e-02, -1.3513e-02,  4.1304e-02, -8.1053e-02,
              -4.9787e-02, -8.1154e-03, -6.9614e-02, -8.6335e-02, -8.2642e-02,
               1.1239e-02, -2.9094e-02, -7.9066e-02,  8.6259e-02,  2.8128e-02,
              -1.7309e-02, -3.9473e-02,  3.1097e-03,  6.6451e-02,  2.0752e-02,
               6.5958e-02,  1.3649e-02,  1.2680e-01,  1.0900e-02, -2.6149e-02,
               5.0733e-02,  4.0548e-02,  7.8464e-02,  4.5867e-02,  3.9911e-02,
              -5.9929e-02,  3.6826e-02, -5.9146e-02, -3.4829e-02, -8.1041e-02,
               5.6606e-02, -2.2417e-02, -4.2888e-03, -3.5870e-02,  2.8688e-02,
               1.1745e-01,  1.3533e-03,  4.2071e-02, -3.4773e-02,  1.8792e-02,
              -3.4782e-02,  7.3807e-02, -2.7561e-02, -1.0201e-01,  9.2736e-02],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0.,  ..., -0., -0., -0.],
              [0., 0., 0.,  ..., 0., -0., 0.],
              [0., 0., 0.,  ..., 0., -0., -0.],
              ...,
              [-0., -0., 0.,  ..., -0., -0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [-0., -0., 0.,  ..., -0., 0., 0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0337, -0.0484,  0.0253,  ..., -0.0286, -0.0485, -0.0091],
              [ 0.0469,  0.0292,  0.0258,  ...,  0.0047, -0.0409,  0.0409],
              [ 0.0367,  0.0313,  0.0040,  ...,  0.0234, -0.0487, -0.0428],
              ...,
              [-0.0139, -0.0203,  0.0175,  ..., -0.0324, -0.0387,  0.0258],
              [ 0.0096,  0.0293,  0.0120,  ...,  0.0264,  0.0297,  0.0347],
              [-0.0126, -0.0436,  0.0311,  ..., -0.0195,  0.0352,  0.0191]])
      (bias): Normal:
       loc: tensor([0., -0., -0., 0., 0., 0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0.,
              0., 0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., -0., -0., -0.,
              -0., -0., -0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0.,
              -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0227, -0.0239, -0.0015,  0.0243,  0.0179,  0.0240,  0.0462,  0.0295,
              -0.0409, -0.0290, -0.0125, -0.0128, -0.0085,  0.0326, -0.0219,  0.0423,
               0.0489, -0.0179, -0.0246, -0.0301,  0.0485, -0.0077,  0.0185,  0.0481,
              -0.0032, -0.0149, -0.0211, -0.0451,  0.0282, -0.0297,  0.0339, -0.0216,
              -0.0034, -0.0181, -0.0157,  0.0106, -0.0460, -0.0273, -0.0288, -0.0371,
              -0.0023, -0.0038,  0.0074,  0.0185, -0.0443,  0.0172, -0.0478,  0.0460,
               0.0363,  0.0282, -0.0387, -0.0109, -0.0231,  0.0013, -0.0196,  0.0272,
               0.0049,  0.0265, -0.0020, -0.0435,  0.0185, -0.0403,  0.0289,  0.0379,
              -0.0112, -0.0080, -0.0427,  0.0491, -0.0431, -0.0402, -0.0102, -0.0105,
              -0.0474, -0.0193, -0.0236, -0.0411,  0.0166,  0.0335, -0.0161, -0.0324,
              -0.0196,  0.0304, -0.0400,  0.0024,  0.0160,  0.0193,  0.0080, -0.0252,
               0.0398, -0.0498,  0.0386,  0.0138,  0.0152,  0.0196,  0.0355, -0.0123,
              -0.0179,  0.0390,  0.0361, -0.0140, -0.0484,  0.0458,  0.0205, -0.0043,
              -0.0300,  0.0102, -0.0160, -0.0108, -0.0162,  0.0330,  0.0324, -0.0429,
               0.0008,  0.0134,  0.0364, -0.0246,  0.0498,  0.0140, -0.0339,  0.0392])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0552, -0.0814, -0.0594, -0.0877,  0.0260,  0.0764,  0.0603, -0.0690,
               -0.0617,  0.0816,  0.0810,  0.0872, -0.0560,  0.0101,  0.0851, -0.0210,
               -0.0042,  0.0030,  0.0784, -0.0529,  0.0262,  0.0469,  0.0635,  0.0703,
               -0.0497, -0.0707, -0.0536, -0.0544,  0.0331, -0.0068,  0.0086, -0.0108,
               -0.0055, -0.0076,  0.0049, -0.0805,  0.0471,  0.0136, -0.0250, -0.0345,
                0.0100, -0.0639,  0.0511,  0.0064, -0.0230,  0.0270, -0.0317, -0.0259,
               -0.0307, -0.0506, -0.0617, -0.0262, -0.0688, -0.0900,  0.0328,  0.0166,
               -0.0635,  0.0072,  0.0536, -0.0793, -0.0726,  0.0021, -0.0797,  0.0413,
                0.0666,  0.0185, -0.0274, -0.0572, -0.0811, -0.0831, -0.0043, -0.0517,
               -0.0639,  0.0666, -0.0211, -0.0832, -0.0887,  0.0568, -0.0423,  0.0110,
               -0.0736,  0.0800,  0.0822,  0.0823,  0.0557,  0.0646, -0.0729,  0.0857,
                0.0337, -0.0079,  0.0632, -0.0813, -0.0178, -0.0147,  0.0018,  0.0151,
                0.0909, -0.0895,  0.0524, -0.0362, -0.0328,  0.0500,  0.0494, -0.0450,
               -0.0204, -0.0412,  0.0766, -0.0161, -0.0584, -0.0680,  0.0278,  0.0007,
               -0.0566,  0.0467,  0.0536, -0.0230,  0.0731,  0.0413, -0.0785, -0.0119],
              [ 0.0685, -0.0193,  0.0604,  0.0138,  0.0828,  0.0634, -0.0749,  0.0419,
               -0.0212, -0.0736, -0.0009,  0.0376, -0.0012,  0.0102, -0.0813, -0.0153,
               -0.0003, -0.0116,  0.0483, -0.0689, -0.0361, -0.0136,  0.0256, -0.0330,
                0.0639,  0.0103, -0.0673,  0.0759,  0.0751,  0.0812,  0.0010, -0.0302,
                0.0461, -0.0861, -0.0432, -0.0070,  0.0077, -0.0529,  0.0748,  0.0328,
               -0.0610,  0.0524,  0.0129,  0.0665,  0.0886, -0.0200,  0.0524,  0.0696,
               -0.0629, -0.0878,  0.0708,  0.0556, -0.0741,  0.0161, -0.0897,  0.0735,
                0.0793, -0.0354,  0.0309, -0.0521,  0.0006,  0.0265, -0.0274,  0.0792,
                0.0860, -0.0430, -0.0282,  0.0335, -0.0274,  0.0322,  0.0616, -0.0157,
               -0.0142,  0.0187,  0.0102, -0.0078,  0.0554, -0.0854, -0.0591,  0.0875,
               -0.0630, -0.0741, -0.0793,  0.0149,  0.0818, -0.0127, -0.0881,  0.0015,
               -0.0594,  0.0065, -0.0806,  0.0295, -0.0144,  0.0879,  0.0663,  0.0900,
                0.0258, -0.0155,  0.0731, -0.0102, -0.0611, -0.0473, -0.0538,  0.0004,
               -0.0415, -0.0457, -0.0481, -0.0759,  0.0621, -0.0188,  0.0160,  0.0484,
               -0.0819, -0.0031,  0.0558,  0.0735,  0.0219, -0.0744, -0.0153, -0.0397]],
             requires_grad=True)
       tensor: tensor([[ 0.0155, -0.1381, -0.0459, -0.1173,  0.0346,  0.0912,  0.0721, -0.0664,
                0.0061,  0.0576,  0.1026,  0.0628, -0.0534, -0.0191,  0.0756, -0.0014,
                0.0800,  0.0257,  0.0461, -0.0446, -0.0176,  0.1071,  0.1174, -0.0118,
               -0.0709, -0.1369, -0.0317, -0.0239,  0.0649,  0.0462, -0.0075,  0.0183,
                0.0289,  0.0655,  0.0824, -0.0964,  0.1179,  0.0729, -0.0761,  0.0014,
                0.1068, -0.0576,  0.0095, -0.0066, -0.0316, -0.0846,  0.0223, -0.0321,
               -0.1153, -0.0251,  0.0454,  0.0507, -0.1412, -0.1396,  0.0252,  0.0494,
               -0.1182, -0.0545,  0.0342, -0.1580, -0.0585,  0.0391, -0.1098,  0.0674,
                0.0748,  0.0277, -0.1225, -0.0614, -0.0111, -0.0965, -0.0249, -0.0885,
               -0.0442,  0.0160,  0.0275,  0.0262, -0.1012, -0.0106, -0.1147, -0.0267,
               -0.0619,  0.1101,  0.0542,  0.0399,  0.0702,  0.0073, -0.0637,  0.0863,
                0.0255,  0.0294,  0.0963, -0.1078,  0.0166, -0.0889, -0.0432,  0.0194,
                0.1233, -0.1050,  0.0582, -0.0849, -0.0124,  0.0670, -0.0239, -0.0239,
               -0.0879, -0.0510,  0.0547,  0.0301, -0.1166, -0.1147, -0.0102,  0.0029,
               -0.0347,  0.1045,  0.0452, -0.0121,  0.1167,  0.0179, -0.0259, -0.0612],
              [ 0.0958, -0.0319,  0.0668,  0.0619,  0.1232,  0.0888, -0.0047,  0.1079,
               -0.0665, -0.0406, -0.1337,  0.0833,  0.0152,  0.0016, -0.1006, -0.0306,
               -0.0050, -0.0641, -0.0259, -0.0763, -0.0317, -0.0147,  0.0674, -0.0595,
                0.0330, -0.0659, -0.0305,  0.0856,  0.0609,  0.0719, -0.0262, -0.0806,
                0.1531, -0.1720,  0.0591, -0.0311,  0.0744, -0.0318,  0.0379, -0.0081,
               -0.0684,  0.0417,  0.0077,  0.0655,  0.0082, -0.1232, -0.0634,  0.1215,
               -0.0478, -0.1038,  0.0839,  0.1241, -0.0877, -0.0262, -0.0386,  0.1171,
                0.1270, -0.0735,  0.0073, -0.0233, -0.0936,  0.0021, -0.0512,  0.0938,
                0.0540, -0.1652, -0.0778,  0.0516,  0.0366,  0.0138,  0.0603, -0.0893,
                0.0044,  0.0672,  0.0058,  0.0546, -0.0217, -0.0781, -0.0321,  0.0029,
               -0.0905, -0.1680, -0.0874, -0.0058,  0.1265, -0.0179, -0.0598,  0.0163,
               -0.0342,  0.0565, -0.0393,  0.0613, -0.0383,  0.0677,  0.1313,  0.1180,
                0.0259, -0.0481,  0.1126,  0.0305, -0.0701, -0.0525, -0.0936,  0.0749,
               -0.0364, -0.0313, -0.0123,  0.0132,  0.0251,  0.0448, -0.0225,  0.1201,
               -0.1090, -0.0413,  0.1151,  0.1027,  0.0813, -0.0764, -0.0066, -0.0852]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0430, -0.0439], requires_grad=True)
       tensor: tensor([ 0.1120, -0.0376], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., -0., -0., 0., 0., 0., -0., -0., 0., 0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., 0., 0., 0.,
               -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., -0., 0., -0., -0.,
               -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., -0., -0., -0.,
               -0., 0., -0., -0., -0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
               0., -0., 0., -0., -0., 0., 0., -0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., 0., -0., -0.],
              [0., -0., 0., 0., 0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., -0.,
               0., 0., -0., 0., 0., 0., 0., -0., 0., -0., -0., -0., 0., -0., 0., 0., -0., 0., 0., 0., 0., -0., 0., 0.,
               -0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., -0.,
               -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., -0., 0., 0., 0.,
               0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., 0., 0., 0., -0., -0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0552, -0.0814, -0.0594, -0.0877,  0.0260,  0.0764,  0.0603, -0.0690,
               -0.0617,  0.0816,  0.0810,  0.0872, -0.0560,  0.0101,  0.0851, -0.0210,
               -0.0042,  0.0030,  0.0784, -0.0529,  0.0262,  0.0469,  0.0635,  0.0703,
               -0.0497, -0.0707, -0.0536, -0.0544,  0.0331, -0.0068,  0.0086, -0.0108,
               -0.0055, -0.0076,  0.0049, -0.0805,  0.0471,  0.0136, -0.0250, -0.0345,
                0.0100, -0.0639,  0.0511,  0.0064, -0.0230,  0.0270, -0.0317, -0.0259,
               -0.0307, -0.0506, -0.0617, -0.0262, -0.0688, -0.0900,  0.0328,  0.0166,
               -0.0635,  0.0072,  0.0536, -0.0793, -0.0726,  0.0021, -0.0797,  0.0413,
                0.0666,  0.0185, -0.0274, -0.0572, -0.0811, -0.0831, -0.0043, -0.0517,
               -0.0639,  0.0666, -0.0211, -0.0832, -0.0887,  0.0568, -0.0423,  0.0110,
               -0.0736,  0.0800,  0.0822,  0.0823,  0.0557,  0.0646, -0.0729,  0.0857,
                0.0337, -0.0079,  0.0632, -0.0813, -0.0178, -0.0147,  0.0018,  0.0151,
                0.0909, -0.0895,  0.0524, -0.0362, -0.0328,  0.0500,  0.0494, -0.0450,
               -0.0204, -0.0412,  0.0766, -0.0161, -0.0584, -0.0680,  0.0278,  0.0007,
               -0.0566,  0.0467,  0.0536, -0.0230,  0.0731,  0.0413, -0.0785, -0.0119],
              [ 0.0685, -0.0193,  0.0604,  0.0138,  0.0828,  0.0634, -0.0749,  0.0419,
               -0.0212, -0.0736, -0.0009,  0.0376, -0.0012,  0.0102, -0.0813, -0.0153,
               -0.0003, -0.0116,  0.0483, -0.0689, -0.0361, -0.0136,  0.0256, -0.0330,
                0.0639,  0.0103, -0.0673,  0.0759,  0.0751,  0.0812,  0.0010, -0.0302,
                0.0461, -0.0861, -0.0432, -0.0070,  0.0077, -0.0529,  0.0748,  0.0328,
               -0.0610,  0.0524,  0.0129,  0.0665,  0.0886, -0.0200,  0.0524,  0.0696,
               -0.0629, -0.0878,  0.0708,  0.0556, -0.0741,  0.0161, -0.0897,  0.0735,
                0.0793, -0.0354,  0.0309, -0.0521,  0.0006,  0.0265, -0.0274,  0.0792,
                0.0860, -0.0430, -0.0282,  0.0335, -0.0274,  0.0322,  0.0616, -0.0157,
               -0.0142,  0.0187,  0.0102, -0.0078,  0.0554, -0.0854, -0.0591,  0.0875,
               -0.0630, -0.0741, -0.0793,  0.0149,  0.0818, -0.0127, -0.0881,  0.0015,
               -0.0594,  0.0065, -0.0806,  0.0295, -0.0144,  0.0879,  0.0663,  0.0900,
                0.0258, -0.0155,  0.0731, -0.0102, -0.0611, -0.0473, -0.0538,  0.0004,
               -0.0415, -0.0457, -0.0481, -0.0759,  0.0621, -0.0188,  0.0160,  0.0484,
               -0.0819, -0.0031,  0.0558,  0.0735,  0.0219, -0.0744, -0.0153, -0.0397]])
      (bias): Normal:
       loc: tensor([0., -0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0430, -0.0439])
    )
    (observed): Observed()
  )
)

Fit the model

Finally we can set up the training loop

optim = torch.optim.Adam(net.parameters())
for i in range(1):
    for data, target in loader:
        net.observe(classification=target)
        borch.sample(net)
        net(data)
        loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
        loss.backward()
        optim.step()
        optim.zero_grad()

Now we can check the accuracy, Note that one should stop condtioning on the target by setting net.observe(None)

net.observe(None)
tot_acc = 0
with torch.no_grad():
    for i, (data, target) in enumerate(loader):
        borch.sample(net)
        out = net(data)
        acc = float((target == out).sum().float() / target.shape[0]) * 100
        tot_acc += acc
    tot_acc /= i + 1
print(tot_acc)

Out:

64.99999761581421

the accuracy is basically random, this is due to the fact that we are fitting white noise so it to be expected.

But in case you have trouble getting higher accuracy you should consider running for more epochs, setting up an augmentation pipeline (see: the data loading tutorial) and changing your posterior. The posterior can be changed using

net.apply(borch.set_posteriors(borch.posterior.Automatic))

Out:

Net(
  (posterior): Automatic()
  (prior): Module(
    (classification): Categorical:
     logits: tensor([[0.4066, 0.9223],
            [0.3959, 0.5796],
            [0.2345, 1.0376],
            [0.4777, 0.7572],
            [0.3595, 0.5583],
            [0.5373, 0.4655],
            [0.3681, 0.9276],
            [0.4100, 0.7398],
            [0.4180, 0.5060],
            [0.4787, 0.8265],
            [0.5473, 0.8474],
            [0.5303, 0.7285],
            [0.4437, 0.7785],
            [0.3840, 0.7335],
            [0.6166, 0.8975],
            [0.1428, 0.6157],
            [0.2379, 0.7260],
            [0.3817, 0.7180],
            [0.5401, 0.6208],
            [0.3308, 0.8017]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([])
  )
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., 0., 0., -0., 0., 0.], requires_grad=True)
       tensor: tensor([-0.1332, -0.5036,  0.1559, -0.2012, -0.6483,  0.5744],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.]]],


              [[[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[0., 0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., 0.]]],


              [[[-0., 0., 0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.]]],


              [[[-0., 0., 0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., 0., 0.],
                [0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.1927,  0.1399, -0.1354, -0.1729, -0.1509],
                [-0.1643,  0.1672,  0.1231,  0.0760,  0.1983],
                [-0.1877,  0.1498,  0.1842, -0.1852, -0.0757],
                [-0.0739, -0.0365,  0.1345,  0.0660, -0.1410],
                [ 0.0395,  0.1924, -0.0889,  0.1817, -0.1028]]],


              [[[ 0.0161, -0.0283,  0.0567,  0.0753,  0.1422],
                [ 0.1557, -0.0744,  0.1324,  0.1368,  0.1860],
                [ 0.1084, -0.0066, -0.1755,  0.0982, -0.1038],
                [ 0.0692, -0.1975,  0.1228,  0.1460, -0.1969],
                [-0.0408,  0.0335, -0.1200,  0.0135, -0.0343]]],


              [[[ 0.1090, -0.1523, -0.0839, -0.1336,  0.1845],
                [-0.0021, -0.1854, -0.0692,  0.0818,  0.0268],
                [ 0.0554,  0.0253, -0.0801,  0.0925, -0.1053],
                [-0.0562,  0.0456,  0.1366, -0.1447, -0.1639],
                [-0.0741, -0.1671,  0.1945, -0.1811, -0.1519]]],


              [[[ 0.0950,  0.1374,  0.1735,  0.1682,  0.0029],
                [ 0.0662, -0.0615, -0.1451,  0.1452, -0.1408],
                [-0.1634,  0.0675,  0.1090, -0.1899, -0.0123],
                [ 0.0842, -0.0821,  0.1183, -0.0658,  0.1601],
                [-0.1688,  0.1212, -0.0177, -0.1106,  0.1500]]],


              [[[-0.0711,  0.0887,  0.0110, -0.0312,  0.1235],
                [ 0.1727, -0.1303, -0.1418,  0.0785,  0.0870],
                [ 0.0772,  0.0361, -0.1826, -0.0933,  0.0847],
                [ 0.0795,  0.1156, -0.1951, -0.1324, -0.1288],
                [-0.1276, -0.1009, -0.1230, -0.1850,  0.0189]]],


              [[[-0.0376,  0.1030,  0.1858,  0.1619,  0.1681],
                [ 0.0571, -0.1018, -0.1616,  0.1122,  0.1009],
                [ 0.1114,  0.0628, -0.0361,  0.0023,  0.0281],
                [ 0.0725,  0.1015, -0.0013,  0.1214, -0.1805],
                [ 0.1711,  0.1865, -0.1692,  0.0962,  0.0706]]]])
      (bias): Normal:
       loc: tensor([-0., 0., 0., -0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1597,  0.1243,  0.0051, -0.0381,  0.0524,  0.0078])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0.],
             requires_grad=True)
       tensor: tensor([ 0.3275, -0.0609,  0.1431, -0.5142,  0.4180, -0.0330, -0.0895,  0.5254,
               0.2556,  0.1222, -0.0713,  0.4351,  0.2290, -0.3585, -0.0342,  0.0564],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.]],

               [[-0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., 0., -0.]],

               [[-0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., 0., -0.]],

               [[-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., 0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., -0., 0., 0., -0.]],

               [[0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.]]],


              [[[-0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., 0., -0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[0., -0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., -0., -0.]]],


              [[[0., -0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.]],

               [[0., -0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.]],

               [[0., 0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., -0., -0., -0.]]],


              ...,


              [[[0., 0., -0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[0., 0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., -0., -0.]],

               [[-0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.]],

               [[-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., -0., 0., 0.]],

               [[0., 0., -0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.]],

               [[0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., -0.]],

               [[0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., 0., -0.]]],


              [[[-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., 0.]],

               [[0., 0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.]],

               [[0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., -0., 0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 0.0228, -0.0484,  0.0637, -0.0249, -0.0302],
                [ 0.0105,  0.0733,  0.0576,  0.0393,  0.0094],
                [-0.0590, -0.0747, -0.0506, -0.0495,  0.0402],
                [-0.0647,  0.0789, -0.0096,  0.0477,  0.0700],
                [-0.0450,  0.0176,  0.0443, -0.0176,  0.0407]],

               [[-0.0148,  0.0075, -0.0554,  0.0808,  0.0812],
                [ 0.0636, -0.0776,  0.0617,  0.0497,  0.0760],
                [-0.0350,  0.0633,  0.0318,  0.0562,  0.0742],
                [-0.0229, -0.0075,  0.0338,  0.0709, -0.0382],
                [ 0.0655,  0.0249,  0.0556,  0.0329, -0.0734]],

               [[-0.0531,  0.0393,  0.0355, -0.0358,  0.0298],
                [-0.0772, -0.0365,  0.0397, -0.0093, -0.0571],
                [ 0.0087,  0.0256,  0.0568,  0.0123,  0.0098],
                [ 0.0070,  0.0767,  0.0580,  0.0573,  0.0128],
                [ 0.0476, -0.0090,  0.0190,  0.0597, -0.0108]],

               [[-0.0097, -0.0755,  0.0251,  0.0349,  0.0423],
                [ 0.0800,  0.0594,  0.0019,  0.0419, -0.0641],
                [-0.0501, -0.0639, -0.0783,  0.0276,  0.0425],
                [ 0.0793, -0.0161,  0.0031,  0.0700, -0.0077],
                [ 0.0717,  0.0564,  0.0130,  0.0126,  0.0055]],

               [[-0.0711, -0.0004, -0.0350, -0.0729, -0.0567],
                [-0.0351,  0.0119,  0.0163,  0.0542,  0.0558],
                [-0.0308,  0.0337,  0.0634,  0.0556,  0.0051],
                [-0.0305, -0.0103, -0.0515, -0.0746,  0.0701],
                [ 0.0014, -0.0249,  0.0724,  0.0282, -0.0773]],

               [[ 0.0647,  0.0735, -0.0460,  0.0321,  0.0060],
                [-0.0624, -0.0262,  0.0131, -0.0270, -0.0768],
                [ 0.0053,  0.0301,  0.0262, -0.0012,  0.0230],
                [ 0.0758, -0.0297, -0.0011, -0.0424,  0.0659],
                [ 0.0155, -0.0027, -0.0785, -0.0715,  0.0603]]],


              [[[-0.0680,  0.0027, -0.0140,  0.0451,  0.0738],
                [-0.0323, -0.0084, -0.0281,  0.0522, -0.0210],
                [-0.0525, -0.0399, -0.0072, -0.0577,  0.0280],
                [-0.0374, -0.0055,  0.0143, -0.0258,  0.0404],
                [-0.0804,  0.0563, -0.0096, -0.0485, -0.0052]],

               [[ 0.0772, -0.0417, -0.0619, -0.0375, -0.0217],
                [ 0.0312, -0.0186, -0.0225,  0.0433,  0.0327],
                [ 0.0751, -0.0210,  0.0098,  0.0275,  0.0606],
                [ 0.0308, -0.0484, -0.0063,  0.0799, -0.0184],
                [-0.0195,  0.0014, -0.0263, -0.0651, -0.0203]],

               [[-0.0441, -0.0229, -0.0547,  0.0813, -0.0195],
                [-0.0343,  0.0006,  0.0058,  0.0113,  0.0668],
                [-0.0674,  0.0276, -0.0072,  0.0360, -0.0707],
                [-0.0468,  0.0240, -0.0658,  0.0561,  0.0396],
                [-0.0704, -0.0553,  0.0372, -0.0310,  0.0525]],

               [[-0.0081, -0.0563,  0.0125,  0.0304,  0.0486],
                [-0.0515,  0.0539, -0.0165,  0.0263,  0.0241],
                [-0.0189,  0.0789, -0.0254,  0.0502,  0.0336],
                [-0.0288, -0.0308,  0.0397, -0.0238, -0.0365],
                [ 0.0195,  0.0625, -0.0496,  0.0805,  0.0320]],

               [[ 0.0451,  0.0478, -0.0347, -0.0325,  0.0203],
                [-0.0460, -0.0406,  0.0251,  0.0666,  0.0718],
                [-0.0109, -0.0145,  0.0313,  0.0103,  0.0525],
                [-0.0299, -0.0280, -0.0036,  0.0186,  0.0350],
                [ 0.0073,  0.0463,  0.0292, -0.0732, -0.0485]],

               [[ 0.0059, -0.0394,  0.0745,  0.0484,  0.0053],
                [-0.0341, -0.0371,  0.0608,  0.0058, -0.0537],
                [-0.0605, -0.0791, -0.0253,  0.0762,  0.0762],
                [-0.0059,  0.0153,  0.0580,  0.0198, -0.0173],
                [-0.0174,  0.0578, -0.0368, -0.0736, -0.0631]]],


              [[[ 0.0033, -0.0336,  0.0105,  0.0191,  0.0074],
                [ 0.0714, -0.0613, -0.0475, -0.0452, -0.0098],
                [-0.0617, -0.0515,  0.0353, -0.0375,  0.0360],
                [ 0.0542, -0.0326, -0.0236,  0.0674, -0.0347],
                [ 0.0394,  0.0746, -0.0467,  0.0522,  0.0424]],

               [[ 0.0503, -0.0693,  0.0626, -0.0424, -0.0488],
                [-0.0617, -0.0619, -0.0342,  0.0127, -0.0253],
                [ 0.0319,  0.0270, -0.0089,  0.0175, -0.0308],
                [-0.0551,  0.0107,  0.0212,  0.0026,  0.0019],
                [ 0.0115,  0.0343,  0.0572,  0.0811, -0.0354]],

               [[ 0.0632, -0.0197,  0.0072, -0.0372,  0.0581],
                [-0.0492, -0.0602, -0.0201,  0.0465, -0.0482],
                [ 0.0403, -0.0086, -0.0292, -0.0711, -0.0294],
                [-0.0228, -0.0102,  0.0450,  0.0804,  0.0032],
                [-0.0009,  0.0815,  0.0113,  0.0315, -0.0110]],

               [[-0.0627,  0.0313, -0.0443, -0.0353,  0.0797],
                [-0.0039, -0.0053, -0.0678, -0.0335, -0.0764],
                [-0.0202,  0.0296,  0.0548,  0.0185, -0.0028],
                [ 0.0044,  0.0496,  0.0734, -0.0418, -0.0792],
                [ 0.0025,  0.0171, -0.0681, -0.0224,  0.0077]],

               [[ 0.0424,  0.0476,  0.0079, -0.0167,  0.0523],
                [-0.0335,  0.0003, -0.0560, -0.0448, -0.0173],
                [ 0.0426, -0.0002,  0.0133, -0.0705,  0.0085],
                [ 0.0724,  0.0664,  0.0455,  0.0461, -0.0607],
                [-0.0546,  0.0696, -0.0658,  0.0364, -0.0562]],

               [[-0.0390,  0.0704,  0.0786,  0.0033,  0.0295],
                [-0.0145, -0.0198,  0.0638,  0.0332, -0.0197],
                [ 0.0663,  0.0625, -0.0292,  0.0036, -0.0408],
                [ 0.0201,  0.0786, -0.0615, -0.0266, -0.0627],
                [-0.0323,  0.0509, -0.0516, -0.0624, -0.0025]]],


              ...,


              [[[ 0.0777,  0.0133, -0.0427,  0.0800, -0.0246],
                [-0.0430,  0.0370,  0.0550,  0.0002, -0.0453],
                [-0.0782,  0.0747, -0.0467,  0.0130, -0.0217],
                [ 0.0417,  0.0371,  0.0066,  0.0761,  0.0344],
                [ 0.0590,  0.0188, -0.0681,  0.0579,  0.0579]],

               [[ 0.0803,  0.0284, -0.0485,  0.0127,  0.0668],
                [-0.0049, -0.0313, -0.0693, -0.0646, -0.0111],
                [ 0.0460,  0.0529,  0.0564, -0.0790, -0.0583],
                [-0.0102,  0.0005,  0.0437,  0.0308, -0.0437],
                [ 0.0205,  0.0394, -0.0644,  0.0770,  0.0368]],

               [[ 0.0157, -0.0298, -0.0073,  0.0653, -0.0469],
                [-0.0527,  0.0319, -0.0647,  0.0084, -0.0241],
                [ 0.0004,  0.0514, -0.0025,  0.0630, -0.0294],
                [ 0.0689,  0.0191,  0.0546,  0.0365, -0.0539],
                [-0.0045, -0.0745,  0.0544, -0.0686,  0.0816]],

               [[ 0.0061, -0.0117,  0.0211,  0.0131, -0.0410],
                [ 0.0168, -0.0675, -0.0497,  0.0285, -0.0718],
                [-0.0401, -0.0345,  0.0003, -0.0376, -0.0554],
                [ 0.0065,  0.0611, -0.0163,  0.0418, -0.0605],
                [-0.0081,  0.0791,  0.0381, -0.0200, -0.0402]],

               [[ 0.0350,  0.0182, -0.0750,  0.0471,  0.0045],
                [ 0.0441,  0.0266, -0.0399, -0.0425, -0.0318],
                [-0.0428, -0.0136, -0.0466, -0.0151,  0.0100],
                [-0.0644,  0.0688,  0.0709,  0.0026,  0.0585],
                [-0.0269, -0.0130,  0.0715, -0.0413, -0.0011]],

               [[-0.0537,  0.0477, -0.0147, -0.0553,  0.0459],
                [ 0.0619,  0.0005, -0.0194,  0.0200,  0.0296],
                [-0.0205, -0.0382,  0.0277,  0.0322, -0.0201],
                [-0.0251, -0.0238, -0.0758,  0.0534,  0.0487],
                [-0.0023,  0.0328, -0.0237,  0.0328, -0.0399]]],


              [[[ 0.0501, -0.0288, -0.0173,  0.0392, -0.0050],
                [ 0.0156,  0.0807,  0.0459, -0.0056,  0.0346],
                [-0.0486,  0.0271,  0.0530, -0.0753, -0.0678],
                [ 0.0070,  0.0007, -0.0208, -0.0359,  0.0587],
                [ 0.0728, -0.0304,  0.0334, -0.0355,  0.0323]],

               [[-0.0734,  0.0403,  0.0763,  0.0570, -0.0033],
                [-0.0664,  0.0105, -0.0366,  0.0042,  0.0546],
                [-0.0745,  0.0639, -0.0417, -0.0033,  0.0495],
                [-0.0550,  0.0292, -0.0198, -0.0241,  0.0367],
                [ 0.0107, -0.0213, -0.0517,  0.0664,  0.0738]],

               [[ 0.0123,  0.0034, -0.0046,  0.0620, -0.0356],
                [-0.0580, -0.0810,  0.0235,  0.0637, -0.0323],
                [ 0.0501,  0.0611,  0.0696, -0.0725, -0.0392],
                [-0.0327,  0.0616, -0.0098,  0.0648,  0.0049],
                [-0.0644,  0.0434,  0.0465,  0.0378,  0.0250]],

               [[ 0.0816,  0.0159,  0.0255,  0.0219, -0.0652],
                [ 0.0783,  0.0748, -0.0595,  0.0515, -0.0486],
                [-0.0709, -0.0491, -0.0587, -0.0085, -0.0437],
                [ 0.0395, -0.0117,  0.0683,  0.0806, -0.0066],
                [-0.0332, -0.0257, -0.0023, -0.0359,  0.0064]],

               [[-0.0328, -0.0616, -0.0107, -0.0231, -0.0393],
                [ 0.0030,  0.0048, -0.0813, -0.0253, -0.0723],
                [-0.0680, -0.0350,  0.0409,  0.0464,  0.0235],
                [-0.0085,  0.0688, -0.0767, -0.0011, -0.0570],
                [ 0.0553, -0.0721,  0.0039, -0.0811, -0.0608]],

               [[ 0.0617,  0.0303, -0.0521,  0.0155, -0.0364],
                [-0.0589, -0.0223, -0.0112, -0.0599, -0.0590],
                [ 0.0729,  0.0326,  0.0761, -0.0415, -0.0048],
                [ 0.0036, -0.0197, -0.0393, -0.0060,  0.0785],
                [-0.0679, -0.0750,  0.0671,  0.0385, -0.0260]]],


              [[[-0.0456,  0.0197,  0.0548,  0.0420, -0.0569],
                [ 0.0518, -0.0172, -0.0758, -0.0328,  0.0196],
                [-0.0712, -0.0446,  0.0593, -0.0403, -0.0250],
                [-0.0142, -0.0058, -0.0283,  0.0783,  0.0075],
                [ 0.0755,  0.0161,  0.0319, -0.0562,  0.0378]],

               [[ 0.0572,  0.0773,  0.0243,  0.0638, -0.0472],
                [-0.0081,  0.0225,  0.0298, -0.0442, -0.0075],
                [-0.0814,  0.0369, -0.0680, -0.0471,  0.0187],
                [ 0.0290, -0.0338,  0.0786,  0.0685, -0.0263],
                [-0.0453, -0.0716, -0.0462,  0.0556,  0.0159]],

               [[ 0.0359, -0.0511,  0.0707, -0.0696,  0.0407],
                [-0.0717,  0.0521,  0.0813,  0.0335, -0.0515],
                [-0.0295, -0.0124, -0.0406, -0.0247,  0.0162],
                [-0.0252,  0.0105,  0.0624, -0.0701,  0.0153],
                [-0.0490,  0.0815,  0.0331,  0.0130, -0.0174]],

               [[-0.0801, -0.0390, -0.0246,  0.0187,  0.0752],
                [-0.0654, -0.0242,  0.0666,  0.0303, -0.0114],
                [ 0.0783, -0.0565, -0.0200, -0.0462, -0.0119],
                [ 0.0788,  0.0656,  0.0623,  0.0350,  0.0254],
                [-0.0227,  0.0380, -0.0172,  0.0293, -0.0065]],

               [[-0.0086, -0.0572,  0.0217, -0.0286,  0.0476],
                [ 0.0695, -0.0679,  0.0714,  0.0371,  0.0638],
                [ 0.0099, -0.0652, -0.0545, -0.0068,  0.0805],
                [-0.0506, -0.0737, -0.0110, -0.0198,  0.0047],
                [-0.0288,  0.0730,  0.0794, -0.0033,  0.0242]],

               [[-0.0616, -0.0632, -0.0110, -0.0658, -0.0470],
                [ 0.0425, -0.0136,  0.0665, -0.0201, -0.0727],
                [ 0.0189,  0.0189,  0.0641, -0.0384,  0.0180],
                [ 0.0002, -0.0737,  0.0365,  0.0311, -0.0378],
                [ 0.0789,  0.0037, -0.0582,  0.0148,  0.0323]]]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-3.9858e-02, -6.2551e-02, -1.4050e-02,  5.5231e-02,  1.7201e-02,
              -2.6817e-02,  2.8532e-02, -2.6857e-02, -2.8279e-02,  9.3339e-04,
              -5.8307e-02, -6.2624e-02,  7.1998e-05,  1.1212e-02, -2.0352e-02,
              -5.8423e-02])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([0., -0., -0., 0., 0., 0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0.,
              0., 0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., -0., -0., -0.,
              -0., -0., -0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0.,
              -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.],
             requires_grad=True)
       tensor: tensor([-0.0551, -0.0427, -0.0985, -0.0071,  0.0669, -0.0145, -0.1063,  0.0711,
               0.0262, -0.0387, -0.0015, -0.0757,  0.0055, -0.0922, -0.1753, -0.0271,
              -0.0179,  0.1418,  0.0163, -0.0432, -0.0603,  0.0581,  0.0135, -0.1230,
               0.0199,  0.0857,  0.1111, -0.0646, -0.0754,  0.0303,  0.0280,  0.0519,
               0.1128, -0.1034,  0.0422, -0.0920, -0.0389,  0.0287, -0.1931,  0.0913,
              -0.1261,  0.0266,  0.1357, -0.0125,  0.0096,  0.0660, -0.0504,  0.0232,
               0.1478,  0.0955, -0.0457, -0.2297,  0.0295, -0.0228,  0.0511,  0.0642,
              -0.2107,  0.0015,  0.1223,  0.0589,  0.0315,  0.0577, -0.0832, -0.1663,
              -0.0293, -0.0221,  0.0139, -0.0254,  0.1410, -0.0805,  0.0196,  0.1361,
               0.1635,  0.0949,  0.0235,  0.0668,  0.0409, -0.0313, -0.1028,  0.0487,
               0.2261, -0.0205,  0.0892,  0.1026, -0.0642,  0.0138,  0.1022, -0.0512,
               0.1710,  0.1139,  0.1485,  0.0546, -0.0279,  0.0331, -0.0193,  0.0144,
               0.0365, -0.0167, -0.0164,  0.1642, -0.0179, -0.0689,  0.0483,  0.1306,
              -0.0528,  0.0161,  0.0134,  0.1056, -0.0197, -0.0272,  0.0707,  0.2466,
               0.1049, -0.0313, -0.0176,  0.2206,  0.0065, -0.1687,  0.0019,  0.0558],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0.,  ..., -0., -0., -0.],
              [0., 0., 0.,  ..., 0., -0., 0.],
              [0., 0., 0.,  ..., 0., -0., -0.],
              ...,
              [-0., -0., 0.,  ..., -0., -0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [-0., -0., 0.,  ..., -0., 0., 0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0337, -0.0484,  0.0253,  ..., -0.0286, -0.0485, -0.0091],
              [ 0.0469,  0.0292,  0.0258,  ...,  0.0047, -0.0409,  0.0409],
              [ 0.0367,  0.0313,  0.0040,  ...,  0.0234, -0.0487, -0.0428],
              ...,
              [-0.0139, -0.0203,  0.0175,  ..., -0.0324, -0.0387,  0.0258],
              [ 0.0096,  0.0293,  0.0120,  ...,  0.0264,  0.0297,  0.0347],
              [-0.0126, -0.0436,  0.0311,  ..., -0.0195,  0.0352,  0.0191]])
      (bias): Normal:
       loc: tensor([0., -0., -0., 0., 0., 0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0.,
              0., 0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., 0., -0., -0., -0., -0.,
              -0., -0., -0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0.,
              -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0227, -0.0239, -0.0015,  0.0243,  0.0179,  0.0240,  0.0462,  0.0295,
              -0.0409, -0.0290, -0.0125, -0.0128, -0.0085,  0.0326, -0.0219,  0.0423,
               0.0489, -0.0179, -0.0246, -0.0301,  0.0485, -0.0077,  0.0185,  0.0481,
              -0.0032, -0.0149, -0.0211, -0.0451,  0.0282, -0.0297,  0.0339, -0.0216,
              -0.0034, -0.0181, -0.0157,  0.0106, -0.0460, -0.0273, -0.0288, -0.0371,
              -0.0023, -0.0038,  0.0074,  0.0185, -0.0443,  0.0172, -0.0478,  0.0460,
               0.0363,  0.0282, -0.0387, -0.0109, -0.0231,  0.0013, -0.0196,  0.0272,
               0.0049,  0.0265, -0.0020, -0.0435,  0.0185, -0.0403,  0.0289,  0.0379,
              -0.0112, -0.0080, -0.0427,  0.0491, -0.0431, -0.0402, -0.0102, -0.0105,
              -0.0474, -0.0193, -0.0236, -0.0411,  0.0166,  0.0335, -0.0161, -0.0324,
              -0.0196,  0.0304, -0.0400,  0.0024,  0.0160,  0.0193,  0.0080, -0.0252,
               0.0398, -0.0498,  0.0386,  0.0138,  0.0152,  0.0196,  0.0355, -0.0123,
              -0.0179,  0.0390,  0.0361, -0.0140, -0.0484,  0.0458,  0.0205, -0.0043,
              -0.0300,  0.0102, -0.0160, -0.0108, -0.0162,  0.0330,  0.0324, -0.0429,
               0.0008,  0.0134,  0.0364, -0.0246,  0.0498,  0.0140, -0.0339,  0.0392])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.7071, 0.7071], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([0., -0.], requires_grad=True)
       tensor: tensor([-0.8370,  0.1714], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., -0., -0., 0., 0., 0., -0., -0., 0., 0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., 0., 0., 0.,
               -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., -0., 0., -0., -0.,
               -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., -0., -0., -0.,
               -0., 0., -0., -0., -0., 0., -0., 0., -0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
               0., -0., 0., -0., -0., 0., 0., -0., -0., -0., 0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., 0., -0., -0.],
              [0., -0., 0., 0., 0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., -0.,
               0., 0., -0., 0., 0., 0., 0., -0., 0., -0., -0., -0., 0., -0., 0., 0., -0., 0., 0., 0., 0., -0., 0., 0.,
               -0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., -0.,
               -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., -0., 0., 0., 0.,
               0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., 0., 0., 0., -0., -0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0552, -0.0814, -0.0594, -0.0877,  0.0260,  0.0764,  0.0603, -0.0690,
               -0.0617,  0.0816,  0.0810,  0.0872, -0.0560,  0.0101,  0.0851, -0.0210,
               -0.0042,  0.0030,  0.0784, -0.0529,  0.0262,  0.0469,  0.0635,  0.0703,
               -0.0497, -0.0707, -0.0536, -0.0544,  0.0331, -0.0068,  0.0086, -0.0108,
               -0.0055, -0.0076,  0.0049, -0.0805,  0.0471,  0.0136, -0.0250, -0.0345,
                0.0100, -0.0639,  0.0511,  0.0064, -0.0230,  0.0270, -0.0317, -0.0259,
               -0.0307, -0.0506, -0.0617, -0.0262, -0.0688, -0.0900,  0.0328,  0.0166,
               -0.0635,  0.0072,  0.0536, -0.0793, -0.0726,  0.0021, -0.0797,  0.0413,
                0.0666,  0.0185, -0.0274, -0.0572, -0.0811, -0.0831, -0.0043, -0.0517,
               -0.0639,  0.0666, -0.0211, -0.0832, -0.0887,  0.0568, -0.0423,  0.0110,
               -0.0736,  0.0800,  0.0822,  0.0823,  0.0557,  0.0646, -0.0729,  0.0857,
                0.0337, -0.0079,  0.0632, -0.0813, -0.0178, -0.0147,  0.0018,  0.0151,
                0.0909, -0.0895,  0.0524, -0.0362, -0.0328,  0.0500,  0.0494, -0.0450,
               -0.0204, -0.0412,  0.0766, -0.0161, -0.0584, -0.0680,  0.0278,  0.0007,
               -0.0566,  0.0467,  0.0536, -0.0230,  0.0731,  0.0413, -0.0785, -0.0119],
              [ 0.0685, -0.0193,  0.0604,  0.0138,  0.0828,  0.0634, -0.0749,  0.0419,
               -0.0212, -0.0736, -0.0009,  0.0376, -0.0012,  0.0102, -0.0813, -0.0153,
               -0.0003, -0.0116,  0.0483, -0.0689, -0.0361, -0.0136,  0.0256, -0.0330,
                0.0639,  0.0103, -0.0673,  0.0759,  0.0751,  0.0812,  0.0010, -0.0302,
                0.0461, -0.0861, -0.0432, -0.0070,  0.0077, -0.0529,  0.0748,  0.0328,
               -0.0610,  0.0524,  0.0129,  0.0665,  0.0886, -0.0200,  0.0524,  0.0696,
               -0.0629, -0.0878,  0.0708,  0.0556, -0.0741,  0.0161, -0.0897,  0.0735,
                0.0793, -0.0354,  0.0309, -0.0521,  0.0006,  0.0265, -0.0274,  0.0792,
                0.0860, -0.0430, -0.0282,  0.0335, -0.0274,  0.0322,  0.0616, -0.0157,
               -0.0142,  0.0187,  0.0102, -0.0078,  0.0554, -0.0854, -0.0591,  0.0875,
               -0.0630, -0.0741, -0.0793,  0.0149,  0.0818, -0.0127, -0.0881,  0.0015,
               -0.0594,  0.0065, -0.0806,  0.0295, -0.0144,  0.0879,  0.0663,  0.0900,
                0.0258, -0.0155,  0.0731, -0.0102, -0.0611, -0.0473, -0.0538,  0.0004,
               -0.0415, -0.0457, -0.0481, -0.0759,  0.0621, -0.0188,  0.0160,  0.0484,
               -0.0819, -0.0031,  0.0558,  0.0735,  0.0219, -0.0744, -0.0153, -0.0397]])
      (bias): Normal:
       loc: tensor([0., -0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0430, -0.0439])
    )
    (observed): Observed()
  )
)

One can also set the posterior when one creates the module

nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))

Out:

Linear(
  in_features=10, out_features=10, bias=True
  (posterior): Normal(
    (weight): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498]], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([[0., -0., 0., 0., 0., 0., -0., -0., -0., -0.],
            [0., -0., 0., 0., -0., 0., 0., 0., -0., 0.],
            [0., -0., -0., -0., 0., -0., 0., 0., 0., 0.],
            [0., -0., 0., 0., 0., -0., 0., 0., 0., -0.],
            [-0., 0., 0., 0., 0., -0., -0., 0., -0., -0.],
            [0., 0., -0., 0., -0., 0., 0., -0., -0., -0.],
            [0., -0., 0., 0., 0., -0., -0., -0., 0., 0.],
            [-0., -0., -0., 0., -0., 0., 0., -0., -0., -0.],
            [0., 0., 0., 0., 0., -0., -0., 0., 0., 0.],
            [-0., -0., 0., 0., 0., -0., 0., -0., -0., -0.]], requires_grad=True)
     tensor: tensor([[-1.7528e-02,  4.8843e-02, -4.4339e-03,  2.3555e-03,  9.4015e-02,
             -3.7064e-02, -7.4909e-02,  3.4505e-02, -8.3809e-02, -2.4774e-02],
            [ 3.2571e-02, -4.0296e-02, -3.0393e-02, -4.4918e-02,  9.7954e-05,
              5.2636e-02, -8.8278e-03,  3.9501e-02,  6.7966e-02, -2.8009e-02],
            [-4.1282e-02, -6.1663e-02,  4.1791e-02,  1.3551e-02,  1.2286e-04,
             -7.7635e-02, -4.3325e-03,  6.4397e-03,  1.2994e-02,  1.6747e-02],
            [-2.2902e-02,  1.1335e-02,  3.2603e-02, -9.1162e-02, -1.1745e-02,
              1.8527e-02,  1.5697e-02, -2.6060e-02, -5.6732e-02, -1.4607e-02],
            [-1.3031e-03,  1.1841e-02,  9.7400e-02, -1.0226e-01, -6.9819e-02,
             -5.1301e-03,  1.1550e-02, -3.9957e-03,  1.5816e-02, -1.1080e-02],
            [-3.9156e-02, -6.4962e-02,  1.4351e-02, -2.4294e-02, -1.5933e-02,
              5.4919e-02, -9.4059e-02,  4.6750e-02, -1.4138e-02,  1.1048e-02],
            [-6.7411e-02,  3.9682e-02, -2.6654e-02, -1.5400e-02,  1.3210e-02,
             -1.7574e-02, -9.3503e-02,  9.6496e-02,  1.4712e-02, -3.1450e-02],
            [ 6.4931e-02,  3.1781e-02, -2.1908e-02, -7.8900e-02,  2.1049e-02,
             -5.4553e-02,  4.7148e-02,  3.0550e-03, -8.0232e-02,  7.5974e-02],
            [-4.0818e-02, -3.0395e-02, -7.0800e-02,  3.9479e-02,  6.9372e-02,
              1.2751e-04,  1.8037e-02, -6.9233e-02, -6.7140e-02,  6.6097e-03],
            [ 3.6661e-02, -9.0289e-02, -3.2585e-02, -7.5074e-02, -2.3557e-02,
             -4.3972e-02, -6.3312e-03, -9.9403e-03,  3.9945e-02,  2.5346e-02]],
           grad_fn=<AddBackward0>)
    (bias): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
            0.0498], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([-0., 0., -0., -0., -0., -0., 0., 0., 0., -0.], requires_grad=True)
     tensor: tensor([-0.0109,  0.0057, -0.0320, -0.0700, -0.0560,  0.0503, -0.0742, -0.1038,
            -0.0172, -0.0039], grad_fn=<AddBackward0>)
  )
  (prior): Module(
    (weight): Normal:
     loc: tensor([[0., -0., 0., 0., 0., 0., -0., -0., -0., -0.],
            [0., -0., 0., 0., -0., 0., 0., 0., -0., 0.],
            [0., -0., -0., -0., 0., -0., 0., 0., 0., 0.],
            [0., -0., 0., 0., 0., -0., 0., 0., 0., -0.],
            [-0., 0., 0., 0., 0., -0., -0., 0., -0., -0.],
            [0., 0., -0., 0., -0., 0., 0., -0., -0., -0.],
            [0., -0., 0., 0., 0., -0., -0., -0., 0., 0.],
            [-0., -0., -0., 0., -0., 0., 0., -0., -0., -0.],
            [0., 0., 0., 0., 0., -0., -0., 0., 0., 0.],
            [-0., -0., 0., 0., 0., -0., 0., -0., -0., -0.]])
     scale: tensor([[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([[ 0.1768, -0.0492,  0.0335,  0.2470,  0.2241,  0.2600, -0.0639, -0.0142,
             -0.1177, -0.2890],
            [ 0.0065, -0.1177,  0.1118,  0.3136, -0.1325,  0.2423,  0.1595,  0.0413,
             -0.2334,  0.1953],
            [ 0.0350, -0.2088, -0.3060, -0.2115,  0.0493, -0.0122,  0.1962,  0.0292,
              0.1152,  0.1241],
            [ 0.0738, -0.0087,  0.2692,  0.3067,  0.0663, -0.1652,  0.0351,  0.0828,
              0.0910, -0.1150],
            [-0.2252,  0.2810,  0.0442,  0.0448,  0.3011, -0.0721, -0.0053,  0.0346,
             -0.2713, -0.1514],
            [ 0.1203,  0.0920, -0.1798,  0.2911, -0.1815,  0.0126,  0.0843, -0.2110,
             -0.2804, -0.1401],
            [ 0.1194, -0.1758,  0.1153,  0.1420,  0.0124, -0.0216, -0.1932, -0.0692,
              0.0776,  0.2156],
            [-0.0955, -0.1280, -0.2442,  0.2149, -0.1078,  0.2254,  0.0832, -0.2987,
             -0.0894, -0.0538],
            [ 0.0035,  0.1922,  0.1726,  0.1055,  0.0854, -0.1505, -0.2445,  0.0725,
              0.1223,  0.1330],
            [-0.1026, -0.2841,  0.2254,  0.2714,  0.1426, -0.2921,  0.0770, -0.0995,
             -0.2960, -0.2719]])
    (bias): Normal:
     loc: tensor([-0., 0., -0., -0., -0., -0., 0., 0., 0., -0.])
     scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
            0.3162])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([-0.1022,  0.2749, -0.0217, -0.0210, -0.0487, -0.2076,  0.0649,  0.1562,
             0.2436, -0.1976])
  )
  (observed): Observed()
)

See the borch.posterior documentation for other posteriors and what parameters you can set. Note that all posteriors does not work with all parameters but you can have different posteriors for the different borch.Module’s in your network.

Exercises

  1. Use what you have learned to train an image classifier for MNIST, you should achieve an accuracy larger than 98 %. Note: you can access MNST using torchvision.datasets.MNIST.

  2. Fit the same model architecture with normal torch and compare the likelihood with the borch network, What are the differences and why?

  3. Port the model to CIFAR and see how you can improve the accuracy.

  4. Show how the Categorical distribution is related to the cross entropy loss function that is commonly used in frequentest deep learning.

Total running time of the script: ( 0 minutes 0.226 seconds)

Gallery generated by Sphinx-Gallery