Training an Image classifier

You will learn the basics of how to create an image classifier using the borch.nn package and fit it using the infer package.

Lets start of with importing what we need

import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F

The module borch.nn provides implementations of neural network modules that are used for deep probabilistic programming. It provides an interface almost identical to the torch.nn modules and in many cases it is possible to just switch

from torch import nn

to

from borch import nn

Data

In this example we will use simulated data and not run the fitting until convergence, but show how the model is set up and how one can construct the training loop. We will just generate some random data, where data represent the image and target is the class.

data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)

Model

Lets set up the model. In order to use infer and the borch to the fullest, we need to select a a likelihood distribution. For classification the distributions.Categorical is suitable.

class Net(borch.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=borch.posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 2)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return self.classification

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-0.1567,  0.0194,  0.0769, -0.0673, -0.0263],
                [-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
                [-0.1941,  0.1049, -0.0666,  0.1288,  0.1553],
                [-0.0577,  0.0011,  0.0145, -0.0518, -0.1235],
                [-0.1586,  0.1436, -0.1999, -0.0364, -0.0872]]],


              [[[-0.0281,  0.1596, -0.0365, -0.0633, -0.1375],
                [ 0.0499, -0.0787,  0.1923, -0.1312,  0.0520],
                [ 0.0351, -0.0756, -0.0147,  0.1359,  0.1740],
                [ 0.0815,  0.0631,  0.0970,  0.0557, -0.1079],
                [-0.0794, -0.0029, -0.1620, -0.1147,  0.0705]]],


              [[[ 0.1613,  0.0524, -0.1935,  0.0940, -0.0838],
                [-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
                [ 0.1950, -0.1447, -0.1897,  0.1429,  0.0565],
                [ 0.1752,  0.0891, -0.0210, -0.0461,  0.1512],
                [ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],


              [[[ 0.1134,  0.0826, -0.1147, -0.1228, -0.0040],
                [-0.1923,  0.1244,  0.1793,  0.0183, -0.0433],
                [-0.1840, -0.1691,  0.1184,  0.0151, -0.1467],
                [-0.1828,  0.0363,  0.1660, -0.0660, -0.1319],
                [-0.1644,  0.1835, -0.0681, -0.0800,  0.0668]]],


              [[[ 0.0005,  0.1977,  0.1792,  0.0681,  0.1410],
                [ 0.0081, -0.0216, -0.0456, -0.1985,  0.0049],
                [-0.0568, -0.1405,  0.0604, -0.1020,  0.0084],
                [-0.1587, -0.1756,  0.1148,  0.0872, -0.1307],
                [ 0.1244,  0.0193,  0.0314,  0.0787, -0.1329]]],


              [[[-0.1363,  0.0538,  0.0294,  0.0635, -0.0873],
                [-0.0751, -0.0357, -0.0701,  0.0053, -0.1896],
                [-0.0187, -0.0438, -0.0541, -0.1959,  0.0103],
                [-0.1991,  0.0912,  0.1853, -0.0772,  0.0883],
                [-0.1368,  0.0420,  0.0322, -0.1162, -0.0013]]]], requires_grad=True)
       tensor: tensor([[[[-0.1464,  0.0288,  0.1006, -0.0143, -0.0394],
                [-0.0963, -0.0275, -0.0248, -0.0250, -0.0185],
                [-0.1034,  0.1739, -0.1335,  0.1078,  0.1196],
                [-0.0705, -0.0096, -0.0125, -0.0423,  0.0023],
                [-0.1915,  0.1488, -0.1672, -0.0354, -0.1115]]],


              [[[-0.0702,  0.1708, -0.0110, -0.0176, -0.1317],
                [ 0.1053, -0.0288,  0.2829, -0.0449,  0.0364],
                [ 0.0569, -0.1476, -0.0458,  0.1452,  0.1287],
                [ 0.1165,  0.0803,  0.0561,  0.0671, -0.1254],
                [-0.0583, -0.0479, -0.0769, -0.0862,  0.0184]]],


              [[[ 0.0326,  0.0916, -0.1446,  0.1062, -0.1284],
                [-0.1355, -0.0286,  0.0212, -0.0864, -0.1912],
                [ 0.2453, -0.0990, -0.2416,  0.1289,  0.0360],
                [ 0.1220,  0.2028, -0.0364, -0.0095,  0.2029],
                [ 0.0586, -0.1530, -0.1777, -0.1552, -0.0960]]],


              [[[ 0.1117,  0.1044, -0.1346, -0.0786, -0.0155],
                [-0.1331,  0.1088,  0.1338,  0.0570, -0.0485],
                [-0.1223, -0.1886,  0.1482,  0.0753, -0.1289],
                [-0.1578,  0.0648,  0.1323,  0.0027, -0.1692],
                [-0.1471,  0.2222, -0.0199, -0.1038, -0.0085]]],


              [[[-0.0257,  0.2232,  0.1758, -0.0237,  0.1489],
                [-0.0155,  0.1008, -0.0155, -0.2084,  0.0323],
                [-0.0005, -0.1623,  0.0254, -0.1075, -0.0103],
                [-0.1946, -0.1818,  0.2106,  0.0592, -0.1350],
                [ 0.1469, -0.0297,  0.0166,  0.0733, -0.1201]]],


              [[[-0.1150, -0.0068, -0.0089,  0.0525, -0.0929],
                [-0.0980,  0.0597, -0.0527,  0.0507, -0.2522],
                [-0.0187, -0.0277, -0.0706, -0.2390, -0.0100],
                [-0.1499, -0.0335,  0.1549, -0.1291,  0.0113],
                [-0.0331,  0.1316,  0.0597, -0.1865, -0.0048]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.1428, -0.1690, -0.0722,  0.0667,  0.0933, -0.0337],
             requires_grad=True)
       tensor: tensor([-0.0361, -0.1412, -0.1430,  0.1021,  0.1712,  0.0340],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.]]],


              [[[-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.]]],


              [[[0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., 0.],
                [0., -0., -0., -0., -0.]]],


              [[[0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.]]],


              [[[0., 0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., 0., -0.]]],


              [[[-0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.1567,  0.0194,  0.0769, -0.0673, -0.0263],
                [-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
                [-0.1941,  0.1049, -0.0666,  0.1288,  0.1553],
                [-0.0577,  0.0011,  0.0145, -0.0518, -0.1235],
                [-0.1586,  0.1436, -0.1999, -0.0364, -0.0872]]],


              [[[-0.0281,  0.1596, -0.0365, -0.0633, -0.1375],
                [ 0.0499, -0.0787,  0.1923, -0.1312,  0.0520],
                [ 0.0351, -0.0756, -0.0147,  0.1359,  0.1740],
                [ 0.0815,  0.0631,  0.0970,  0.0557, -0.1079],
                [-0.0794, -0.0029, -0.1620, -0.1147,  0.0705]]],


              [[[ 0.1613,  0.0524, -0.1935,  0.0940, -0.0838],
                [-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
                [ 0.1950, -0.1447, -0.1897,  0.1429,  0.0565],
                [ 0.1752,  0.0891, -0.0210, -0.0461,  0.1512],
                [ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],


              [[[ 0.1134,  0.0826, -0.1147, -0.1228, -0.0040],
                [-0.1923,  0.1244,  0.1793,  0.0183, -0.0433],
                [-0.1840, -0.1691,  0.1184,  0.0151, -0.1467],
                [-0.1828,  0.0363,  0.1660, -0.0660, -0.1319],
                [-0.1644,  0.1835, -0.0681, -0.0800,  0.0668]]],


              [[[ 0.0005,  0.1977,  0.1792,  0.0681,  0.1410],
                [ 0.0081, -0.0216, -0.0456, -0.1985,  0.0049],
                [-0.0568, -0.1405,  0.0604, -0.1020,  0.0084],
                [-0.1587, -0.1756,  0.1148,  0.0872, -0.1307],
                [ 0.1244,  0.0193,  0.0314,  0.0787, -0.1329]]],


              [[[-0.1363,  0.0538,  0.0294,  0.0635, -0.0873],
                [-0.0751, -0.0357, -0.0701,  0.0053, -0.1896],
                [-0.0187, -0.0438, -0.0541, -0.1959,  0.0103],
                [-0.1991,  0.0912,  0.1853, -0.0772,  0.0883],
                [-0.1368,  0.0420,  0.0322, -0.1162, -0.0013]]]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., 0., -0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1428, -0.1690, -0.0722,  0.0667,  0.0933, -0.0337])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02,  1.6377e-02,  6.2508e-02],
                [ 4.0202e-02,  4.9382e-02,  1.0253e-02,  3.4932e-02, -2.9341e-02],
                [-6.4074e-02,  5.2063e-02, -3.8546e-02,  1.5865e-02, -4.8314e-02],
                [ 2.0698e-02,  8.0888e-02,  7.9982e-02,  5.6686e-02,  1.1419e-02],
                [-6.1541e-02,  6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],

               [[ 8.0672e-02, -7.8271e-02,  7.9289e-02,  2.1951e-02,  5.9238e-02],
                [ 1.1237e-02, -5.2598e-02,  4.1551e-02,  4.3546e-02, -5.8132e-02],
                [ 3.0288e-02, -7.5402e-03,  2.8025e-02,  1.5065e-02, -3.7260e-02],
                [-2.1855e-02,  4.4424e-02, -6.7250e-02,  4.3303e-02, -4.0667e-02],
                [ 4.9321e-02, -2.0541e-02,  5.6232e-02, -3.8466e-02,  5.1026e-02]],

               [[-2.6513e-02,  4.9569e-02,  1.8676e-02,  6.9054e-02, -2.6767e-02],
                [-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02,  2.8787e-04],
                [ 1.3084e-02,  4.6926e-02, -1.3894e-02,  2.7116e-02, -7.9155e-02],
                [ 7.5190e-02, -2.2396e-02,  5.5935e-02,  6.8270e-02, -3.2401e-02],
                [-1.0770e-02,  3.6370e-02, -7.0476e-02,  6.4846e-02, -6.7540e-02]],

               [[-3.7310e-02,  5.9737e-02,  5.9160e-05,  3.5063e-02,  7.1916e-02],
                [ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02,  4.9856e-02],
                [-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
                [ 9.3012e-04, -2.4376e-02, -1.1352e-03,  1.1929e-02, -1.3466e-02],
                [ 5.6119e-02,  2.2037e-02,  6.6074e-02,  3.7063e-02,  4.2302e-02]],

               [[-3.6769e-02,  3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
                [-2.8501e-02, -8.0159e-02, -2.8122e-02,  1.8005e-02, -3.7241e-02],
                [ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02,  2.2794e-02],
                [-6.3064e-02,  6.8354e-02,  4.3337e-02, -4.3442e-02, -6.5524e-02],
                [-5.6791e-02,  7.4918e-02,  5.9263e-02, -1.2248e-04, -3.5795e-02]],

               [[-2.1802e-02, -9.4069e-03,  7.6862e-02,  1.4458e-02, -7.0315e-02],
                [-5.9874e-02,  5.0751e-02,  3.9215e-02,  7.7538e-02,  3.2766e-02],
                [-4.4479e-02,  5.2775e-02, -1.9053e-03,  3.2342e-02, -4.6321e-02],
                [ 3.8624e-02,  2.0750e-02,  4.6860e-02, -5.3962e-02,  6.4275e-02],
                [-6.5864e-02, -1.3428e-02,  1.7864e-02, -7.0872e-02, -4.7788e-02]]],


              [[[-6.8434e-03,  1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
                [-4.6419e-03,  7.9605e-02,  2.4481e-02, -1.8153e-02, -6.5677e-02],
                [-4.4087e-02, -3.3966e-02,  7.7397e-02, -2.4133e-02, -2.1281e-02],
                [ 7.2108e-02,  5.7730e-02,  1.8671e-02, -6.6956e-02, -7.8317e-02],
                [ 5.1377e-02,  4.0833e-02,  1.3915e-02, -7.3671e-02, -6.4171e-02]],

               [[-2.2992e-02, -1.5135e-02, -1.6468e-02,  3.4729e-02, -5.0686e-02],
                [ 3.3976e-02,  4.7649e-02,  9.0078e-03,  6.1119e-02,  7.7885e-02],
                [ 2.0299e-02,  4.7323e-02,  5.7702e-02,  4.8393e-02, -4.6623e-02],
                [ 4.0535e-02,  5.1355e-02, -3.0590e-02,  7.3194e-02, -6.9639e-02],
                [-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02,  1.5673e-02]],

               [[-2.4125e-02, -2.1872e-02, -4.1021e-02,  1.5590e-02,  7.0267e-02],
                [-1.2325e-02,  4.8418e-02,  3.7418e-02,  5.6973e-02,  1.5516e-02],
                [-5.0112e-02,  4.1789e-02,  5.5392e-02, -3.5548e-02,  2.8206e-02],
                [ 5.9003e-02,  7.3764e-03,  1.5419e-02, -1.6909e-02, -4.0654e-02],
                [ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],

               [[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
                [-1.6499e-02,  5.8059e-02,  3.5353e-02,  9.1365e-03, -2.8343e-02],
                [ 4.6293e-02,  3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
                [ 8.0904e-02,  1.1077e-02,  4.2119e-02,  5.4343e-02,  2.1999e-02],
                [ 6.2393e-02,  2.9035e-02,  2.6746e-02, -2.8318e-02, -4.9989e-02]],

               [[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
                [-7.4939e-03,  9.3801e-03,  1.3410e-02,  5.7410e-02,  4.4898e-03],
                [-1.7655e-02,  3.3112e-02, -6.4055e-02, -1.3577e-02,  6.3291e-02],
                [-3.9472e-02,  1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
                [-6.3974e-02, -6.9448e-02, -2.1668e-02,  2.4177e-02,  3.0717e-02]],

               [[ 2.8530e-02,  2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
                [ 5.4420e-02, -1.2094e-02,  2.1143e-02, -6.5897e-02, -2.8639e-02],
                [ 4.7260e-02,  4.2903e-02,  7.4266e-02,  6.9400e-02, -6.4887e-02],
                [-7.5132e-02, -5.4750e-02,  4.6103e-03,  3.0465e-02, -6.0162e-02],
                [ 3.2707e-02, -3.7524e-02, -6.7505e-02,  2.1123e-02, -1.5651e-02]]],


              [[[ 6.0358e-02,  5.9896e-02, -3.1081e-02,  5.8683e-02,  1.5452e-02],
                [ 6.0256e-02, -7.1520e-02,  7.8586e-02, -3.8772e-02, -7.1890e-02],
                [-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
                [-3.2572e-02, -6.9616e-02,  2.9291e-02,  7.2235e-02,  2.5144e-02],
                [ 7.6527e-02,  3.1913e-03, -1.8299e-02,  1.5759e-02,  6.3982e-02]],

               [[-6.2217e-02,  2.4136e-02,  6.5684e-02, -5.2996e-02, -8.5318e-03],
                [ 5.4878e-02, -7.0118e-02,  5.6222e-02, -5.8217e-02, -1.2457e-02],
                [-9.0102e-03,  4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
                [ 3.9352e-02,  4.2597e-02,  6.4178e-02, -1.6116e-02,  4.4007e-02],
                [-4.6907e-02,  5.0872e-02,  1.4034e-02, -7.7642e-02, -3.1652e-02]],

               [[ 1.8691e-02,  3.9128e-02, -1.8538e-03,  6.9222e-02,  4.7985e-02],
                [ 2.5163e-02, -3.2308e-02,  5.8934e-02,  6.4200e-02,  7.5079e-02],
                [-3.8752e-02, -6.2834e-02, -1.3630e-02,  4.7745e-02, -3.6710e-02],
                [-6.5912e-02, -7.5509e-02,  2.0538e-03,  6.1806e-02,  4.7332e-02],
                [ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02,  6.5297e-02]],

               [[-4.0377e-02,  2.2776e-02,  2.0396e-03,  6.3307e-02,  7.7342e-02],
                [-8.6686e-03,  6.8417e-02,  4.9833e-02, -5.8394e-02, -6.8530e-02],
                [-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
                [-7.8387e-02,  1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
                [-6.5617e-02,  5.9709e-02,  7.6255e-02,  7.4220e-02, -3.9595e-02]],

               [[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02,  1.1639e-02],
                [ 4.5532e-02,  4.3314e-02,  2.7463e-02, -3.2127e-02, -6.0824e-02],
                [ 4.7108e-02,  1.2112e-02, -1.6862e-02, -5.4160e-02,  4.8685e-03],
                [-2.5893e-02,  6.4832e-02,  3.3282e-02,  4.9884e-02,  6.3713e-02],
                [-1.9860e-02,  1.2712e-02,  7.0452e-04,  4.6135e-02, -5.4728e-02]],

               [[-5.1852e-02, -1.5589e-02, -3.1799e-02,  4.5747e-02, -2.9827e-02],
                [ 6.4932e-02, -5.3074e-02, -4.9272e-02,  1.8426e-02, -4.6095e-02],
                [ 4.1712e-02, -7.9372e-02, -7.9577e-02,  1.2126e-02,  4.2022e-02],
                [ 5.8650e-02,  1.3046e-02, -5.0546e-02,  7.1611e-02,  5.8748e-02],
                [ 4.8559e-02, -8.1399e-02,  6.8672e-02,  7.3071e-02, -4.6508e-02]]],


              ...,


              [[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02,  8.1223e-02],
                [-2.6174e-02,  4.4648e-02, -1.7627e-05,  6.8991e-02,  9.3131e-04],
                [-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
                [-2.9851e-02, -7.2093e-02,  6.1742e-03, -6.0501e-02, -3.0240e-02],
                [-3.6711e-02, -3.9918e-02,  1.8570e-02,  2.7867e-02, -3.9091e-02]],

               [[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02,  2.6511e-02],
                [-1.1918e-02,  1.9399e-02,  1.9778e-03,  4.6715e-02,  7.9662e-02],
                [-3.9779e-02, -3.2971e-02,  2.6502e-03,  6.0599e-02,  6.1761e-02],
                [-7.2092e-02, -3.8731e-02, -1.5203e-02,  3.3408e-02, -7.3232e-02],
                [ 2.1354e-02,  4.9467e-02,  6.6561e-02,  6.4517e-02, -4.9400e-02]],

               [[ 2.5279e-02,  7.5811e-02, -5.6423e-02,  7.1795e-03, -8.0469e-02],
                [ 6.3054e-02,  1.5441e-02, -7.9545e-02,  6.0103e-02,  4.7542e-02],
                [-2.6974e-02,  6.4899e-02, -7.6267e-02,  3.2200e-02,  9.7143e-03],
                [ 4.1850e-02, -1.8550e-02, -5.0626e-02,  3.7149e-02,  7.6128e-02],
                [-2.8989e-02,  2.5409e-03, -3.2850e-02, -1.1957e-02,  4.2580e-04]],

               [[ 4.6736e-02,  7.7891e-02,  2.2977e-02,  3.5759e-02, -7.9195e-02],
                [-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03,  2.4057e-02],
                [-6.7396e-03,  2.7530e-02,  1.1896e-03, -1.6895e-02, -5.0218e-02],
                [ 5.6456e-02, -7.6683e-02,  2.4498e-02, -5.4710e-02,  5.6294e-02],
                [-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02,  4.4595e-02]],

               [[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02,  5.5213e-02],
                [ 1.1116e-02, -3.8443e-02,  4.3369e-02, -7.6902e-02,  5.1385e-02],
                [ 5.6263e-02, -5.4902e-02,  8.0991e-02,  1.6011e-02,  6.6421e-02],
                [-2.4895e-02, -4.4881e-02,  4.6953e-02, -4.1781e-02, -4.2947e-02],
                [ 8.0550e-02, -7.2696e-02, -4.6141e-02,  6.7832e-03, -1.6691e-03]],

               [[ 2.6609e-02, -3.9203e-02, -7.8157e-03,  2.2936e-04, -2.7554e-02],
                [ 4.0520e-02,  1.1102e-02, -2.2165e-02,  6.4671e-02,  1.1872e-02],
                [ 2.5477e-02,  3.2211e-02,  5.6317e-02,  5.1697e-02,  5.5899e-02],
                [-3.0296e-02, -3.9487e-02, -2.5797e-02,  5.7478e-02, -4.8781e-03],
                [-6.3375e-02, -4.3827e-02,  3.5311e-03,  4.7217e-02,  6.8362e-02]]],


              [[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03,  6.8520e-02],
                [ 2.3377e-02, -5.9736e-03, -6.8239e-02,  7.2911e-02, -6.6242e-02],
                [-1.5282e-02,  1.7386e-02,  3.9979e-02, -6.8327e-03, -1.7662e-03],
                [ 5.4649e-02, -4.8377e-03,  7.7069e-02, -8.0424e-02, -2.7894e-02],
                [-6.3750e-02, -2.7770e-02,  5.7462e-02, -1.8159e-02,  5.8960e-02]],

               [[ 1.5038e-02, -8.0078e-02,  1.0708e-02,  2.2493e-02,  2.2514e-02],
                [-2.7322e-02,  4.5916e-02,  7.1295e-02, -5.6998e-02, -5.2429e-02],
                [-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
                [-8.0915e-02,  1.1733e-02,  7.0872e-02,  4.2211e-02,  3.7455e-03],
                [ 5.6451e-02, -2.0291e-02,  5.9699e-02, -3.8810e-02,  9.7062e-03]],

               [[ 7.0948e-02,  7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
                [ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
                [ 6.2984e-02,  7.1278e-02,  1.8568e-02, -8.1057e-02, -7.4301e-02],
                [-1.7063e-02,  3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
                [ 3.3733e-02,  3.2608e-02, -1.8692e-02,  6.1727e-02,  1.0257e-02]],

               [[ 4.3113e-04, -6.9241e-02,  2.2611e-02,  4.1913e-02, -6.6395e-02],
                [ 7.5128e-03, -7.3346e-02,  8.0353e-02,  1.2347e-02,  5.5333e-02],
                [-9.7800e-03,  4.5897e-02,  2.8835e-02, -3.6708e-02,  3.9655e-02],
                [ 2.7716e-02, -7.1659e-02,  7.1108e-03,  1.1511e-02, -4.8559e-02],
                [-3.0865e-02,  7.5560e-02,  2.8310e-02,  7.4005e-02, -5.0888e-02]],

               [[ 7.5087e-02,  6.3344e-02,  5.9466e-02,  1.0437e-02,  9.3939e-03],
                [-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
                [ 1.1036e-02, -2.6916e-02, -6.9722e-02,  5.9740e-02,  4.6108e-02],
                [ 2.0379e-02,  3.6167e-02,  4.8153e-02, -3.0691e-02, -5.5250e-02],
                [ 3.5924e-02,  4.5421e-02, -4.7335e-02,  6.4587e-02, -5.7064e-02]],

               [[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02,  7.1452e-02],
                [ 1.6848e-02, -7.5881e-02,  2.5285e-02,  2.5364e-02, -1.0818e-02],
                [-3.0854e-02, -2.4429e-02, -6.4815e-02,  8.1414e-03, -7.9674e-02],
                [-6.2038e-02,  7.4582e-02, -1.7759e-02,  2.3795e-02, -1.5795e-02],
                [ 3.7823e-02, -3.3319e-04,  7.1363e-03,  7.7572e-02,  4.3771e-02]]],


              [[[-4.8656e-03,  4.3062e-02, -3.2547e-02,  9.7140e-03, -5.3167e-02],
                [ 4.2759e-02, -4.1656e-02,  6.4357e-02,  3.5642e-02, -7.8376e-02],
                [-1.2937e-02,  6.4533e-02,  1.5182e-02,  1.1444e-02, -7.4220e-02],
                [ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
                [-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],

               [[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02,  2.2487e-02],
                [-2.9249e-03,  6.8871e-02,  4.3621e-03,  2.6227e-02,  4.3713e-02],
                [ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
                [-3.5359e-02, -1.3766e-02,  3.6173e-02,  9.1202e-03, -3.9747e-02],
                [-7.2135e-02, -7.3420e-02,  6.0504e-02,  3.1594e-02,  7.6891e-02]],

               [[ 7.2759e-02, -6.5420e-02,  6.7763e-02,  7.2741e-02, -7.4671e-02],
                [-5.5163e-02, -7.5269e-02,  1.3287e-02,  1.8645e-02, -3.4054e-02],
                [ 6.5525e-02, -4.1262e-03,  4.6500e-02, -6.6291e-02,  5.8884e-02],
                [ 3.0486e-02,  3.6131e-03,  1.1222e-02, -3.3646e-02, -6.5889e-02],
                [ 4.7762e-02,  3.6352e-02,  9.7470e-03,  7.7495e-03, -5.5064e-02]],

               [[ 4.2110e-02,  5.1736e-02,  4.9755e-02,  1.8245e-02,  3.1093e-02],
                [-5.8074e-02, -4.1158e-02, -5.9566e-03,  6.2394e-02,  1.6582e-02],
                [-7.2003e-02,  1.4616e-02, -3.5987e-02,  3.0575e-02, -4.4705e-02],
                [-4.7500e-02, -1.9091e-02,  1.2661e-02,  2.4751e-02, -7.1824e-02],
                [ 4.3771e-02,  4.9023e-02,  7.2368e-02, -2.3195e-02, -3.0777e-02]],

               [[-8.2526e-03, -1.3523e-02, -6.9580e-02,  2.5552e-02, -1.5779e-02],
                [ 2.2318e-03,  2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
                [ 7.4568e-02, -4.6680e-02,  7.4333e-02, -6.5834e-02,  8.0266e-02],
                [ 1.0070e-02,  5.4708e-02, -1.4732e-03,  1.9077e-02, -2.5033e-02],
                [-2.7357e-02,  1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],

               [[ 4.0488e-02, -7.1390e-02, -2.3527e-02,  1.3764e-02, -1.5115e-02],
                [-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
                [-7.0937e-02, -5.7632e-02, -1.2339e-02,  5.2566e-02, -5.8696e-02],
                [-6.4373e-02,  5.0876e-02, -6.8186e-02,  6.8750e-02,  4.6615e-02],
                [ 3.0661e-02, -6.8377e-02,  1.7900e-02, -8.8543e-03,  1.4958e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[-7.9040e-02, -3.7379e-03,  3.7772e-02,  3.2235e-03,  5.9213e-02],
                [ 1.1029e-01,  5.3078e-02, -5.1533e-02, -1.9493e-02, -2.0278e-02],
                [-3.1871e-02,  1.9686e-02, -2.9023e-02,  6.5129e-02, -1.2462e-01],
                [ 1.5340e-02,  6.1715e-02,  6.1034e-02,  4.7931e-02,  4.5956e-02],
                [-1.4301e-01,  7.0018e-02, -9.7622e-02,  1.0660e-01, -1.3165e-02]],

               [[ 4.6854e-02, -7.3574e-02,  8.1291e-02, -8.2834e-02,  5.4429e-02],
                [-2.5799e-02, -1.1986e-01,  9.7622e-02, -1.7025e-02,  5.7727e-02],
                [ 1.7480e-02,  1.1082e-02, -2.8427e-02,  3.6260e-02, -3.2276e-02],
                [-3.9838e-02,  2.2212e-02, -9.0553e-02,  8.0943e-02, -1.1631e-02],
                [-1.0137e-02, -8.9539e-02,  8.1702e-02, -6.7811e-02,  1.3221e-01]],

               [[-4.7101e-02, -6.6853e-02, -1.6752e-02,  8.6421e-02, -7.5081e-02],
                [-9.1753e-03, -8.0882e-02, -1.0052e-01, -5.2712e-02,  6.5347e-03],
                [ 1.2138e-01,  8.8738e-02,  2.7613e-02,  1.6467e-02, -2.6266e-02],
                [ 1.4057e-01, -4.7242e-02, -2.1786e-03,  4.5640e-02, -1.3707e-01],
                [-4.3512e-02,  6.7946e-02,  5.3988e-02,  5.0796e-02, -8.4169e-02]],

               [[-3.3406e-02,  8.4323e-02,  1.2269e-02, -8.0027e-03,  6.7707e-02],
                [ 3.4774e-02,  4.1956e-02,  4.8764e-02, -3.4235e-02,  8.3637e-02],
                [-1.4833e-01, -2.8633e-02, -1.3108e-01,  2.1766e-03,  1.2428e-01],
                [-3.0024e-02, -1.8942e-02, -3.2634e-02,  2.2048e-02,  7.3094e-03],
                [ 2.8312e-02,  2.7016e-02,  1.1275e-01, -3.3630e-02,  5.8374e-02]],

               [[-9.6861e-02,  5.6462e-02, -1.3434e-02,  1.3635e-02,  1.2531e-02],
                [ 3.7592e-02, -4.0006e-02, -5.6685e-02,  4.3706e-02, -1.9282e-02],
                [ 6.0806e-02, -1.0441e-01, -9.5292e-02, -8.2102e-03,  3.6336e-02],
                [-7.0097e-03,  2.9375e-02,  2.9469e-02, -1.4813e-02, -1.4279e-01],
                [-6.3974e-02,  1.3383e-01,  9.9203e-02,  3.4012e-02, -3.2517e-02]],

               [[-7.2753e-03,  7.4082e-02,  1.8057e-03, -1.9648e-02, -6.6572e-02],
                [-9.5007e-03,  1.5475e-01, -1.7310e-02,  1.0670e-01,  1.0902e-01],
                [-2.5831e-02,  6.5105e-02, -2.5116e-03,  6.3226e-02, -8.0030e-02],
                [ 9.8436e-02,  1.7270e-02,  2.2965e-03,  1.3588e-02,  9.2589e-02],
                [-1.8574e-01, -4.7728e-02, -5.0609e-02, -4.6957e-02, -1.1836e-01]]],


              [[[ 4.2598e-02, -4.7736e-03, -1.5925e-02, -1.3042e-01, -3.4655e-02],
                [ 1.0344e-01, -8.2122e-02, -6.3094e-02, -1.3021e-02, -1.4217e-01],
                [-2.2429e-02, -8.9024e-02, -2.6338e-02, -4.2198e-02, -3.3970e-02],
                [ 6.1737e-02,  1.4096e-01,  8.1511e-02, -1.1718e-01, -7.9483e-02],
                [ 1.3986e-01,  1.0246e-01,  1.3764e-02, -1.1664e-01, -1.2733e-02]],

               [[ 4.0471e-02, -5.8450e-02, -9.7472e-02, -1.1759e-03, -2.6700e-02],
                [-2.9101e-02, -1.7837e-02, -3.6606e-02,  7.5730e-02,  4.6381e-02],
                [-6.4833e-02,  1.0151e-01,  4.6630e-02, -1.2530e-02,  1.4527e-02],
                [-7.4575e-02,  4.4769e-02,  3.6406e-02,  3.2990e-02, -7.2557e-02],
                [-3.3549e-02, -2.6702e-02, -5.1471e-02, -7.6490e-02,  6.8940e-02]],

               [[-3.6364e-02, -6.5542e-03,  1.8009e-02,  5.2121e-02,  1.3648e-02],
                [ 7.4661e-02,  5.3473e-02,  1.0069e-01,  4.3583e-02, -5.3611e-02],
                [ 5.8234e-02,  1.0366e-01,  3.8814e-02, -1.0320e-01,  1.0476e-01],
                [-5.2903e-03,  9.4834e-02,  8.0787e-02,  2.3429e-02, -7.9876e-02],
                [ 1.9255e-02, -7.6146e-02,  1.4325e-02,  1.1370e-01, -6.8911e-02]],

               [[ 2.2207e-02, -9.1128e-02, -4.1531e-02, -6.9817e-02,  6.1896e-02],
                [-9.3361e-02,  4.2035e-02,  4.2046e-02,  3.0422e-02,  2.4328e-02],
                [ 3.8398e-02, -5.3092e-02, -1.9244e-02, -5.9219e-02, -4.7959e-02],
                [ 1.1761e-01,  8.1741e-03,  9.8403e-02,  9.8197e-02,  3.1676e-03],
                [ 3.4369e-02, -4.2303e-02,  8.5386e-02, -3.3772e-02, -3.5091e-02]],

               [[-2.3272e-02, -4.7407e-02, -7.4221e-02, -4.5022e-02,  6.2653e-02],
                [-3.3159e-02,  8.0423e-02, -3.4854e-02,  1.0360e-01, -6.0444e-02],
                [-6.4923e-02,  1.8660e-03, -5.3824e-02, -3.3233e-02,  5.3356e-02],
                [ 2.6758e-02,  5.6524e-02,  1.4273e-02,  2.0408e-02, -6.0961e-03],
                [-7.6553e-02, -1.5248e-01, -1.6458e-01,  6.0101e-02, -6.5330e-02]],

               [[ 7.4326e-02,  5.1909e-02, -1.0165e-01, -5.6569e-02,  3.2163e-03],
                [ 2.2902e-02,  4.9499e-02,  1.2251e-01, -1.2251e-02, -8.8281e-02],
                [ 7.0423e-02,  7.8330e-03,  1.1058e-01,  4.6035e-02,  3.0721e-02],
                [-1.3987e-01, -2.5666e-02,  6.7374e-03,  1.6375e-02, -1.1733e-01],
                [ 7.9805e-02, -8.5537e-02,  5.2054e-02,  2.3272e-02,  1.0554e-01]]],


              [[[ 7.6109e-02,  1.1910e-01,  5.8049e-03,  1.1300e-01,  4.7126e-02],
                [ 4.7814e-02, -7.1319e-02,  9.9547e-02, -3.3557e-02, -1.1612e-01],
                [-6.4900e-02, -3.6347e-02, -9.5536e-02, -3.6929e-02,  7.2433e-03],
                [-8.8529e-02, -5.4179e-02,  1.1549e-01,  5.7148e-02,  2.8013e-02],
                [ 1.2830e-01,  4.1872e-02, -3.4477e-02,  8.0472e-02,  6.5879e-02]],

               [[-4.7669e-02, -9.9669e-03,  7.2031e-02, -7.7567e-02, -8.9658e-02],
                [ 1.5114e-01, -9.9826e-02,  2.3776e-02, -2.5093e-02, -1.1878e-01],
                [-4.9512e-02,  1.5042e-01, -5.2417e-02, -3.5301e-03, -5.2818e-02],
                [ 9.4058e-03,  1.0676e-02,  8.0632e-02,  2.8295e-02,  6.7808e-02],
                [-5.5620e-02,  1.2470e-01,  7.8352e-02,  1.2027e-02, -4.2022e-02]],

               [[ 1.0013e-01,  1.1369e-01, -1.0589e-01, -2.8328e-02,  1.3427e-01],
                [ 2.5859e-02,  7.2207e-02,  1.0146e-02, -4.0904e-02, -1.6968e-02],
                [-1.2773e-02, -6.0562e-02, -1.5470e-02,  7.2821e-02,  1.7777e-02],
                [-8.7816e-02,  1.0456e-03,  7.3066e-02, -9.0929e-03,  4.0644e-03],
                [ 1.4177e-02, -4.2387e-02,  1.8873e-02, -2.0339e-02,  1.3091e-03]],

               [[-1.1817e-01,  7.3392e-03,  2.8750e-04,  1.5765e-01,  1.0838e-01],
                [-8.5690e-02,  7.2779e-03,  1.7669e-01,  3.7189e-03, -7.1858e-02],
                [-1.0172e-01,  2.2364e-02,  8.7263e-03, -6.2377e-02, -6.5125e-02],
                [ 1.8006e-02, -9.0645e-02,  4.0923e-02, -8.8384e-02, -3.2109e-02],
                [-1.1282e-01,  7.5887e-02, -2.0569e-02,  6.8537e-02, -9.1945e-02]],

               [[ 8.7547e-03, -7.3805e-02, -2.4725e-02, -9.7951e-02,  1.4958e-02],
                [ 1.5039e-03,  3.5538e-02,  6.9676e-03, -3.4485e-02, -7.6818e-02],
                [ 5.0455e-02, -1.6215e-02, -4.2158e-02, -4.7410e-02,  1.1642e-01],
                [ 1.3294e-02,  1.3026e-01, -1.5100e-02,  8.2574e-02,  7.0877e-02],
                [-8.0356e-02,  5.9577e-02, -2.4652e-02,  4.9606e-03, -5.2893e-02]],

               [[-2.6808e-02, -4.1903e-02,  1.0307e-01,  4.8184e-02,  5.5327e-02],
                [ 1.1018e-01, -1.9608e-02, -7.4902e-02,  4.5365e-02,  6.8104e-03],
                [ 1.6191e-02, -1.0323e-01, -1.4568e-01,  1.1337e-01,  1.1910e-01],
                [ 2.0553e-02,  2.9168e-02, -3.5671e-02,  9.2641e-02, -4.0162e-02],
                [ 8.6601e-02, -1.6433e-01,  1.2801e-01,  3.8804e-02, -9.1625e-02]]],


              ...,


              [[[ 6.2546e-02, -1.4037e-02, -6.0003e-02, -1.2287e-01,  1.0268e-01],
                [-6.6390e-02,  1.0820e-01,  3.2333e-02,  9.1314e-02, -2.1904e-02],
                [-7.6083e-02, -8.3289e-02,  2.6048e-02,  3.8830e-02, -7.6226e-02],
                [-3.3048e-02, -3.1125e-02, -1.9828e-02, -7.9846e-02, -6.0296e-02],
                [-5.7660e-02,  2.9841e-02,  4.6691e-02, -5.9227e-02, -8.6609e-03]],

               [[ 1.9611e-02, -6.0364e-02, -6.5186e-02, -1.4345e-01,  6.7694e-02],
                [ 2.5685e-03, -2.5887e-02, -3.7028e-02,  8.4842e-02,  5.7676e-02],
                [-8.0969e-02,  8.7602e-03,  2.1688e-02,  1.3805e-01,  8.9147e-02],
                [-1.3020e-01,  1.3120e-02, -2.7499e-02,  1.6288e-02, -8.0340e-02],
                [ 4.6616e-02,  5.5087e-02,  1.2293e-01,  3.0804e-02, -3.4891e-02]],

               [[-2.7551e-02,  1.1154e-01, -1.7306e-01,  6.3198e-02, -8.9437e-02],
                [ 1.0023e-01,  7.0259e-02, -1.0258e-01, -5.0549e-02,  1.6971e-02],
                [ 2.5675e-02, -1.1538e-02, -3.6890e-02,  1.8445e-02,  2.4550e-02],
                [ 6.6892e-02,  7.9879e-02, -2.2027e-02,  7.8564e-03,  1.3598e-01],
                [-1.3008e-02,  5.2810e-02, -1.6301e-01, -4.4813e-02,  2.5996e-02]],

               [[ 8.5186e-02,  9.1058e-02, -1.3829e-02,  3.8593e-02, -4.9633e-02],
                [-7.4777e-02, -1.3623e-01, -4.2233e-02, -2.4843e-02,  4.7866e-02],
                [ 1.1859e-01, -3.9103e-02,  6.5936e-02, -6.2660e-03, -8.1931e-02],
                [ 6.5915e-02, -3.5195e-02, -8.3154e-02,  2.5878e-02,  1.0228e-01],
                [-2.0687e-02,  4.4530e-02, -9.8919e-03,  2.0240e-02,  1.0023e-01]],

               [[-3.2076e-02,  3.3811e-02, -4.0684e-02, -1.0620e-01,  4.5157e-02],
                [ 1.5433e-02, -5.7362e-02, -2.0055e-02, -1.7903e-01,  3.5422e-02],
                [ 5.4437e-02, -6.6113e-02,  5.3977e-02, -1.0462e-02,  6.6393e-02],
                [-4.2926e-02, -9.2950e-02,  7.8664e-03, -1.4528e-01, -7.3950e-02],
                [ 1.1710e-01, -7.6187e-03, -3.6296e-02, -3.4596e-02, -8.7469e-02]],

               [[-1.5697e-02,  1.8503e-03, -1.0732e-02,  4.4367e-02,  6.3441e-02],
                [ 6.3576e-02, -7.5327e-02, -2.1156e-02,  1.0805e-01,  5.9616e-03],
                [-5.1490e-02,  6.6770e-02,  3.8325e-02,  9.7658e-02,  6.7618e-02],
                [-9.1422e-03, -2.5269e-02, -6.2489e-02, -3.4347e-02, -3.2664e-02],
                [-7.2520e-02, -8.9599e-02,  6.2424e-02,  7.9191e-02,  7.9724e-02]]],


              [[[-3.9622e-02, -4.6573e-02, -1.0472e-01,  1.1712e-02,  6.0660e-02],
                [-1.1623e-01, -5.1286e-03, -1.4487e-03,  1.1639e-01, -1.7682e-01],
                [-8.0438e-03, -2.5883e-02, -6.0686e-02, -4.3806e-02, -5.9083e-02],
                [ 6.1406e-02,  2.7634e-02,  8.6463e-02, -2.1471e-02,  1.3222e-03],
                [-3.6193e-02, -1.3771e-01, -3.6332e-02, -1.2857e-02,  4.2172e-02]],

               [[ 6.7233e-02, -1.2348e-01, -1.0973e-02,  9.5008e-02, -2.0650e-02],
                [-5.6706e-02,  6.6555e-02,  7.6050e-02, -1.5699e-01, -6.1664e-02],
                [ 1.7243e-02, -5.0797e-02, -1.3130e-01, -6.8116e-02, -2.0173e-02],
                [-1.7038e-01,  6.6691e-02,  1.3256e-02,  6.9085e-02, -2.6799e-02],
                [ 6.1442e-02, -3.6632e-02,  3.3934e-02, -9.6146e-02,  1.9848e-02]],

               [[ 5.0492e-02,  3.8202e-02,  8.3198e-03, -4.8258e-02,  1.4369e-01],
                [ 1.2627e-01, -1.1377e-02,  2.7590e-02, -5.6298e-02, -1.0798e-01],
                [ 9.1512e-02,  3.4963e-02, -3.0616e-02,  7.3307e-02, -8.8587e-02],
                [-4.4653e-02,  9.5914e-02, -3.9202e-02,  1.3692e-02, -4.4345e-02],
                [ 6.6634e-02,  5.8251e-02,  1.8750e-03,  2.9827e-02, -2.3503e-02]],

               [[-2.7796e-02, -6.2559e-02,  1.1090e-01, -2.0772e-02, -6.1026e-02],
                [ 3.4603e-06, -8.2480e-02,  1.2860e-01, -8.0894e-02,  1.0519e-01],
                [-9.9585e-02,  1.1226e-01,  8.5616e-02, -3.4187e-02, -2.4726e-02],
                [ 7.1098e-02, -8.6928e-02,  1.1157e-02,  5.2506e-02, -1.0666e-01],
                [-1.0207e-01,  3.2715e-02,  2.3789e-02,  1.2132e-02, -5.6706e-02]],

               [[ 1.4586e-01,  7.7736e-02,  8.8786e-02,  6.8187e-02, -3.1638e-02],
                [ 3.9524e-03, -1.0926e-01, -1.1489e-01, -8.6353e-02, -1.1245e-01],
                [ 2.9333e-02, -8.5913e-03, -5.3009e-02,  9.1673e-02,  8.5318e-02],
                [ 6.1508e-02,  1.0515e-01,  1.2023e-01, -1.4283e-02,  4.9522e-02],
                [ 3.0185e-02,  1.1706e-01, -2.6336e-02,  4.6261e-02, -5.2317e-02]],

               [[-1.7293e-02,  2.8523e-02, -2.8706e-02, -1.2312e-01,  1.0700e-04],
                [-2.4223e-02, -7.1036e-02,  1.6568e-02,  7.3946e-02,  3.1011e-02],
                [-2.9600e-02, -4.9499e-02, -9.9627e-02,  5.0107e-02, -8.2558e-02],
                [-8.7629e-02,  7.3977e-02, -1.1190e-02,  5.6551e-02,  5.6759e-03],
                [ 7.6054e-02,  2.6358e-02,  5.8867e-03,  4.6918e-02,  8.0736e-02]]],


              [[[ 7.6387e-02,  3.5164e-02, -1.0085e-01,  2.5791e-02,  1.2621e-02],
                [ 4.0133e-02, -6.2389e-02, -7.0186e-03, -1.5747e-02, -5.8938e-02],
                [-2.1233e-02,  5.9552e-02, -1.6451e-02,  5.8491e-02, -1.4009e-01],
                [ 7.7164e-02,  7.3244e-02,  3.9004e-02, -5.1173e-02,  1.8552e-02],
                [-9.8993e-03, -3.2891e-03, -1.2678e-02,  3.2900e-02, -1.6397e-02]],

               [[ 4.8623e-02, -1.6285e-01,  2.6418e-02, -4.4003e-02,  1.0522e-02],
                [ 6.5607e-02, -4.3938e-02, -4.4469e-02,  2.8729e-02,  7.2077e-02],
                [ 8.6301e-02, -5.3798e-02,  2.2424e-03, -8.6657e-02,  2.3395e-02],
                [-5.0502e-02,  2.6603e-02,  6.7704e-02, -7.1994e-02, -1.0376e-01],
                [-1.1587e-01, -4.1993e-02,  4.0082e-02,  8.4338e-03,  5.3974e-02]],

               [[ 9.3545e-03, -3.9613e-02,  1.1917e-01,  5.5167e-02, -5.5481e-02],
                [-7.8236e-02, -3.6044e-02, -6.9676e-02,  2.0344e-02,  8.6851e-02],
                [ 2.0640e-02, -2.4200e-02,  9.8858e-02, -1.3632e-02,  3.1430e-02],
                [ 5.8873e-02,  9.5151e-03,  1.6263e-02,  4.1727e-02, -1.2870e-01],
                [-4.9029e-02,  3.8735e-02,  2.3564e-02, -4.3188e-02, -6.6108e-02]],

               [[ 4.8084e-02,  7.8817e-02,  3.1851e-02,  4.9276e-02,  3.8264e-02],
                [-9.3058e-02, -6.3315e-02, -1.7968e-02,  7.5934e-02,  3.9757e-02],
                [-1.2371e-01,  4.7944e-02, -2.0174e-02, -3.3382e-02, -4.8501e-02],
                [-9.3338e-02, -1.7285e-03,  8.4074e-04, -1.8147e-02, -1.1448e-01],
                [ 6.5431e-03,  1.3819e-01,  5.7124e-02, -5.3819e-02,  2.9704e-02]],

               [[-3.9365e-02,  8.3383e-02, -5.7147e-02, -1.6916e-02,  8.8240e-02],
                [-2.1344e-02,  6.3516e-02, -6.2374e-03, -6.7964e-02, -6.1580e-02],
                [ 3.6108e-02, -8.6570e-02, -5.2997e-02, -6.8391e-02,  7.4299e-02],
                [ 1.0596e-01,  1.0751e-02,  6.2705e-02, -2.0067e-02, -8.1794e-02],
                [-3.2689e-02,  9.9929e-03, -6.4664e-02, -1.9183e-02, -1.2894e-01]],

               [[ 5.6596e-02,  7.1493e-03,  2.1555e-02,  3.2155e-02, -4.1433e-02],
                [-1.1520e-01, -1.6532e-02, -2.5351e-02, -1.0983e-01, -2.3312e-03],
                [-1.0072e-01, -1.6166e-02, -4.6829e-03,  1.3726e-01, -3.1124e-02],
                [-1.3071e-02,  1.0966e-01, -1.1188e-01, -2.2783e-02,  9.0083e-02],
                [ 8.3369e-03, -9.1546e-02,  5.8801e-02, -2.5113e-03, -5.2018e-02]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0482,  0.0700, -0.0483,  0.0543,  0.0427, -0.0160,  0.0652,  0.0702,
              -0.0693, -0.0685,  0.0211,  0.0540,  0.0356,  0.0235,  0.0208,  0.0305],
             requires_grad=True)
       tensor: tensor([-0.0018,  0.0913, -0.0619,  0.0713,  0.0755, -0.0322,  0.0876,  0.1089,
              -0.1069, -0.0383,  0.0331,  0.0550,  0.0829,  0.0486, -0.0027, -0.0946],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., -0., 0.]],

               [[-0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.]],

               [[-0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[-0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.]],

               [[0., 0., -0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.]]],


              [[[0., 0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.]],

               [[0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., 0., 0.],
                [0., -0., 0., 0., -0.]]],


              ...,


              [[[0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., -0.]],

               [[0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.]],

               [[0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., -0., 0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., -0., -0., 0., -0.]],

               [[0., -0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[-0., -0., -0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., 0., -0., 0.]],

               [[0., -0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., 0.]],

               [[0., 0., -0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., 0., -0.]],

               [[-0., -0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., 0.]]],


              [[[-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.]],

               [[0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., -0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02,  1.6377e-02,  6.2508e-02],
                [ 4.0202e-02,  4.9382e-02,  1.0253e-02,  3.4932e-02, -2.9341e-02],
                [-6.4074e-02,  5.2063e-02, -3.8546e-02,  1.5865e-02, -4.8314e-02],
                [ 2.0698e-02,  8.0888e-02,  7.9982e-02,  5.6686e-02,  1.1419e-02],
                [-6.1541e-02,  6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],

               [[ 8.0672e-02, -7.8271e-02,  7.9289e-02,  2.1951e-02,  5.9238e-02],
                [ 1.1237e-02, -5.2598e-02,  4.1551e-02,  4.3546e-02, -5.8132e-02],
                [ 3.0288e-02, -7.5402e-03,  2.8025e-02,  1.5065e-02, -3.7260e-02],
                [-2.1855e-02,  4.4424e-02, -6.7250e-02,  4.3303e-02, -4.0667e-02],
                [ 4.9321e-02, -2.0541e-02,  5.6232e-02, -3.8466e-02,  5.1026e-02]],

               [[-2.6513e-02,  4.9569e-02,  1.8676e-02,  6.9054e-02, -2.6767e-02],
                [-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02,  2.8787e-04],
                [ 1.3084e-02,  4.6926e-02, -1.3894e-02,  2.7116e-02, -7.9155e-02],
                [ 7.5190e-02, -2.2396e-02,  5.5935e-02,  6.8270e-02, -3.2401e-02],
                [-1.0770e-02,  3.6370e-02, -7.0476e-02,  6.4846e-02, -6.7540e-02]],

               [[-3.7310e-02,  5.9737e-02,  5.9160e-05,  3.5063e-02,  7.1916e-02],
                [ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02,  4.9856e-02],
                [-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
                [ 9.3012e-04, -2.4376e-02, -1.1352e-03,  1.1929e-02, -1.3466e-02],
                [ 5.6119e-02,  2.2037e-02,  6.6074e-02,  3.7063e-02,  4.2302e-02]],

               [[-3.6769e-02,  3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
                [-2.8501e-02, -8.0159e-02, -2.8122e-02,  1.8005e-02, -3.7241e-02],
                [ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02,  2.2794e-02],
                [-6.3064e-02,  6.8354e-02,  4.3337e-02, -4.3442e-02, -6.5524e-02],
                [-5.6791e-02,  7.4918e-02,  5.9263e-02, -1.2248e-04, -3.5795e-02]],

               [[-2.1802e-02, -9.4069e-03,  7.6862e-02,  1.4458e-02, -7.0315e-02],
                [-5.9874e-02,  5.0751e-02,  3.9215e-02,  7.7538e-02,  3.2766e-02],
                [-4.4479e-02,  5.2775e-02, -1.9053e-03,  3.2342e-02, -4.6321e-02],
                [ 3.8624e-02,  2.0750e-02,  4.6860e-02, -5.3962e-02,  6.4275e-02],
                [-6.5864e-02, -1.3428e-02,  1.7864e-02, -7.0872e-02, -4.7788e-02]]],


              [[[-6.8434e-03,  1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
                [-4.6419e-03,  7.9605e-02,  2.4481e-02, -1.8153e-02, -6.5677e-02],
                [-4.4087e-02, -3.3966e-02,  7.7397e-02, -2.4133e-02, -2.1281e-02],
                [ 7.2108e-02,  5.7730e-02,  1.8671e-02, -6.6956e-02, -7.8317e-02],
                [ 5.1377e-02,  4.0833e-02,  1.3915e-02, -7.3671e-02, -6.4171e-02]],

               [[-2.2992e-02, -1.5135e-02, -1.6468e-02,  3.4729e-02, -5.0686e-02],
                [ 3.3976e-02,  4.7649e-02,  9.0078e-03,  6.1119e-02,  7.7885e-02],
                [ 2.0299e-02,  4.7323e-02,  5.7702e-02,  4.8393e-02, -4.6623e-02],
                [ 4.0535e-02,  5.1355e-02, -3.0590e-02,  7.3194e-02, -6.9639e-02],
                [-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02,  1.5673e-02]],

               [[-2.4125e-02, -2.1872e-02, -4.1021e-02,  1.5590e-02,  7.0267e-02],
                [-1.2325e-02,  4.8418e-02,  3.7418e-02,  5.6973e-02,  1.5516e-02],
                [-5.0112e-02,  4.1789e-02,  5.5392e-02, -3.5548e-02,  2.8206e-02],
                [ 5.9003e-02,  7.3764e-03,  1.5419e-02, -1.6909e-02, -4.0654e-02],
                [ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],

               [[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
                [-1.6499e-02,  5.8059e-02,  3.5353e-02,  9.1365e-03, -2.8343e-02],
                [ 4.6293e-02,  3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
                [ 8.0904e-02,  1.1077e-02,  4.2119e-02,  5.4343e-02,  2.1999e-02],
                [ 6.2393e-02,  2.9035e-02,  2.6746e-02, -2.8318e-02, -4.9989e-02]],

               [[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
                [-7.4939e-03,  9.3801e-03,  1.3410e-02,  5.7410e-02,  4.4898e-03],
                [-1.7655e-02,  3.3112e-02, -6.4055e-02, -1.3577e-02,  6.3291e-02],
                [-3.9472e-02,  1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
                [-6.3974e-02, -6.9448e-02, -2.1668e-02,  2.4177e-02,  3.0717e-02]],

               [[ 2.8530e-02,  2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
                [ 5.4420e-02, -1.2094e-02,  2.1143e-02, -6.5897e-02, -2.8639e-02],
                [ 4.7260e-02,  4.2903e-02,  7.4266e-02,  6.9400e-02, -6.4887e-02],
                [-7.5132e-02, -5.4750e-02,  4.6103e-03,  3.0465e-02, -6.0162e-02],
                [ 3.2707e-02, -3.7524e-02, -6.7505e-02,  2.1123e-02, -1.5651e-02]]],


              [[[ 6.0358e-02,  5.9896e-02, -3.1081e-02,  5.8683e-02,  1.5452e-02],
                [ 6.0256e-02, -7.1520e-02,  7.8586e-02, -3.8772e-02, -7.1890e-02],
                [-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
                [-3.2572e-02, -6.9616e-02,  2.9291e-02,  7.2235e-02,  2.5144e-02],
                [ 7.6527e-02,  3.1913e-03, -1.8299e-02,  1.5759e-02,  6.3982e-02]],

               [[-6.2217e-02,  2.4136e-02,  6.5684e-02, -5.2996e-02, -8.5318e-03],
                [ 5.4878e-02, -7.0118e-02,  5.6222e-02, -5.8217e-02, -1.2457e-02],
                [-9.0102e-03,  4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
                [ 3.9352e-02,  4.2597e-02,  6.4178e-02, -1.6116e-02,  4.4007e-02],
                [-4.6907e-02,  5.0872e-02,  1.4034e-02, -7.7642e-02, -3.1652e-02]],

               [[ 1.8691e-02,  3.9128e-02, -1.8538e-03,  6.9222e-02,  4.7985e-02],
                [ 2.5163e-02, -3.2308e-02,  5.8934e-02,  6.4200e-02,  7.5079e-02],
                [-3.8752e-02, -6.2834e-02, -1.3630e-02,  4.7745e-02, -3.6710e-02],
                [-6.5912e-02, -7.5509e-02,  2.0538e-03,  6.1806e-02,  4.7332e-02],
                [ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02,  6.5297e-02]],

               [[-4.0377e-02,  2.2776e-02,  2.0396e-03,  6.3307e-02,  7.7342e-02],
                [-8.6686e-03,  6.8417e-02,  4.9833e-02, -5.8394e-02, -6.8530e-02],
                [-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
                [-7.8387e-02,  1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
                [-6.5617e-02,  5.9709e-02,  7.6255e-02,  7.4220e-02, -3.9595e-02]],

               [[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02,  1.1639e-02],
                [ 4.5532e-02,  4.3314e-02,  2.7463e-02, -3.2127e-02, -6.0824e-02],
                [ 4.7108e-02,  1.2112e-02, -1.6862e-02, -5.4160e-02,  4.8685e-03],
                [-2.5893e-02,  6.4832e-02,  3.3282e-02,  4.9884e-02,  6.3713e-02],
                [-1.9860e-02,  1.2712e-02,  7.0452e-04,  4.6135e-02, -5.4728e-02]],

               [[-5.1852e-02, -1.5589e-02, -3.1799e-02,  4.5747e-02, -2.9827e-02],
                [ 6.4932e-02, -5.3074e-02, -4.9272e-02,  1.8426e-02, -4.6095e-02],
                [ 4.1712e-02, -7.9372e-02, -7.9577e-02,  1.2126e-02,  4.2022e-02],
                [ 5.8650e-02,  1.3046e-02, -5.0546e-02,  7.1611e-02,  5.8748e-02],
                [ 4.8559e-02, -8.1399e-02,  6.8672e-02,  7.3071e-02, -4.6508e-02]]],


              ...,


              [[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02,  8.1223e-02],
                [-2.6174e-02,  4.4648e-02, -1.7627e-05,  6.8991e-02,  9.3131e-04],
                [-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
                [-2.9851e-02, -7.2093e-02,  6.1742e-03, -6.0501e-02, -3.0240e-02],
                [-3.6711e-02, -3.9918e-02,  1.8570e-02,  2.7867e-02, -3.9091e-02]],

               [[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02,  2.6511e-02],
                [-1.1918e-02,  1.9399e-02,  1.9778e-03,  4.6715e-02,  7.9662e-02],
                [-3.9779e-02, -3.2971e-02,  2.6502e-03,  6.0599e-02,  6.1761e-02],
                [-7.2092e-02, -3.8731e-02, -1.5203e-02,  3.3408e-02, -7.3232e-02],
                [ 2.1354e-02,  4.9467e-02,  6.6561e-02,  6.4517e-02, -4.9400e-02]],

               [[ 2.5279e-02,  7.5811e-02, -5.6423e-02,  7.1795e-03, -8.0469e-02],
                [ 6.3054e-02,  1.5441e-02, -7.9545e-02,  6.0103e-02,  4.7542e-02],
                [-2.6974e-02,  6.4899e-02, -7.6267e-02,  3.2200e-02,  9.7143e-03],
                [ 4.1850e-02, -1.8550e-02, -5.0626e-02,  3.7149e-02,  7.6128e-02],
                [-2.8989e-02,  2.5409e-03, -3.2850e-02, -1.1957e-02,  4.2580e-04]],

               [[ 4.6736e-02,  7.7891e-02,  2.2977e-02,  3.5759e-02, -7.9195e-02],
                [-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03,  2.4057e-02],
                [-6.7396e-03,  2.7530e-02,  1.1896e-03, -1.6895e-02, -5.0218e-02],
                [ 5.6456e-02, -7.6683e-02,  2.4498e-02, -5.4710e-02,  5.6294e-02],
                [-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02,  4.4595e-02]],

               [[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02,  5.5213e-02],
                [ 1.1116e-02, -3.8443e-02,  4.3369e-02, -7.6902e-02,  5.1385e-02],
                [ 5.6263e-02, -5.4902e-02,  8.0991e-02,  1.6011e-02,  6.6421e-02],
                [-2.4895e-02, -4.4881e-02,  4.6953e-02, -4.1781e-02, -4.2947e-02],
                [ 8.0550e-02, -7.2696e-02, -4.6141e-02,  6.7832e-03, -1.6691e-03]],

               [[ 2.6609e-02, -3.9203e-02, -7.8157e-03,  2.2936e-04, -2.7554e-02],
                [ 4.0520e-02,  1.1102e-02, -2.2165e-02,  6.4671e-02,  1.1872e-02],
                [ 2.5477e-02,  3.2211e-02,  5.6317e-02,  5.1697e-02,  5.5899e-02],
                [-3.0296e-02, -3.9487e-02, -2.5797e-02,  5.7478e-02, -4.8781e-03],
                [-6.3375e-02, -4.3827e-02,  3.5311e-03,  4.7217e-02,  6.8362e-02]]],


              [[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03,  6.8520e-02],
                [ 2.3377e-02, -5.9736e-03, -6.8239e-02,  7.2911e-02, -6.6242e-02],
                [-1.5282e-02,  1.7386e-02,  3.9979e-02, -6.8327e-03, -1.7662e-03],
                [ 5.4649e-02, -4.8377e-03,  7.7069e-02, -8.0424e-02, -2.7894e-02],
                [-6.3750e-02, -2.7770e-02,  5.7462e-02, -1.8159e-02,  5.8960e-02]],

               [[ 1.5038e-02, -8.0078e-02,  1.0708e-02,  2.2493e-02,  2.2514e-02],
                [-2.7322e-02,  4.5916e-02,  7.1295e-02, -5.6998e-02, -5.2429e-02],
                [-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
                [-8.0915e-02,  1.1733e-02,  7.0872e-02,  4.2211e-02,  3.7455e-03],
                [ 5.6451e-02, -2.0291e-02,  5.9699e-02, -3.8810e-02,  9.7062e-03]],

               [[ 7.0948e-02,  7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
                [ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
                [ 6.2984e-02,  7.1278e-02,  1.8568e-02, -8.1057e-02, -7.4301e-02],
                [-1.7063e-02,  3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
                [ 3.3733e-02,  3.2608e-02, -1.8692e-02,  6.1727e-02,  1.0257e-02]],

               [[ 4.3113e-04, -6.9241e-02,  2.2611e-02,  4.1913e-02, -6.6395e-02],
                [ 7.5128e-03, -7.3346e-02,  8.0353e-02,  1.2347e-02,  5.5333e-02],
                [-9.7800e-03,  4.5897e-02,  2.8835e-02, -3.6708e-02,  3.9655e-02],
                [ 2.7716e-02, -7.1659e-02,  7.1108e-03,  1.1511e-02, -4.8559e-02],
                [-3.0865e-02,  7.5560e-02,  2.8310e-02,  7.4005e-02, -5.0888e-02]],

               [[ 7.5087e-02,  6.3344e-02,  5.9466e-02,  1.0437e-02,  9.3939e-03],
                [-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
                [ 1.1036e-02, -2.6916e-02, -6.9722e-02,  5.9740e-02,  4.6108e-02],
                [ 2.0379e-02,  3.6167e-02,  4.8153e-02, -3.0691e-02, -5.5250e-02],
                [ 3.5924e-02,  4.5421e-02, -4.7335e-02,  6.4587e-02, -5.7064e-02]],

               [[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02,  7.1452e-02],
                [ 1.6848e-02, -7.5881e-02,  2.5285e-02,  2.5364e-02, -1.0818e-02],
                [-3.0854e-02, -2.4429e-02, -6.4815e-02,  8.1414e-03, -7.9674e-02],
                [-6.2038e-02,  7.4582e-02, -1.7759e-02,  2.3795e-02, -1.5795e-02],
                [ 3.7823e-02, -3.3319e-04,  7.1363e-03,  7.7572e-02,  4.3771e-02]]],


              [[[-4.8656e-03,  4.3062e-02, -3.2547e-02,  9.7140e-03, -5.3167e-02],
                [ 4.2759e-02, -4.1656e-02,  6.4357e-02,  3.5642e-02, -7.8376e-02],
                [-1.2937e-02,  6.4533e-02,  1.5182e-02,  1.1444e-02, -7.4220e-02],
                [ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
                [-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],

               [[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02,  2.2487e-02],
                [-2.9249e-03,  6.8871e-02,  4.3621e-03,  2.6227e-02,  4.3713e-02],
                [ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
                [-3.5359e-02, -1.3766e-02,  3.6173e-02,  9.1202e-03, -3.9747e-02],
                [-7.2135e-02, -7.3420e-02,  6.0504e-02,  3.1594e-02,  7.6891e-02]],

               [[ 7.2759e-02, -6.5420e-02,  6.7763e-02,  7.2741e-02, -7.4671e-02],
                [-5.5163e-02, -7.5269e-02,  1.3287e-02,  1.8645e-02, -3.4054e-02],
                [ 6.5525e-02, -4.1262e-03,  4.6500e-02, -6.6291e-02,  5.8884e-02],
                [ 3.0486e-02,  3.6131e-03,  1.1222e-02, -3.3646e-02, -6.5889e-02],
                [ 4.7762e-02,  3.6352e-02,  9.7470e-03,  7.7495e-03, -5.5064e-02]],

               [[ 4.2110e-02,  5.1736e-02,  4.9755e-02,  1.8245e-02,  3.1093e-02],
                [-5.8074e-02, -4.1158e-02, -5.9566e-03,  6.2394e-02,  1.6582e-02],
                [-7.2003e-02,  1.4616e-02, -3.5987e-02,  3.0575e-02, -4.4705e-02],
                [-4.7500e-02, -1.9091e-02,  1.2661e-02,  2.4751e-02, -7.1824e-02],
                [ 4.3771e-02,  4.9023e-02,  7.2368e-02, -2.3195e-02, -3.0777e-02]],

               [[-8.2526e-03, -1.3523e-02, -6.9580e-02,  2.5552e-02, -1.5779e-02],
                [ 2.2318e-03,  2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
                [ 7.4568e-02, -4.6680e-02,  7.4333e-02, -6.5834e-02,  8.0266e-02],
                [ 1.0070e-02,  5.4708e-02, -1.4732e-03,  1.9077e-02, -2.5033e-02],
                [-2.7357e-02,  1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],

               [[ 4.0488e-02, -7.1390e-02, -2.3527e-02,  1.3764e-02, -1.5115e-02],
                [-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
                [-7.0937e-02, -5.7632e-02, -1.2339e-02,  5.2566e-02, -5.8696e-02],
                [-6.4373e-02,  5.0876e-02, -6.8186e-02,  6.8750e-02,  4.6615e-02],
                [ 3.0661e-02, -6.8377e-02,  1.7900e-02, -8.8543e-03,  1.4958e-02]]]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0482,  0.0700, -0.0483,  0.0543,  0.0427, -0.0160,  0.0652,  0.0702,
              -0.0693, -0.0685,  0.0211,  0.0540,  0.0356,  0.0235,  0.0208,  0.0305])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[-0.0222,  0.0364, -0.0196,  ..., -0.0034,  0.0235,  0.0133],
              [ 0.0133,  0.0200, -0.0305,  ..., -0.0283,  0.0068,  0.0075],
              [ 0.0367, -0.0009,  0.0018,  ...,  0.0404, -0.0319, -0.0467],
              ...,
              [-0.0015, -0.0323,  0.0344,  ...,  0.0226, -0.0282,  0.0235],
              [ 0.0195,  0.0296,  0.0472,  ..., -0.0476,  0.0057, -0.0356],
              [-0.0482,  0.0452,  0.0364,  ...,  0.0029,  0.0322,  0.0217]],
             requires_grad=True)
       tensor: tensor([[-3.8304e-02,  5.6343e-02, -1.1556e-01,  ...,  8.4700e-02,
               -2.3077e-02, -9.6969e-02],
              [-6.9593e-02,  4.2074e-02,  2.1176e-03,  ..., -5.1170e-02,
               -1.7429e-02,  6.8395e-02],
              [-1.5431e-02,  4.9186e-03, -4.7218e-02,  ...,  4.2745e-02,
               -3.6945e-02, -8.5793e-02],
              ...,
              [ 5.5533e-02,  2.4006e-02,  2.3173e-02,  ...,  2.2089e-02,
               -4.7982e-03, -3.0608e-02],
              [-9.1586e-06,  6.9304e-02,  6.9180e-02,  ..., -1.3455e-02,
                4.2434e-02, -3.7071e-02],
              [-2.3114e-02,  5.0826e-02, -7.1701e-02,  ...,  5.6384e-02,
                2.2557e-02, -4.0097e-02]], grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0030, -0.0106, -0.0147,  0.0311, -0.0320,  0.0124, -0.0280,  0.0140,
               0.0088, -0.0196,  0.0488, -0.0001,  0.0003,  0.0401, -0.0014,  0.0427,
              -0.0403, -0.0008,  0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
               0.0320, -0.0158, -0.0113, -0.0198,  0.0193,  0.0356, -0.0264, -0.0160,
               0.0050,  0.0121, -0.0498,  0.0146, -0.0372,  0.0089, -0.0298,  0.0399,
               0.0347, -0.0108,  0.0353, -0.0157, -0.0174, -0.0355,  0.0131,  0.0192,
               0.0432, -0.0373,  0.0332, -0.0114,  0.0318, -0.0132, -0.0002, -0.0403,
               0.0447, -0.0203,  0.0274, -0.0342, -0.0080,  0.0389,  0.0318,  0.0043,
               0.0192,  0.0158,  0.0490,  0.0272, -0.0142, -0.0218,  0.0353, -0.0035,
               0.0169, -0.0432,  0.0079,  0.0499, -0.0018,  0.0296, -0.0337, -0.0214,
               0.0376,  0.0054,  0.0384,  0.0403, -0.0050, -0.0075, -0.0203,  0.0318,
               0.0285, -0.0415,  0.0395, -0.0045, -0.0020,  0.0245, -0.0361,  0.0150,
               0.0347,  0.0185, -0.0093, -0.0056,  0.0021, -0.0026,  0.0046,  0.0380,
              -0.0403, -0.0422, -0.0218, -0.0198, -0.0446,  0.0296,  0.0276,  0.0089,
              -0.0049,  0.0100,  0.0118,  0.0283, -0.0334,  0.0287,  0.0236, -0.0404],
             requires_grad=True)
       tensor: tensor([ 0.0523, -0.0292,  0.0273,  0.1130, -0.1221, -0.0301, -0.0207, -0.0004,
               0.0262, -0.0813,  0.0444, -0.0165, -0.0252,  0.0806, -0.0251, -0.0477,
              -0.0964, -0.0773,  0.0956, -0.1304, -0.0227, -0.0166, -0.0507, -0.0412,
               0.0829, -0.0533,  0.0776, -0.0555,  0.1376,  0.0208, -0.0055, -0.0125,
              -0.0041, -0.0186,  0.0247,  0.0902, -0.0040, -0.0262, -0.0154,  0.0425,
               0.0448, -0.0340,  0.0756, -0.0699, -0.0605, -0.0363,  0.0278,  0.0412,
              -0.0084,  0.0181,  0.0751, -0.0466, -0.0437, -0.0335,  0.0715, -0.0624,
               0.0333,  0.0213,  0.0205,  0.0236, -0.0448, -0.0932,  0.0688,  0.0458,
               0.0226,  0.0806,  0.0200,  0.0317, -0.0296, -0.0642,  0.0867, -0.1344,
               0.0021,  0.0203, -0.0926, -0.0053, -0.0611,  0.0150, -0.0447, -0.0945,
               0.0837,  0.0649,  0.0609,  0.0122,  0.0407, -0.0284,  0.0042,  0.0202,
               0.0736, -0.0268, -0.0105, -0.0601,  0.0477,  0.0708,  0.0400,  0.0461,
              -0.0125, -0.0138, -0.0739,  0.0362,  0.0398, -0.0847, -0.0054,  0.0625,
              -0.1151, -0.0861, -0.1333, -0.1284,  0.0016,  0.0346,  0.0692,  0.0263,
              -0.0565, -0.0479, -0.0471,  0.0167, -0.0569,  0.1117,  0.0053, -0.0830],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., 0., -0.,  ..., -0., 0., 0.],
              [0., 0., -0.,  ..., -0., 0., 0.],
              [0., -0., 0.,  ..., 0., -0., -0.],
              ...,
              [-0., -0., 0.,  ..., 0., -0., 0.],
              [0., 0., 0.,  ..., -0., 0., -0.],
              [-0., 0., 0.,  ..., 0., 0., 0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-0.0222,  0.0364, -0.0196,  ..., -0.0034,  0.0235,  0.0133],
              [ 0.0133,  0.0200, -0.0305,  ..., -0.0283,  0.0068,  0.0075],
              [ 0.0367, -0.0009,  0.0018,  ...,  0.0404, -0.0319, -0.0467],
              ...,
              [-0.0015, -0.0323,  0.0344,  ...,  0.0226, -0.0282,  0.0235],
              [ 0.0195,  0.0296,  0.0472,  ..., -0.0476,  0.0057, -0.0356],
              [-0.0482,  0.0452,  0.0364,  ...,  0.0029,  0.0322,  0.0217]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
              0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
              0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
              0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
              0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0030, -0.0106, -0.0147,  0.0311, -0.0320,  0.0124, -0.0280,  0.0140,
               0.0088, -0.0196,  0.0488, -0.0001,  0.0003,  0.0401, -0.0014,  0.0427,
              -0.0403, -0.0008,  0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
               0.0320, -0.0158, -0.0113, -0.0198,  0.0193,  0.0356, -0.0264, -0.0160,
               0.0050,  0.0121, -0.0498,  0.0146, -0.0372,  0.0089, -0.0298,  0.0399,
               0.0347, -0.0108,  0.0353, -0.0157, -0.0174, -0.0355,  0.0131,  0.0192,
               0.0432, -0.0373,  0.0332, -0.0114,  0.0318, -0.0132, -0.0002, -0.0403,
               0.0447, -0.0203,  0.0274, -0.0342, -0.0080,  0.0389,  0.0318,  0.0043,
               0.0192,  0.0158,  0.0490,  0.0272, -0.0142, -0.0218,  0.0353, -0.0035,
               0.0169, -0.0432,  0.0079,  0.0499, -0.0018,  0.0296, -0.0337, -0.0214,
               0.0376,  0.0054,  0.0384,  0.0403, -0.0050, -0.0075, -0.0203,  0.0318,
               0.0285, -0.0415,  0.0395, -0.0045, -0.0020,  0.0245, -0.0361,  0.0150,
               0.0347,  0.0185, -0.0093, -0.0056,  0.0021, -0.0026,  0.0046,  0.0380,
              -0.0403, -0.0422, -0.0218, -0.0198, -0.0446,  0.0296,  0.0276,  0.0089,
              -0.0049,  0.0100,  0.0118,  0.0283, -0.0334,  0.0287,  0.0236, -0.0404])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 5.3861e-02, -3.2810e-02,  1.8512e-02,  4.5986e-02,  5.1913e-02,
               -4.3799e-02,  5.6503e-03,  2.7175e-02,  6.8377e-02,  1.3057e-02,
               -3.2690e-02, -6.1064e-02,  3.8826e-02,  6.6544e-02,  5.5398e-02,
                3.0196e-02,  1.7987e-02,  5.5327e-02, -5.0021e-02, -7.3435e-02,
                6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
               -7.9152e-03,  2.1925e-02,  4.2662e-02,  2.8947e-06, -8.5279e-02,
               -7.4841e-03,  9.8418e-03, -7.7589e-02,  8.2463e-02,  4.1143e-02,
                3.5816e-03,  4.9777e-02, -4.8500e-02,  5.3974e-02, -6.0017e-02,
                6.7712e-02,  2.5674e-02,  6.1376e-02,  4.4900e-03, -6.7039e-02,
               -7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
               -7.6909e-02,  7.4443e-02, -2.8730e-02,  5.1316e-02,  1.8881e-04,
               -3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
               -5.9713e-03, -5.5715e-02,  5.2674e-02, -1.8986e-02,  1.1715e-02,
                7.8946e-02,  8.7757e-02,  7.8873e-02, -4.0747e-02, -9.0934e-02,
                8.4907e-02,  6.0053e-02,  6.8620e-02, -6.7833e-02,  4.9746e-02,
               -9.3147e-03,  4.5588e-02,  2.6278e-02,  7.8289e-02,  6.9648e-02,
               -1.8778e-02,  7.0376e-02,  4.2418e-02,  8.1578e-02, -2.5462e-03,
               -8.2635e-02, -3.2393e-02, -1.8944e-02,  5.2082e-02,  7.6844e-02,
                7.8650e-02,  1.0107e-02,  5.8002e-02,  5.6146e-02,  1.0183e-02,
                8.1826e-02,  2.2654e-02,  1.4227e-02,  6.7762e-02, -1.4747e-02,
               -1.7642e-02, -6.6754e-02, -3.3121e-02,  8.3696e-02, -1.6725e-02,
                5.1801e-02,  8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
               -2.1219e-02, -8.6543e-02,  1.6206e-02, -3.9126e-02,  5.2650e-02,
                8.4007e-02,  6.5445e-03, -2.2124e-02, -2.6831e-03,  9.0644e-02],
              [-7.6951e-02, -7.7030e-02, -2.8385e-02,  8.8830e-02,  1.8159e-02,
                1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
               -3.8862e-02,  8.1929e-02,  9.8960e-03, -7.4610e-02,  4.4434e-02,
                9.1217e-02,  1.7353e-02, -4.1310e-02,  4.6109e-03, -6.0577e-03,
               -5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
               -2.7643e-02,  2.0528e-02, -3.5310e-02, -7.1740e-02,  9.2993e-03,
               -4.4849e-02, -2.9480e-02,  8.8236e-02,  6.5175e-03, -1.8974e-02,
               -6.3304e-02, -4.7579e-02, -7.5561e-02,  6.3819e-03,  6.4816e-02,
               -5.2397e-03,  4.9444e-02, -7.8613e-02,  3.2730e-02, -4.0234e-02,
                7.0352e-02,  9.4854e-03, -6.6504e-02, -5.6702e-02,  8.2212e-02,
               -1.0986e-03, -7.2008e-02, -2.7400e-02,  2.1016e-02,  3.5173e-02,
               -7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02,  3.6873e-02,
               -4.1272e-02, -5.1494e-03, -7.6218e-03,  6.9450e-02,  3.0325e-02,
                1.2730e-02, -6.4263e-02,  8.1150e-02, -7.4637e-02,  4.3607e-02,
               -5.5105e-02,  7.7573e-02,  7.7680e-02,  4.4945e-02, -7.6075e-02,
                3.1183e-02,  4.1456e-02, -5.6365e-02, -6.4912e-02,  1.1328e-02,
               -6.9009e-02,  2.9723e-02,  3.2704e-02,  1.9641e-02, -2.5972e-02,
               -4.7365e-02, -6.7854e-02,  3.1308e-02, -4.7928e-02, -3.9339e-02,
                1.7182e-03, -7.1567e-02, -7.8225e-02,  8.7315e-02,  4.9131e-02,
               -8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02,  3.4342e-02,
               -1.7424e-02,  2.5424e-02, -7.5859e-02,  8.9858e-02, -6.9163e-02,
               -7.2367e-02,  8.4212e-02,  5.0855e-02,  1.1414e-03,  7.7345e-02,
                5.0493e-02,  7.5050e-02, -2.1079e-02,  7.8312e-02,  7.6594e-03,
                4.9073e-02,  6.5474e-02,  2.7924e-02, -1.4094e-02,  4.2029e-02]],
             requires_grad=True)
       tensor: tensor([[ 0.0610, -0.0128, -0.0120,  0.0714,  0.0156,  0.0183,  0.0681, -0.0840,
                0.1129,  0.0066, -0.0570, -0.0925,  0.0874,  0.0386,  0.0270,  0.0112,
                0.0199,  0.0883, -0.0757, -0.0829,  0.0854, -0.1206, -0.0577, -0.0897,
                0.0741, -0.0573,  0.0606, -0.0012,  0.0103, -0.1070, -0.0114,  0.1001,
               -0.0395, -0.0006,  0.0756,  0.0616,  0.0758, -0.0650,  0.0724, -0.0838,
                0.1125,  0.0607,  0.0729, -0.0289, -0.1317, -0.1161, -0.0362,  0.0083,
               -0.0673,  0.0753, -0.0666,  0.0700, -0.0002,  0.0904,  0.0204, -0.0528,
               -0.1112,  0.0246, -0.0508, -0.0364, -0.0725, -0.0273,  0.0819,  0.0152,
               -0.0765,  0.0628,  0.0228, -0.0045,  0.0116, -0.1370,  0.0286,  0.0078,
                0.0368, -0.1036, -0.0161, -0.0211,  0.0430, -0.0009,  0.0646,  0.0536,
               -0.0243,  0.1421, -0.0034,  0.0893, -0.0097, -0.1771,  0.0357, -0.0548,
                0.0414,  0.0090,  0.0756, -0.0161,  0.0447,  0.0344,  0.0186,  0.1482,
               -0.0190,  0.0250,  0.0333, -0.1116, -0.0008, -0.0638,  0.0016,  0.1462,
                0.0523, -0.0028,  0.1291,  0.0462, -0.0976, -0.0499, -0.0413, -0.0380,
                0.0625,  0.0406, -0.0037,  0.0522, -0.0341, -0.0405, -0.0160,  0.1534],
              [-0.0298, -0.1157, -0.0818,  0.0838, -0.0245,  0.0392, -0.1349, -0.0704,
               -0.1346, -0.0168,  0.0534,  0.0690, -0.0314, -0.0291,  0.0875,  0.1339,
                0.0260, -0.1362,  0.0432,  0.0740, -0.0278, -0.0642, -0.1747,  0.0349,
               -0.0983,  0.0362, -0.0348, -0.1002, -0.0262,  0.0045, -0.1418, -0.0607,
                0.0709, -0.0468, -0.0267, -0.0706, -0.0858, -0.0905,  0.0190,  0.0397,
               -0.0106, -0.0135, -0.0531, -0.0055, -0.0194,  0.0866,  0.0335, -0.0849,
               -0.0636, -0.0231, -0.0386, -0.0960, -0.0768,  0.0532,  0.0323, -0.0124,
               -0.0476, -0.0436, -0.0623,  0.0220, -0.0545,  0.0014, -0.0546,  0.1407,
               -0.0416, -0.0167, -0.0326,  0.0189, -0.0989,  0.0821, -0.0453,  0.1176,
                0.0483,  0.1579, -0.1453, -0.0338,  0.1282, -0.0891, -0.0673, -0.0274,
               -0.1313,  0.0243, -0.0285, -0.0651, -0.0463,  0.0518, -0.0594, -0.0117,
               -0.0808,  0.0316,  0.0344, -0.0550, -0.0490,  0.0915,  0.0668, -0.1171,
                0.0758, -0.0164, -0.1157,  0.1160,  0.0045,  0.0031, -0.1212,  0.0278,
               -0.0541, -0.1058,  0.0086,  0.0400,  0.0132,  0.1076,  0.1130,  0.0112,
                0.0434,  0.0284,  0.0447,  0.0700,  0.0243, -0.0166, -0.0543,  0.0415]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0687, -0.0438], requires_grad=True)
       tensor: tensor([-0.0407,  0.0235], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0., 0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0., -0., -0.,
               -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., -0., -0., -0., -0.,
               -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0.,
               0., -0., 0., -0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0.],
              [-0., -0., -0., 0., 0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., -0., -0.,
               -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0.,
               -0., 0., -0., -0., -0., 0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0.,
               0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0.,
               -0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., 0., -0., 0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 5.3861e-02, -3.2810e-02,  1.8512e-02,  4.5986e-02,  5.1913e-02,
               -4.3799e-02,  5.6503e-03,  2.7175e-02,  6.8377e-02,  1.3057e-02,
               -3.2690e-02, -6.1064e-02,  3.8826e-02,  6.6544e-02,  5.5398e-02,
                3.0196e-02,  1.7987e-02,  5.5327e-02, -5.0021e-02, -7.3435e-02,
                6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
               -7.9152e-03,  2.1925e-02,  4.2662e-02,  2.8947e-06, -8.5279e-02,
               -7.4841e-03,  9.8418e-03, -7.7589e-02,  8.2463e-02,  4.1143e-02,
                3.5816e-03,  4.9777e-02, -4.8500e-02,  5.3974e-02, -6.0017e-02,
                6.7712e-02,  2.5674e-02,  6.1376e-02,  4.4900e-03, -6.7039e-02,
               -7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
               -7.6909e-02,  7.4443e-02, -2.8730e-02,  5.1316e-02,  1.8881e-04,
               -3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
               -5.9713e-03, -5.5715e-02,  5.2674e-02, -1.8986e-02,  1.1715e-02,
                7.8946e-02,  8.7757e-02,  7.8873e-02, -4.0747e-02, -9.0934e-02,
                8.4907e-02,  6.0053e-02,  6.8620e-02, -6.7833e-02,  4.9746e-02,
               -9.3147e-03,  4.5588e-02,  2.6278e-02,  7.8289e-02,  6.9648e-02,
               -1.8778e-02,  7.0376e-02,  4.2418e-02,  8.1578e-02, -2.5462e-03,
               -8.2635e-02, -3.2393e-02, -1.8944e-02,  5.2082e-02,  7.6844e-02,
                7.8650e-02,  1.0107e-02,  5.8002e-02,  5.6146e-02,  1.0183e-02,
                8.1826e-02,  2.2654e-02,  1.4227e-02,  6.7762e-02, -1.4747e-02,
               -1.7642e-02, -6.6754e-02, -3.3121e-02,  8.3696e-02, -1.6725e-02,
                5.1801e-02,  8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
               -2.1219e-02, -8.6543e-02,  1.6206e-02, -3.9126e-02,  5.2650e-02,
                8.4007e-02,  6.5445e-03, -2.2124e-02, -2.6831e-03,  9.0644e-02],
              [-7.6951e-02, -7.7030e-02, -2.8385e-02,  8.8830e-02,  1.8159e-02,
                1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
               -3.8862e-02,  8.1929e-02,  9.8960e-03, -7.4610e-02,  4.4434e-02,
                9.1217e-02,  1.7353e-02, -4.1310e-02,  4.6109e-03, -6.0577e-03,
               -5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
               -2.7643e-02,  2.0528e-02, -3.5310e-02, -7.1740e-02,  9.2993e-03,
               -4.4849e-02, -2.9480e-02,  8.8236e-02,  6.5175e-03, -1.8974e-02,
               -6.3304e-02, -4.7579e-02, -7.5561e-02,  6.3819e-03,  6.4816e-02,
               -5.2397e-03,  4.9444e-02, -7.8613e-02,  3.2730e-02, -4.0234e-02,
                7.0352e-02,  9.4854e-03, -6.6504e-02, -5.6702e-02,  8.2212e-02,
               -1.0986e-03, -7.2008e-02, -2.7400e-02,  2.1016e-02,  3.5173e-02,
               -7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02,  3.6873e-02,
               -4.1272e-02, -5.1494e-03, -7.6218e-03,  6.9450e-02,  3.0325e-02,
                1.2730e-02, -6.4263e-02,  8.1150e-02, -7.4637e-02,  4.3607e-02,
               -5.5105e-02,  7.7573e-02,  7.7680e-02,  4.4945e-02, -7.6075e-02,
                3.1183e-02,  4.1456e-02, -5.6365e-02, -6.4912e-02,  1.1328e-02,
               -6.9009e-02,  2.9723e-02,  3.2704e-02,  1.9641e-02, -2.5972e-02,
               -4.7365e-02, -6.7854e-02,  3.1308e-02, -4.7928e-02, -3.9339e-02,
                1.7182e-03, -7.1567e-02, -7.8225e-02,  8.7315e-02,  4.9131e-02,
               -8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02,  3.4342e-02,
               -1.7424e-02,  2.5424e-02, -7.5859e-02,  8.9858e-02, -6.9163e-02,
               -7.2367e-02,  8.4212e-02,  5.0855e-02,  1.1414e-03,  7.7345e-02,
                5.0493e-02,  7.5050e-02, -2.1079e-02,  7.8312e-02,  7.6594e-03,
                4.9073e-02,  6.5474e-02,  2.7924e-02, -1.4094e-02,  4.2029e-02]])
      (bias): Normal:
       loc: tensor([-0., -0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0687, -0.0438])
    )
    (observed): Observed()
  )
)

Fit the model

Finally we can set up the training loop

optim = torch.optim.Adam(net.parameters())
for i in range(1):
    for data, target in loader:
        net.observe(classification=target)
        borch.sample(net)
        net(data)
        loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
        loss.backward()
        optim.step()
        optim.zero_grad()

Now we can check the accuracy, Note that one should stop condtioning on the target by setting net.observe(None)

net.observe(None)
tot_acc = 0
with torch.no_grad():
    for i, (data, target) in enumerate(loader):
        borch.sample(net)
        out = net(data)
        acc = float((target == out).sum().float() / target.shape[0]) * 100
        tot_acc += acc
    tot_acc /= i + 1
print(tot_acc)

Out:

50.0

the accuracy is basically random, this is due to the fact that we are fitting white noise so it to be expected.

But in case you have trouble getting higher accuracy you should consider running for more epochs, setting up an augmentation pipeline (see: the data loading tutorial) and changing your posterior. The posterior can be changed using

net.apply(borch.set_posteriors(borch.posterior.Automatic))

Out:

Net(
  (posterior): Automatic()
  (prior): Module(
    (classification): Categorical:
     logits: tensor([[ 0.3925, -0.3818],
            [ 0.2547, -0.3470],
            [ 0.3052, -0.3358],
            [ 0.3430, -0.1107],
            [ 0.2112, -0.4776],
            [ 0.4108, -0.1150],
            [ 0.5736, -0.1190],
            [ 0.5299, -0.1948],
            [ 0.4381, -0.1530],
            [ 0.4474, -0.2609],
            [ 0.1968, -0.2226],
            [ 0.3060, -0.2168],
            [ 0.5697, -0.3060],
            [ 0.4261, -0.2987],
            [ 0.3829, -0.3391],
            [ 0.3738, -0.3848],
            [ 0.2025, -0.4219],
            [ 0.1772, -0.2161],
            [ 0.3705, -0.2277],
            [ 0.3645, -0.5349]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([])
  )
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., -0., -0., 0., 0., -0.], requires_grad=True)
       tensor: tensor([ 0.3744,  0.0923,  0.0835, -0.4157, -0.3875,  0.3015],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.]]],


              [[[-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.]]],


              [[[0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., 0.],
                [0., -0., -0., -0., -0.]]],


              [[[0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.]]],


              [[[0., 0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., 0., -0.]]],


              [[[-0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.1567,  0.0194,  0.0769, -0.0673, -0.0263],
                [-0.1055, -0.0411, -0.0689, -0.0195, -0.0636],
                [-0.1941,  0.1049, -0.0666,  0.1288,  0.1553],
                [-0.0577,  0.0011,  0.0145, -0.0518, -0.1235],
                [-0.1586,  0.1436, -0.1999, -0.0364, -0.0872]]],


              [[[-0.0281,  0.1596, -0.0365, -0.0633, -0.1375],
                [ 0.0499, -0.0787,  0.1923, -0.1312,  0.0520],
                [ 0.0351, -0.0756, -0.0147,  0.1359,  0.1740],
                [ 0.0815,  0.0631,  0.0970,  0.0557, -0.1079],
                [-0.0794, -0.0029, -0.1620, -0.1147,  0.0705]]],


              [[[ 0.1613,  0.0524, -0.1935,  0.0940, -0.0838],
                [-0.1790, -0.0222, -0.0098, -0.1600, -0.1524],
                [ 0.1950, -0.1447, -0.1897,  0.1429,  0.0565],
                [ 0.1752,  0.0891, -0.0210, -0.0461,  0.1512],
                [ 0.1282, -0.1387, -0.1025, -0.1408, -0.0373]]],


              [[[ 0.1134,  0.0826, -0.1147, -0.1228, -0.0040],
                [-0.1923,  0.1244,  0.1793,  0.0183, -0.0433],
                [-0.1840, -0.1691,  0.1184,  0.0151, -0.1467],
                [-0.1828,  0.0363,  0.1660, -0.0660, -0.1319],
                [-0.1644,  0.1835, -0.0681, -0.0800,  0.0668]]],


              [[[ 0.0005,  0.1977,  0.1792,  0.0681,  0.1410],
                [ 0.0081, -0.0216, -0.0456, -0.1985,  0.0049],
                [-0.0568, -0.1405,  0.0604, -0.1020,  0.0084],
                [-0.1587, -0.1756,  0.1148,  0.0872, -0.1307],
                [ 0.1244,  0.0193,  0.0314,  0.0787, -0.1329]]],


              [[[-0.1363,  0.0538,  0.0294,  0.0635, -0.0873],
                [-0.0751, -0.0357, -0.0701,  0.0053, -0.1896],
                [-0.0187, -0.0438, -0.0541, -0.1959,  0.0103],
                [-0.1991,  0.0912,  0.1853, -0.0772,  0.0883],
                [-0.1368,  0.0420,  0.0322, -0.1162, -0.0013]]]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., 0., -0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1428, -0.1690, -0.0722,  0.0667,  0.0933, -0.0337])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
             requires_grad=True)
       tensor: tensor([ 0.3320,  0.1212, -0.0444,  0.3015,  0.2103, -0.0697, -0.1125,  0.1532,
              -0.1150,  0.0043, -0.0299, -0.2236,  0.1937,  0.1273, -0.6395, -0.3135],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., -0., 0.]],

               [[-0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.]],

               [[-0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[-0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.]],

               [[0., 0., -0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.]]],


              [[[0., 0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.]],

               [[0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., 0., 0.],
                [0., -0., 0., 0., -0.]]],


              ...,


              [[[0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., -0.]],

               [[0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.]],

               [[0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., -0., 0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., -0., -0., 0., -0.]],

               [[0., -0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[-0., -0., -0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., 0., -0., 0.]],

               [[0., -0., 0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., 0.]],

               [[0., 0., -0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., 0., -0.]],

               [[-0., -0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., 0.]]],


              [[[-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.]],

               [[0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., -0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-1.6434e-02, -1.3113e-02, -4.1464e-02,  1.6377e-02,  6.2508e-02],
                [ 4.0202e-02,  4.9382e-02,  1.0253e-02,  3.4932e-02, -2.9341e-02],
                [-6.4074e-02,  5.2063e-02, -3.8546e-02,  1.5865e-02, -4.8314e-02],
                [ 2.0698e-02,  8.0888e-02,  7.9982e-02,  5.6686e-02,  1.1419e-02],
                [-6.1541e-02,  6.8319e-02, -7.6261e-02, -2.5828e-03, -5.7380e-02]],

               [[ 8.0672e-02, -7.8271e-02,  7.9289e-02,  2.1951e-02,  5.9238e-02],
                [ 1.1237e-02, -5.2598e-02,  4.1551e-02,  4.3546e-02, -5.8132e-02],
                [ 3.0288e-02, -7.5402e-03,  2.8025e-02,  1.5065e-02, -3.7260e-02],
                [-2.1855e-02,  4.4424e-02, -6.7250e-02,  4.3303e-02, -4.0667e-02],
                [ 4.9321e-02, -2.0541e-02,  5.6232e-02, -3.8466e-02,  5.1026e-02]],

               [[-2.6513e-02,  4.9569e-02,  1.8676e-02,  6.9054e-02, -2.6767e-02],
                [-5.7368e-02, -7.8838e-03, -7.9114e-02, -1.8557e-02,  2.8787e-04],
                [ 1.3084e-02,  4.6926e-02, -1.3894e-02,  2.7116e-02, -7.9155e-02],
                [ 7.5190e-02, -2.2396e-02,  5.5935e-02,  6.8270e-02, -3.2401e-02],
                [-1.0770e-02,  3.6370e-02, -7.0476e-02,  6.4846e-02, -6.7540e-02]],

               [[-3.7310e-02,  5.9737e-02,  5.9160e-05,  3.5063e-02,  7.1916e-02],
                [ 1.7257e-02, -6.6353e-03, -6.3498e-03, -1.9595e-02,  4.9856e-02],
                [-8.0101e-02, -4.1773e-02, -1.9761e-02, -1.6521e-02, -3.6788e-02],
                [ 9.3012e-04, -2.4376e-02, -1.1352e-03,  1.1929e-02, -1.3466e-02],
                [ 5.6119e-02,  2.2037e-02,  6.6074e-02,  3.7063e-02,  4.2302e-02]],

               [[-3.6769e-02,  3.7521e-02, -9.7236e-03, -6.8646e-02, -3.6374e-02],
                [-2.8501e-02, -8.0159e-02, -2.8122e-02,  1.8005e-02, -3.7241e-02],
                [ 4.2443e-02, -4.8021e-02, -3.0217e-02, -5.2788e-02,  2.2794e-02],
                [-6.3064e-02,  6.8354e-02,  4.3337e-02, -4.3442e-02, -6.5524e-02],
                [-5.6791e-02,  7.4918e-02,  5.9263e-02, -1.2248e-04, -3.5795e-02]],

               [[-2.1802e-02, -9.4069e-03,  7.6862e-02,  1.4458e-02, -7.0315e-02],
                [-5.9874e-02,  5.0751e-02,  3.9215e-02,  7.7538e-02,  3.2766e-02],
                [-4.4479e-02,  5.2775e-02, -1.9053e-03,  3.2342e-02, -4.6321e-02],
                [ 3.8624e-02,  2.0750e-02,  4.6860e-02, -5.3962e-02,  6.4275e-02],
                [-6.5864e-02, -1.3428e-02,  1.7864e-02, -7.0872e-02, -4.7788e-02]]],


              [[[-6.8434e-03,  1.4991e-02, -4.0166e-02, -4.2922e-02, -3.7892e-02],
                [-4.6419e-03,  7.9605e-02,  2.4481e-02, -1.8153e-02, -6.5677e-02],
                [-4.4087e-02, -3.3966e-02,  7.7397e-02, -2.4133e-02, -2.1281e-02],
                [ 7.2108e-02,  5.7730e-02,  1.8671e-02, -6.6956e-02, -7.8317e-02],
                [ 5.1377e-02,  4.0833e-02,  1.3915e-02, -7.3671e-02, -6.4171e-02]],

               [[-2.2992e-02, -1.5135e-02, -1.6468e-02,  3.4729e-02, -5.0686e-02],
                [ 3.3976e-02,  4.7649e-02,  9.0078e-03,  6.1119e-02,  7.7885e-02],
                [ 2.0299e-02,  4.7323e-02,  5.7702e-02,  4.8393e-02, -4.6623e-02],
                [ 4.0535e-02,  5.1355e-02, -3.0590e-02,  7.3194e-02, -6.9639e-02],
                [-2.0004e-02, -3.7198e-02, -7.3253e-02, -3.4263e-02,  1.5673e-02]],

               [[-2.4125e-02, -2.1872e-02, -4.1021e-02,  1.5590e-02,  7.0267e-02],
                [-1.2325e-02,  4.8418e-02,  3.7418e-02,  5.6973e-02,  1.5516e-02],
                [-5.0112e-02,  4.1789e-02,  5.5392e-02, -3.5548e-02,  2.8206e-02],
                [ 5.9003e-02,  7.3764e-03,  1.5419e-02, -1.6909e-02, -4.0654e-02],
                [ 4.1070e-02, -2.2652e-02, -4.4021e-02, -8.1407e-03, -7.5206e-02]],

               [[-6.3301e-02, -6.5342e-02, -3.3752e-03, -6.6840e-02, -1.8425e-02],
                [-1.6499e-02,  5.8059e-02,  3.5353e-02,  9.1365e-03, -2.8343e-02],
                [ 4.6293e-02,  3.0543e-02, -7.3024e-02, -6.8207e-02, -2.7875e-02],
                [ 8.0904e-02,  1.1077e-02,  4.2119e-02,  5.4343e-02,  2.1999e-02],
                [ 6.2393e-02,  2.9035e-02,  2.6746e-02, -2.8318e-02, -4.9989e-02]],

               [[-6.2379e-02, -4.8560e-02, -5.2721e-02, -3.9246e-02, -6.8517e-03],
                [-7.4939e-03,  9.3801e-03,  1.3410e-02,  5.7410e-02,  4.4898e-03],
                [-1.7655e-02,  3.3112e-02, -6.4055e-02, -1.3577e-02,  6.3291e-02],
                [-3.9472e-02,  1.4227e-02, -4.6944e-02, -1.8557e-02, -3.4740e-02],
                [-6.3974e-02, -6.9448e-02, -2.1668e-02,  2.4177e-02,  3.0717e-02]],

               [[ 2.8530e-02,  2.3247e-02, -3.8633e-02, -5.7804e-02, -1.9294e-02],
                [ 5.4420e-02, -1.2094e-02,  2.1143e-02, -6.5897e-02, -2.8639e-02],
                [ 4.7260e-02,  4.2903e-02,  7.4266e-02,  6.9400e-02, -6.4887e-02],
                [-7.5132e-02, -5.4750e-02,  4.6103e-03,  3.0465e-02, -6.0162e-02],
                [ 3.2707e-02, -3.7524e-02, -6.7505e-02,  2.1123e-02, -1.5651e-02]]],


              [[[ 6.0358e-02,  5.9896e-02, -3.1081e-02,  5.8683e-02,  1.5452e-02],
                [ 6.0256e-02, -7.1520e-02,  7.8586e-02, -3.8772e-02, -7.1890e-02],
                [-5.1354e-02, -2.9084e-03, -3.9233e-02, -2.1499e-02, -2.9419e-02],
                [-3.2572e-02, -6.9616e-02,  2.9291e-02,  7.2235e-02,  2.5144e-02],
                [ 7.6527e-02,  3.1913e-03, -1.8299e-02,  1.5759e-02,  6.3982e-02]],

               [[-6.2217e-02,  2.4136e-02,  6.5684e-02, -5.2996e-02, -8.5318e-03],
                [ 5.4878e-02, -7.0118e-02,  5.6222e-02, -5.8217e-02, -1.2457e-02],
                [-9.0102e-03,  4.0819e-02, -5.2410e-02, -4.3693e-02, -8.1261e-03],
                [ 3.9352e-02,  4.2597e-02,  6.4178e-02, -1.6116e-02,  4.4007e-02],
                [-4.6907e-02,  5.0872e-02,  1.4034e-02, -7.7642e-02, -3.1652e-02]],

               [[ 1.8691e-02,  3.9128e-02, -1.8538e-03,  6.9222e-02,  4.7985e-02],
                [ 2.5163e-02, -3.2308e-02,  5.8934e-02,  6.4200e-02,  7.5079e-02],
                [-3.8752e-02, -6.2834e-02, -1.3630e-02,  4.7745e-02, -3.6710e-02],
                [-6.5912e-02, -7.5509e-02,  2.0538e-03,  6.1806e-02,  4.7332e-02],
                [ 5.2663e-02, -6.0765e-02, -1.1656e-02, -1.2399e-02,  6.5297e-02]],

               [[-4.0377e-02,  2.2776e-02,  2.0396e-03,  6.3307e-02,  7.7342e-02],
                [-8.6686e-03,  6.8417e-02,  4.9833e-02, -5.8394e-02, -6.8530e-02],
                [-6.4711e-02, -6.4908e-02, -3.2846e-02, -3.7337e-02, -3.2760e-02],
                [-7.8387e-02,  1.2714e-03, -3.3095e-02, -1.5624e-03, -1.5552e-02],
                [-6.5617e-02,  5.9709e-02,  7.6255e-02,  7.4220e-02, -3.9595e-02]],

               [[-2.4471e-02, -1.4723e-02, -4.3525e-03, -5.7851e-02,  1.1639e-02],
                [ 4.5532e-02,  4.3314e-02,  2.7463e-02, -3.2127e-02, -6.0824e-02],
                [ 4.7108e-02,  1.2112e-02, -1.6862e-02, -5.4160e-02,  4.8685e-03],
                [-2.5893e-02,  6.4832e-02,  3.3282e-02,  4.9884e-02,  6.3713e-02],
                [-1.9860e-02,  1.2712e-02,  7.0452e-04,  4.6135e-02, -5.4728e-02]],

               [[-5.1852e-02, -1.5589e-02, -3.1799e-02,  4.5747e-02, -2.9827e-02],
                [ 6.4932e-02, -5.3074e-02, -4.9272e-02,  1.8426e-02, -4.6095e-02],
                [ 4.1712e-02, -7.9372e-02, -7.9577e-02,  1.2126e-02,  4.2022e-02],
                [ 5.8650e-02,  1.3046e-02, -5.0546e-02,  7.1611e-02,  5.8748e-02],
                [ 4.8559e-02, -8.1399e-02,  6.8672e-02,  7.3071e-02, -4.6508e-02]]],


              ...,


              [[[ 7.2944e-02, -5.6815e-02, -7.2140e-02, -8.0878e-02,  8.1223e-02],
                [-2.6174e-02,  4.4648e-02, -1.7627e-05,  6.8991e-02,  9.3131e-04],
                [-1.9168e-02, -1.1712e-02, -3.9127e-02, -6.5451e-02, -1.9835e-02],
                [-2.9851e-02, -7.2093e-02,  6.1742e-03, -6.0501e-02, -3.0240e-02],
                [-3.6711e-02, -3.9918e-02,  1.8570e-02,  2.7867e-02, -3.9091e-02]],

               [[ 4.4294e-02, -6.2949e-02, -4.9712e-02, -6.0654e-02,  2.6511e-02],
                [-1.1918e-02,  1.9399e-02,  1.9778e-03,  4.6715e-02,  7.9662e-02],
                [-3.9779e-02, -3.2971e-02,  2.6502e-03,  6.0599e-02,  6.1761e-02],
                [-7.2092e-02, -3.8731e-02, -1.5203e-02,  3.3408e-02, -7.3232e-02],
                [ 2.1354e-02,  4.9467e-02,  6.6561e-02,  6.4517e-02, -4.9400e-02]],

               [[ 2.5279e-02,  7.5811e-02, -5.6423e-02,  7.1795e-03, -8.0469e-02],
                [ 6.3054e-02,  1.5441e-02, -7.9545e-02,  6.0103e-02,  4.7542e-02],
                [-2.6974e-02,  6.4899e-02, -7.6267e-02,  3.2200e-02,  9.7143e-03],
                [ 4.1850e-02, -1.8550e-02, -5.0626e-02,  3.7149e-02,  7.6128e-02],
                [-2.8989e-02,  2.5409e-03, -3.2850e-02, -1.1957e-02,  4.2580e-04]],

               [[ 4.6736e-02,  7.7891e-02,  2.2977e-02,  3.5759e-02, -7.9195e-02],
                [-4.3826e-02, -7.9846e-02, -5.2120e-02, -3.8209e-03,  2.4057e-02],
                [-6.7396e-03,  2.7530e-02,  1.1896e-03, -1.6895e-02, -5.0218e-02],
                [ 5.6456e-02, -7.6683e-02,  2.4498e-02, -5.4710e-02,  5.6294e-02],
                [-3.0637e-03, -2.6177e-02, -3.8865e-02, -3.9652e-02,  4.4595e-02]],

               [[-8.0799e-02, -7.9691e-02, -2.4048e-02, -6.6943e-02,  5.5213e-02],
                [ 1.1116e-02, -3.8443e-02,  4.3369e-02, -7.6902e-02,  5.1385e-02],
                [ 5.6263e-02, -5.4902e-02,  8.0991e-02,  1.6011e-02,  6.6421e-02],
                [-2.4895e-02, -4.4881e-02,  4.6953e-02, -4.1781e-02, -4.2947e-02],
                [ 8.0550e-02, -7.2696e-02, -4.6141e-02,  6.7832e-03, -1.6691e-03]],

               [[ 2.6609e-02, -3.9203e-02, -7.8157e-03,  2.2936e-04, -2.7554e-02],
                [ 4.0520e-02,  1.1102e-02, -2.2165e-02,  6.4671e-02,  1.1872e-02],
                [ 2.5477e-02,  3.2211e-02,  5.6317e-02,  5.1697e-02,  5.5899e-02],
                [-3.0296e-02, -3.9487e-02, -2.5797e-02,  5.7478e-02, -4.8781e-03],
                [-6.3375e-02, -4.3827e-02,  3.5311e-03,  4.7217e-02,  6.8362e-02]]],


              [[[-4.5381e-02, -7.7842e-02, -6.9001e-02, -7.6422e-03,  6.8520e-02],
                [ 2.3377e-02, -5.9736e-03, -6.8239e-02,  7.2911e-02, -6.6242e-02],
                [-1.5282e-02,  1.7386e-02,  3.9979e-02, -6.8327e-03, -1.7662e-03],
                [ 5.4649e-02, -4.8377e-03,  7.7069e-02, -8.0424e-02, -2.7894e-02],
                [-6.3750e-02, -2.7770e-02,  5.7462e-02, -1.8159e-02,  5.8960e-02]],

               [[ 1.5038e-02, -8.0078e-02,  1.0708e-02,  2.2493e-02,  2.2514e-02],
                [-2.7322e-02,  4.5916e-02,  7.1295e-02, -5.6998e-02, -5.2429e-02],
                [-2.4198e-02, -4.0081e-02, -7.5517e-02, -6.0738e-02, -1.9848e-02],
                [-8.0915e-02,  1.1733e-02,  7.0872e-02,  4.2211e-02,  3.7455e-03],
                [ 5.6451e-02, -2.0291e-02,  5.9699e-02, -3.8810e-02,  9.7062e-03]],

               [[ 7.0948e-02,  7.7596e-02, -5.9511e-02, -2.7747e-03, -2.9197e-02],
                [ 5.6304e-02, -5.9313e-02, -3.6894e-03, -3.4498e-02, -3.1743e-02],
                [ 6.2984e-02,  7.1278e-02,  1.8568e-02, -8.1057e-02, -7.4301e-02],
                [-1.7063e-02,  3.7341e-02, -1.7987e-02, -6.2014e-03, -1.3535e-02],
                [ 3.3733e-02,  3.2608e-02, -1.8692e-02,  6.1727e-02,  1.0257e-02]],

               [[ 4.3113e-04, -6.9241e-02,  2.2611e-02,  4.1913e-02, -6.6395e-02],
                [ 7.5128e-03, -7.3346e-02,  8.0353e-02,  1.2347e-02,  5.5333e-02],
                [-9.7800e-03,  4.5897e-02,  2.8835e-02, -3.6708e-02,  3.9655e-02],
                [ 2.7716e-02, -7.1659e-02,  7.1108e-03,  1.1511e-02, -4.8559e-02],
                [-3.0865e-02,  7.5560e-02,  2.8310e-02,  7.4005e-02, -5.0888e-02]],

               [[ 7.5087e-02,  6.3344e-02,  5.9466e-02,  1.0437e-02,  9.3939e-03],
                [-1.4452e-03, -5.0765e-02, -3.6996e-02, -6.8923e-02, -7.4329e-02],
                [ 1.1036e-02, -2.6916e-02, -6.9722e-02,  5.9740e-02,  4.6108e-02],
                [ 2.0379e-02,  3.6167e-02,  4.8153e-02, -3.0691e-02, -5.5250e-02],
                [ 3.5924e-02,  4.5421e-02, -4.7335e-02,  6.4587e-02, -5.7064e-02]],

               [[-1.6970e-03, -7.8021e-02, -6.0369e-02, -8.0641e-02,  7.1452e-02],
                [ 1.6848e-02, -7.5881e-02,  2.5285e-02,  2.5364e-02, -1.0818e-02],
                [-3.0854e-02, -2.4429e-02, -6.4815e-02,  8.1414e-03, -7.9674e-02],
                [-6.2038e-02,  7.4582e-02, -1.7759e-02,  2.3795e-02, -1.5795e-02],
                [ 3.7823e-02, -3.3319e-04,  7.1363e-03,  7.7572e-02,  4.3771e-02]]],


              [[[-4.8656e-03,  4.3062e-02, -3.2547e-02,  9.7140e-03, -5.3167e-02],
                [ 4.2759e-02, -4.1656e-02,  6.4357e-02,  3.5642e-02, -7.8376e-02],
                [-1.2937e-02,  6.4533e-02,  1.5182e-02,  1.1444e-02, -7.4220e-02],
                [ 6.3483e-02, -1.1542e-02, -4.0774e-02, -1.2172e-02, -2.7794e-02],
                [-8.1438e-03, -5.5991e-02, -2.9966e-02, -8.0014e-03, -5.2937e-02]],

               [[ 9.9251e-03, -2.7150e-02, -1.5934e-02, -3.4809e-02,  2.2487e-02],
                [-2.9249e-03,  6.8871e-02,  4.3621e-03,  2.6227e-02,  4.3713e-02],
                [ 8.1283e-02, -3.1387e-02, -6.9915e-02, -1.7858e-02, -2.1714e-02],
                [-3.5359e-02, -1.3766e-02,  3.6173e-02,  9.1202e-03, -3.9747e-02],
                [-7.2135e-02, -7.3420e-02,  6.0504e-02,  3.1594e-02,  7.6891e-02]],

               [[ 7.2759e-02, -6.5420e-02,  6.7763e-02,  7.2741e-02, -7.4671e-02],
                [-5.5163e-02, -7.5269e-02,  1.3287e-02,  1.8645e-02, -3.4054e-02],
                [ 6.5525e-02, -4.1262e-03,  4.6500e-02, -6.6291e-02,  5.8884e-02],
                [ 3.0486e-02,  3.6131e-03,  1.1222e-02, -3.3646e-02, -6.5889e-02],
                [ 4.7762e-02,  3.6352e-02,  9.7470e-03,  7.7495e-03, -5.5064e-02]],

               [[ 4.2110e-02,  5.1736e-02,  4.9755e-02,  1.8245e-02,  3.1093e-02],
                [-5.8074e-02, -4.1158e-02, -5.9566e-03,  6.2394e-02,  1.6582e-02],
                [-7.2003e-02,  1.4616e-02, -3.5987e-02,  3.0575e-02, -4.4705e-02],
                [-4.7500e-02, -1.9091e-02,  1.2661e-02,  2.4751e-02, -7.1824e-02],
                [ 4.3771e-02,  4.9023e-02,  7.2368e-02, -2.3195e-02, -3.0777e-02]],

               [[-8.2526e-03, -1.3523e-02, -6.9580e-02,  2.5552e-02, -1.5779e-02],
                [ 2.2318e-03,  2.7111e-02, -8.7496e-03, -2.3582e-02, -6.8521e-02],
                [ 7.4568e-02, -4.6680e-02,  7.4333e-02, -6.5834e-02,  8.0266e-02],
                [ 1.0070e-02,  5.4708e-02, -1.4732e-03,  1.9077e-02, -2.5033e-02],
                [-2.7357e-02,  1.9236e-02, -4.7921e-02, -5.5013e-02, -7.4643e-02]],

               [[ 4.0488e-02, -7.1390e-02, -2.3527e-02,  1.3764e-02, -1.5115e-02],
                [-3.7438e-02, -7.9287e-02, -6.0580e-02, -3.2224e-02, -2.1884e-02],
                [-7.0937e-02, -5.7632e-02, -1.2339e-02,  5.2566e-02, -5.8696e-02],
                [-6.4373e-02,  5.0876e-02, -6.8186e-02,  6.8750e-02,  4.6615e-02],
                [ 3.0661e-02, -6.8377e-02,  1.7900e-02, -8.8543e-03,  1.4958e-02]]]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0482,  0.0700, -0.0483,  0.0543,  0.0427, -0.0160,  0.0652,  0.0702,
              -0.0693, -0.0685,  0.0211,  0.0540,  0.0356,  0.0235,  0.0208,  0.0305])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
              0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
              0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
              0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
              0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.],
             requires_grad=True)
       tensor: tensor([ 0.0598,  0.0720, -0.2056, -0.0677,  0.0583,  0.0021,  0.0079,  0.0343,
               0.1597,  0.1023, -0.0047,  0.0677, -0.0050, -0.0228,  0.0133,  0.1010,
               0.0707, -0.1196,  0.0009,  0.2002,  0.0956,  0.0229,  0.0240, -0.0367,
               0.0650,  0.2132,  0.0634, -0.0123,  0.0125,  0.0665, -0.0738,  0.0954,
               0.0280, -0.1036, -0.1214, -0.0371,  0.0403,  0.0539, -0.0327, -0.0039,
               0.0645,  0.0404,  0.0437, -0.1351, -0.0247,  0.1724, -0.0036, -0.1897,
              -0.0021,  0.1507,  0.0524,  0.1568,  0.0330,  0.0332,  0.1069, -0.0358,
              -0.0007, -0.0262,  0.0324, -0.0762,  0.1193, -0.0476, -0.0853,  0.0407,
              -0.0146,  0.0540, -0.0463,  0.0140,  0.1053, -0.1393,  0.0070,  0.0167,
               0.0542,  0.1323,  0.0909,  0.0447,  0.0251, -0.0830, -0.0047, -0.0313,
              -0.0959, -0.0684, -0.1227,  0.1241, -0.0781,  0.0037, -0.0389,  0.1310,
              -0.0885,  0.0379,  0.0523,  0.0267, -0.0916,  0.1664,  0.1827, -0.0398,
              -0.0645,  0.1336,  0.0071, -0.0911, -0.0304,  0.1082,  0.1484, -0.0560,
              -0.0540,  0.0596,  0.1664,  0.0995,  0.1534, -0.0178, -0.0171, -0.0267,
              -0.0456,  0.0378, -0.0477,  0.0744,  0.1072, -0.1904, -0.0535,  0.1517],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., 0., -0.,  ..., -0., 0., 0.],
              [0., 0., -0.,  ..., -0., 0., 0.],
              [0., -0., 0.,  ..., 0., -0., -0.],
              ...,
              [-0., -0., 0.,  ..., 0., -0., 0.],
              [0., 0., 0.,  ..., -0., 0., -0.],
              [-0., 0., 0.,  ..., 0., 0., 0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-0.0222,  0.0364, -0.0196,  ..., -0.0034,  0.0235,  0.0133],
              [ 0.0133,  0.0200, -0.0305,  ..., -0.0283,  0.0068,  0.0075],
              [ 0.0367, -0.0009,  0.0018,  ...,  0.0404, -0.0319, -0.0467],
              ...,
              [-0.0015, -0.0323,  0.0344,  ...,  0.0226, -0.0282,  0.0235],
              [ 0.0195,  0.0296,  0.0472,  ..., -0.0476,  0.0057, -0.0356],
              [-0.0482,  0.0452,  0.0364,  ...,  0.0029,  0.0322,  0.0217]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., -0., -0., -0.,
              0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0.,
              0., -0., 0., -0., 0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0.,
              0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., 0.,
              0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0030, -0.0106, -0.0147,  0.0311, -0.0320,  0.0124, -0.0280,  0.0140,
               0.0088, -0.0196,  0.0488, -0.0001,  0.0003,  0.0401, -0.0014,  0.0427,
              -0.0403, -0.0008,  0.0406, -0.0468, -0.0297, -0.0245, -0.0370, -0.0124,
               0.0320, -0.0158, -0.0113, -0.0198,  0.0193,  0.0356, -0.0264, -0.0160,
               0.0050,  0.0121, -0.0498,  0.0146, -0.0372,  0.0089, -0.0298,  0.0399,
               0.0347, -0.0108,  0.0353, -0.0157, -0.0174, -0.0355,  0.0131,  0.0192,
               0.0432, -0.0373,  0.0332, -0.0114,  0.0318, -0.0132, -0.0002, -0.0403,
               0.0447, -0.0203,  0.0274, -0.0342, -0.0080,  0.0389,  0.0318,  0.0043,
               0.0192,  0.0158,  0.0490,  0.0272, -0.0142, -0.0218,  0.0353, -0.0035,
               0.0169, -0.0432,  0.0079,  0.0499, -0.0018,  0.0296, -0.0337, -0.0214,
               0.0376,  0.0054,  0.0384,  0.0403, -0.0050, -0.0075, -0.0203,  0.0318,
               0.0285, -0.0415,  0.0395, -0.0045, -0.0020,  0.0245, -0.0361,  0.0150,
               0.0347,  0.0185, -0.0093, -0.0056,  0.0021, -0.0026,  0.0046,  0.0380,
              -0.0403, -0.0422, -0.0218, -0.0198, -0.0446,  0.0296,  0.0276,  0.0089,
              -0.0049,  0.0100,  0.0118,  0.0283, -0.0334,  0.0287,  0.0236, -0.0404])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.7071, 0.7071], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., -0.], requires_grad=True)
       tensor: tensor([0.3527, 0.0185], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0., 0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., -0., 0., -0., -0., -0.,
               -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., -0., 0., -0., 0., 0., 0., 0., -0., -0., -0., -0.,
               -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., 0., -0., -0., 0., 0.,
               0., -0., 0., -0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0.,
               0., 0., 0., -0., -0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., 0., -0., -0., 0.],
              [-0., -0., -0., 0., 0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., -0., -0.,
               -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0.,
               -0., 0., -0., -0., -0., 0., 0., -0., -0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0.,
               0., 0., -0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0.,
               -0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., 0., 0., -0., 0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 5.3861e-02, -3.2810e-02,  1.8512e-02,  4.5986e-02,  5.1913e-02,
               -4.3799e-02,  5.6503e-03,  2.7175e-02,  6.8377e-02,  1.3057e-02,
               -3.2690e-02, -6.1064e-02,  3.8826e-02,  6.6544e-02,  5.5398e-02,
                3.0196e-02,  1.7987e-02,  5.5327e-02, -5.0021e-02, -7.3435e-02,
                6.4632e-02, -5.7208e-02, -4.0015e-02, -8.9137e-02, -2.3537e-02,
               -7.9152e-03,  2.1925e-02,  4.2662e-02,  2.8947e-06, -8.5279e-02,
               -7.4841e-03,  9.8418e-03, -7.7589e-02,  8.2463e-02,  4.1143e-02,
                3.5816e-03,  4.9777e-02, -4.8500e-02,  5.3974e-02, -6.0017e-02,
                6.7712e-02,  2.5674e-02,  6.1376e-02,  4.4900e-03, -6.7039e-02,
               -7.8965e-02, -6.7107e-02, -2.3485e-02, -1.5222e-02, -8.7112e-03,
               -7.6909e-02,  7.4443e-02, -2.8730e-02,  5.1316e-02,  1.8881e-04,
               -3.4971e-02, -7.5644e-02, -1.2047e-02, -5.6335e-02, -3.9648e-02,
               -5.9713e-03, -5.5715e-02,  5.2674e-02, -1.8986e-02,  1.1715e-02,
                7.8946e-02,  8.7757e-02,  7.8873e-02, -4.0747e-02, -9.0934e-02,
                8.4907e-02,  6.0053e-02,  6.8620e-02, -6.7833e-02,  4.9746e-02,
               -9.3147e-03,  4.5588e-02,  2.6278e-02,  7.8289e-02,  6.9648e-02,
               -1.8778e-02,  7.0376e-02,  4.2418e-02,  8.1578e-02, -2.5462e-03,
               -8.2635e-02, -3.2393e-02, -1.8944e-02,  5.2082e-02,  7.6844e-02,
                7.8650e-02,  1.0107e-02,  5.8002e-02,  5.6146e-02,  1.0183e-02,
                8.1826e-02,  2.2654e-02,  1.4227e-02,  6.7762e-02, -1.4747e-02,
               -1.7642e-02, -6.6754e-02, -3.3121e-02,  8.3696e-02, -1.6725e-02,
                5.1801e-02,  8.2761e-02, -1.6347e-03, -7.2732e-02, -8.5545e-02,
               -2.1219e-02, -8.6543e-02,  1.6206e-02, -3.9126e-02,  5.2650e-02,
                8.4007e-02,  6.5445e-03, -2.2124e-02, -2.6831e-03,  9.0644e-02],
              [-7.6951e-02, -7.7030e-02, -2.8385e-02,  8.8830e-02,  1.8159e-02,
                1.6277e-02, -8.6646e-02, -4.8097e-03, -7.2484e-02, -2.3892e-02,
               -3.8862e-02,  8.1929e-02,  9.8960e-03, -7.4610e-02,  4.4434e-02,
                9.1217e-02,  1.7353e-02, -4.1310e-02,  4.6109e-03, -6.0577e-03,
               -5.4844e-02, -8.3153e-02, -7.1516e-02, -3.4588e-03, -8.7724e-02,
               -2.7643e-02,  2.0528e-02, -3.5310e-02, -7.1740e-02,  9.2993e-03,
               -4.4849e-02, -2.9480e-02,  8.8236e-02,  6.5175e-03, -1.8974e-02,
               -6.3304e-02, -4.7579e-02, -7.5561e-02,  6.3819e-03,  6.4816e-02,
               -5.2397e-03,  4.9444e-02, -7.8613e-02,  3.2730e-02, -4.0234e-02,
                7.0352e-02,  9.4854e-03, -6.6504e-02, -5.6702e-02,  8.2212e-02,
               -1.0986e-03, -7.2008e-02, -2.7400e-02,  2.1016e-02,  3.5173e-02,
               -7.9398e-02, -4.5183e-02, -6.5127e-02, -3.0206e-02,  3.6873e-02,
               -4.1272e-02, -5.1494e-03, -7.6218e-03,  6.9450e-02,  3.0325e-02,
                1.2730e-02, -6.4263e-02,  8.1150e-02, -7.4637e-02,  4.3607e-02,
               -5.5105e-02,  7.7573e-02,  7.7680e-02,  4.4945e-02, -7.6075e-02,
                3.1183e-02,  4.1456e-02, -5.6365e-02, -6.4912e-02,  1.1328e-02,
               -6.9009e-02,  2.9723e-02,  3.2704e-02,  1.9641e-02, -2.5972e-02,
               -4.7365e-02, -6.7854e-02,  3.1308e-02, -4.7928e-02, -3.9339e-02,
                1.7182e-03, -7.1567e-02, -7.8225e-02,  8.7315e-02,  4.9131e-02,
               -8.9911e-02, -4.4713e-03, -3.8737e-02, -7.6371e-02,  3.4342e-02,
               -1.7424e-02,  2.5424e-02, -7.5859e-02,  8.9858e-02, -6.9163e-02,
               -7.2367e-02,  8.4212e-02,  5.0855e-02,  1.1414e-03,  7.7345e-02,
                5.0493e-02,  7.5050e-02, -2.1079e-02,  7.8312e-02,  7.6594e-03,
                4.9073e-02,  6.5474e-02,  2.7924e-02, -1.4094e-02,  4.2029e-02]])
      (bias): Normal:
       loc: tensor([-0., -0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0687, -0.0438])
    )
    (observed): Observed()
  )
)

One can also set the posterior when one creates the module

nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))

Out:

Linear(
  in_features=10, out_features=10, bias=True
  (posterior): Normal(
    (weight): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498]], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([[0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., -0., -0., -0., -0., -0., -0., -0.],
            [0., -0., 0., 0., -0., -0., -0., -0., -0., -0.],
            [-0., 0., -0., -0., 0., -0., 0., -0., -0., -0.],
            [0., -0., -0., 0., 0., 0., 0., 0., -0., -0.],
            [0., 0., 0., 0., -0., 0., -0., 0., -0., 0.],
            [0., 0., 0., 0., -0., -0., 0., 0., 0., -0.],
            [0., 0., -0., 0., -0., -0., 0., -0., -0., -0.],
            [-0., 0., 0., -0., -0., 0., -0., 0., -0., -0.],
            [-0., -0., 0., 0., 0., 0., 0., 0., 0., 0.]], requires_grad=True)
     tensor: tensor([[-0.0074,  0.0521, -0.0124, -0.0409,  0.0264,  0.0437, -0.0267,  0.0628,
              0.0505,  0.0381],
            [-0.0073,  0.0132, -0.0146,  0.0461,  0.0400, -0.0276, -0.0511, -0.0327,
              0.0057, -0.1127],
            [-0.0828, -0.0232, -0.0328, -0.0699, -0.0015, -0.0479,  0.0434, -0.0257,
             -0.0307, -0.0163],
            [-0.0135,  0.0351, -0.0783, -0.0044,  0.0843,  0.0475,  0.0530,  0.0148,
             -0.0002,  0.0309],
            [ 0.0177,  0.0899, -0.0343, -0.0520,  0.0359, -0.0923,  0.0123,  0.0111,
             -0.0587,  0.0311],
            [-0.0475, -0.0160,  0.0131, -0.0638, -0.0513, -0.0493,  0.0084, -0.0723,
              0.0861,  0.0588],
            [ 0.0710,  0.0325, -0.0378, -0.1117,  0.0043, -0.0647,  0.0132,  0.0545,
              0.0478, -0.0080],
            [-0.0560,  0.0021, -0.0093,  0.0394, -0.0130,  0.0267, -0.0057, -0.0756,
             -0.0694,  0.0105],
            [ 0.0652,  0.0100,  0.0970, -0.0560,  0.0092, -0.0568,  0.0200, -0.0449,
             -0.0877,  0.0570],
            [-0.0767,  0.0075, -0.0159,  0.0510,  0.0111, -0.0010, -0.0614,  0.0479,
              0.0140,  0.0625]], grad_fn=<AddBackward0>)
    (bias): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
            0.0498], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([0., 0., 0., 0., 0., 0., -0., -0., 0., -0.], requires_grad=True)
     tensor: tensor([-0.0160, -0.0384,  0.0385,  0.0560,  0.0037, -0.0786, -0.0205, -0.0698,
            -0.0399, -0.0347], grad_fn=<AddBackward0>)
  )
  (prior): Module(
    (weight): Normal:
     loc: tensor([[0., 0., -0., -0., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., -0., -0., -0., -0., -0., -0., -0.],
            [0., -0., 0., 0., -0., -0., -0., -0., -0., -0.],
            [-0., 0., -0., -0., 0., -0., 0., -0., -0., -0.],
            [0., -0., -0., 0., 0., 0., 0., 0., -0., -0.],
            [0., 0., 0., 0., -0., 0., -0., 0., -0., 0.],
            [0., 0., 0., 0., -0., -0., 0., 0., 0., -0.],
            [0., 0., -0., 0., -0., -0., 0., -0., -0., -0.],
            [-0., 0., 0., -0., -0., 0., -0., 0., -0., -0.],
            [-0., -0., 0., 0., 0., 0., 0., 0., 0., 0.]])
     scale: tensor([[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([[ 0.0537,  0.1889, -0.0853, -0.2389,  0.2932,  0.1727,  0.2749,  0.0216,
              0.2554,  0.2329],
            [ 0.3147,  0.2312,  0.2066, -0.0289, -0.3128, -0.0764, -0.0221, -0.1260,
             -0.0566, -0.2622],
            [ 0.2383, -0.1345,  0.3103,  0.0020, -0.0411, -0.0745, -0.2340, -0.1291,
             -0.1219, -0.2262],
            [-0.1985,  0.1491, -0.2021, -0.0445,  0.3105, -0.0486,  0.1779, -0.0055,
             -0.0425, -0.2117],
            [ 0.0271, -0.0579, -0.3026,  0.2160,  0.0549,  0.1236,  0.2550,  0.0574,
             -0.2970, -0.1247],
            [ 0.3069,  0.2616,  0.2535,  0.3027, -0.0497,  0.2073, -0.0330,  0.1368,
             -0.1539,  0.0674],
            [ 0.2222,  0.2572,  0.1273,  0.0229, -0.0218, -0.2938,  0.0404,  0.0300,
              0.0719, -0.2847],
            [ 0.2613,  0.0033, -0.1948,  0.1279, -0.2686, -0.2299,  0.0870, -0.0815,
             -0.2531, -0.1094],
            [-0.0470,  0.1280,  0.0507, -0.2919, -0.2141,  0.2617, -0.0497,  0.2454,
             -0.0208, -0.0923],
            [-0.2441, -0.2966,  0.0220,  0.2931,  0.1903,  0.0392,  0.0920,  0.2355,
              0.1240,  0.2759]])
    (bias): Normal:
     loc: tensor([0., 0., 0., 0., 0., 0., -0., -0., 0., -0.])
     scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
            0.3162])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([ 0.2346,  0.2745,  0.2184,  0.2540,  0.0325,  0.2008, -0.0309, -0.2569,
             0.1033, -0.1293])
  )
  (observed): Observed()
)

See the borch.posterior documentation for other posteriors and what parameters you can set. Note that all posteriors does not work with all parameters but you can have different posteriors for the different borch.Module’s in your network.

Exercises

  1. Use what you have learned to train an image classifier for MNIST, you should achieve an accuracy larger than 98 %. Note: you can access MNST using torchvision.datasets.MNIST.

  2. Fit the same model architecture with normal torch and compare the likelihood with the borch network, What are the differences and why?

  3. Port the model to CIFAR and see how you can improve the accuracy.

  4. Show how the Categorical distribution is related to the cross entropy loss function that is commonly used in frequentest deep learning.

Total running time of the script: ( 0 minutes 0.222 seconds)

Gallery generated by Sphinx-Gallery