Training an Image classifier

You will learn the basics of how to create an image classifier using the borch.nn package and fit it using the infer package.

Lets start of with importing what we need

import torch
from torch.utils.data import TensorDataset, DataLoader
import borch
from borch import infer, distributions
import torch.nn.functional as F

The module borch.nn provides implementations of neural network modules that are used for deep probabilistic programming. It provides an interface almost identical to the torch.nn modules and in many cases it is possible to just switch

from torch import nn

to

from borch import nn

Data

In this example we will use simulated data and not run the fitting until convergence, but show how the model is set up and how one can construct the training loop. We will just generate some random data, where data represent the image and target is the class.

data = torch.randn(20, 1, 32, 32)
labels = torch.randperm(2).repeat(10)
data_set = TensorDataset(data, labels)
loader = DataLoader(data_set, batch_size=20)

Model

Lets set up the model. In order to use infer and the borch to the fullest, we need to select a a likelihood distribution. For classification the distributions.Categorical is suitable.

class Net(borch.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=borch.posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 2)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return self.classification

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-2.3901e-02,  1.7009e-01, -4.0790e-02,  1.9047e-01,  1.1854e-01],
                [-1.0204e-01, -1.4715e-01,  9.0185e-03,  8.0316e-02, -1.4466e-01],
                [ 1.7187e-01,  8.4639e-02,  1.6665e-01, -1.5608e-01, -9.5238e-02],
                [ 1.9682e-01, -9.3068e-02, -1.9482e-02,  1.6220e-01, -5.5923e-02],
                [-1.5368e-01, -1.2494e-01,  3.1866e-02, -1.7428e-01, -6.5961e-02]]],


              [[[-5.4801e-02,  5.0452e-02, -1.5372e-01, -1.1482e-01,  1.5138e-01],
                [ 1.8247e-03,  4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
                [-1.9791e-02, -1.0920e-01,  1.2796e-01, -6.0495e-02,  1.2142e-01],
                [-1.0610e-01,  2.8335e-02,  2.0862e-02, -1.2132e-01, -5.6049e-03],
                [ 2.7007e-02,  1.5627e-01,  7.0422e-02, -2.1336e-03, -5.9226e-02]]],


              [[[-1.6497e-01, -9.5347e-02,  9.7235e-02,  1.7565e-01, -1.4118e-01],
                [-1.1203e-02, -5.6668e-02,  9.0249e-02,  1.9961e-01, -2.0049e-02],
                [-4.5493e-02, -2.0235e-02, -1.9463e-01,  1.5131e-01,  1.6076e-01],
                [-1.9071e-01, -1.6333e-01,  1.0380e-01, -7.2042e-02,  1.0249e-01],
                [ 1.7660e-01,  1.8708e-02, -1.0379e-01,  8.4113e-02, -1.3492e-01]]],


              [[[ 1.3284e-01, -1.7679e-02, -9.9538e-02,  1.5133e-01,  1.0864e-01],
                [-1.9522e-01,  1.0066e-01, -1.0742e-01, -1.1599e-01,  1.6930e-01],
                [-1.0281e-01, -1.4473e-01,  1.6300e-01, -7.4540e-02, -6.5797e-02],
                [ 1.5015e-01,  7.7701e-03,  7.3404e-02,  9.5653e-02, -1.2661e-01],
                [-3.2228e-02, -7.9872e-02,  1.9932e-01,  6.0159e-02,  1.3894e-01]]],


              [[[-1.7235e-01,  5.8651e-06,  8.9371e-02, -1.5355e-01, -1.2702e-01],
                [ 5.9223e-02,  1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
                [ 9.6366e-03, -1.0979e-01, -1.1570e-01,  7.1673e-03,  8.9918e-02],
                [ 2.4720e-02,  5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
                [-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],


              [[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
                [-7.8788e-03,  4.7546e-02,  1.4185e-01,  8.6187e-02, -1.0654e-01],
                [-6.4892e-02, -7.4628e-02,  5.9973e-02,  5.0245e-02,  1.1456e-01],
                [-4.4647e-02,  9.8620e-02, -1.2445e-01, -6.6966e-02,  3.5321e-02],
                [ 1.9966e-01, -3.8652e-03,  3.5615e-03, -1.3894e-02,  3.5925e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[-0.0311,  0.0859, -0.0472,  0.2721, -0.0104],
                [-0.0725, -0.1375,  0.0592,  0.0649, -0.1419],
                [ 0.1607,  0.0185,  0.2350, -0.1482, -0.0662],
                [ 0.2044, -0.0331,  0.0378,  0.2124, -0.0684],
                [-0.1102, -0.1244,  0.0971, -0.1783, -0.1476]]],


              [[[ 0.0512,  0.0096, -0.1303, -0.0858,  0.1533],
                [-0.0701,  0.0477, -0.0650, -0.0767, -0.1122],
                [-0.0093, -0.0608,  0.1847, -0.0801,  0.1446],
                [-0.1251,  0.0316,  0.0093, -0.1505, -0.0172],
                [-0.0315,  0.0236,  0.1091,  0.0313, -0.0320]]],


              [[[-0.1591, -0.1604,  0.0829,  0.1024, -0.1755],
                [-0.0662, -0.1338,  0.0292,  0.2482, -0.0993],
                [ 0.0087,  0.0244, -0.2304,  0.2126,  0.1096],
                [-0.1508, -0.0875,  0.1158, -0.0840,  0.1429],
                [ 0.1580,  0.0260, -0.0853,  0.0094, -0.1336]]],


              [[[ 0.1035, -0.0026, -0.0805,  0.1173,  0.0393],
                [-0.2130,  0.0323, -0.1763, -0.1442,  0.1475],
                [-0.1606, -0.0697,  0.1412, -0.1098, -0.0839],
                [ 0.1098, -0.0369,  0.0569,  0.1697, -0.0874],
                [-0.0676, -0.1037,  0.1526,  0.1673,  0.0780]]],


              [[[-0.1714,  0.0032,  0.0813, -0.1082, -0.0881],
                [-0.0233,  0.1377, -0.0773, -0.0227, -0.0461],
                [ 0.0277, -0.1762, -0.0500, -0.0424,  0.0864],
                [ 0.0196,  0.0805, -0.1020, -0.2220, -0.1493],
                [-0.0397, -0.2590, -0.0630, -0.1079, -0.0202]]],


              [[[-0.1821, -0.1409, -0.1910, -0.0148, -0.0221],
                [ 0.0964,  0.1303,  0.2632,  0.0626, -0.0929],
                [-0.1410, -0.0631,  0.0544,  0.0835,  0.0102],
                [-0.0231,  0.0293, -0.1860,  0.0420,  0.1065],
                [ 0.1520, -0.0128, -0.0434,  0.0051,  0.0573]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.1531, -0.0004,  0.1924, -0.0493,  0.0953,  0.0265],
             requires_grad=True)
       tensor: tensor([-0.1049,  0.0772,  0.1882, -0.0163,  0.1152,  0.1543],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.]]],


              [[[-0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.],
                [0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[-0., 0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.]]],


              [[[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-2.3901e-02,  1.7009e-01, -4.0790e-02,  1.9047e-01,  1.1854e-01],
                [-1.0204e-01, -1.4715e-01,  9.0185e-03,  8.0316e-02, -1.4466e-01],
                [ 1.7187e-01,  8.4639e-02,  1.6665e-01, -1.5608e-01, -9.5238e-02],
                [ 1.9682e-01, -9.3068e-02, -1.9482e-02,  1.6220e-01, -5.5923e-02],
                [-1.5368e-01, -1.2494e-01,  3.1866e-02, -1.7428e-01, -6.5961e-02]]],


              [[[-5.4801e-02,  5.0452e-02, -1.5372e-01, -1.1482e-01,  1.5138e-01],
                [ 1.8247e-03,  4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
                [-1.9791e-02, -1.0920e-01,  1.2796e-01, -6.0495e-02,  1.2142e-01],
                [-1.0610e-01,  2.8335e-02,  2.0862e-02, -1.2132e-01, -5.6049e-03],
                [ 2.7007e-02,  1.5627e-01,  7.0422e-02, -2.1336e-03, -5.9226e-02]]],


              [[[-1.6497e-01, -9.5347e-02,  9.7235e-02,  1.7565e-01, -1.4118e-01],
                [-1.1203e-02, -5.6668e-02,  9.0249e-02,  1.9961e-01, -2.0049e-02],
                [-4.5493e-02, -2.0235e-02, -1.9463e-01,  1.5131e-01,  1.6076e-01],
                [-1.9071e-01, -1.6333e-01,  1.0380e-01, -7.2042e-02,  1.0249e-01],
                [ 1.7660e-01,  1.8708e-02, -1.0379e-01,  8.4113e-02, -1.3492e-01]]],


              [[[ 1.3284e-01, -1.7679e-02, -9.9538e-02,  1.5133e-01,  1.0864e-01],
                [-1.9522e-01,  1.0066e-01, -1.0742e-01, -1.1599e-01,  1.6930e-01],
                [-1.0281e-01, -1.4473e-01,  1.6300e-01, -7.4540e-02, -6.5797e-02],
                [ 1.5015e-01,  7.7701e-03,  7.3404e-02,  9.5653e-02, -1.2661e-01],
                [-3.2228e-02, -7.9872e-02,  1.9932e-01,  6.0159e-02,  1.3894e-01]]],


              [[[-1.7235e-01,  5.8651e-06,  8.9371e-02, -1.5355e-01, -1.2702e-01],
                [ 5.9223e-02,  1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
                [ 9.6366e-03, -1.0979e-01, -1.1570e-01,  7.1673e-03,  8.9918e-02],
                [ 2.4720e-02,  5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
                [-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],


              [[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
                [-7.8788e-03,  4.7546e-02,  1.4185e-01,  8.6187e-02, -1.0654e-01],
                [-6.4892e-02, -7.4628e-02,  5.9973e-02,  5.0245e-02,  1.1456e-01],
                [-4.4647e-02,  9.8620e-02, -1.2445e-01, -6.6966e-02,  3.5321e-02],
                [ 1.9966e-01, -3.8652e-03,  3.5615e-03, -1.3894e-02,  3.5925e-02]]]])
      (bias): Normal:
       loc: tensor([-0., -0., 0., -0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1531, -0.0004,  0.1924, -0.0493,  0.0953,  0.0265])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-2.9409e-02,  6.0003e-02,  5.3245e-02,  2.8255e-02,  5.1865e-02],
                [-1.3120e-03,  6.9924e-02, -2.1958e-02,  4.5196e-02,  5.5987e-02],
                [ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
                [-6.1391e-02,  2.6227e-02,  2.2223e-02,  5.1530e-02,  7.8572e-02],
                [ 1.1986e-02, -3.5154e-02,  2.2590e-02, -2.8268e-02, -4.1092e-02]],

               [[-7.7353e-02, -3.3001e-03,  2.2446e-02,  3.6329e-02, -8.4312e-03],
                [ 5.9898e-02,  3.3921e-02, -4.0756e-02,  5.6990e-02, -7.9915e-02],
                [ 1.1058e-02, -7.5468e-02,  3.1362e-02, -1.2730e-02,  7.0498e-02],
                [-6.7726e-03, -3.9453e-02,  6.8676e-02,  2.8066e-02,  4.1992e-02],
                [ 8.0445e-02,  6.9594e-02,  2.9287e-02,  4.4932e-02, -7.2432e-02]],

               [[-6.8457e-02, -7.5792e-02,  3.8843e-02,  3.2058e-02,  2.1625e-02],
                [ 7.5601e-02, -5.7455e-02,  1.2939e-02, -3.5165e-02, -5.6556e-02],
                [ 8.3346e-03,  2.2498e-02,  6.8399e-03, -2.3131e-02, -2.1235e-03],
                [-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
                [ 4.4843e-02,  1.4243e-02, -7.8475e-02,  2.0007e-02,  6.8106e-02]],

               [[-7.1642e-03, -4.0418e-02,  1.6455e-03, -7.4808e-03,  3.0932e-02],
                [-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02,  7.4918e-02],
                [ 2.7178e-03,  7.6559e-02,  2.3546e-02,  1.0888e-02, -3.2943e-02],
                [-3.1639e-02, -4.4183e-02, -1.1435e-02,  1.8024e-02,  7.9902e-02],
                [-1.6372e-02,  6.6337e-03, -5.5328e-02,  1.6804e-02, -1.1563e-02]],

               [[ 3.4681e-02,  3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
                [ 8.0547e-02,  4.9685e-02,  7.5298e-02, -2.9211e-02,  2.2224e-02],
                [-2.4409e-02,  7.0003e-02,  2.1558e-04, -4.4977e-02,  2.3133e-02],
                [-7.4964e-02, -6.4867e-02, -4.5671e-02,  3.1878e-02,  3.5104e-02],
                [ 2.2541e-02,  3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],

               [[-7.0267e-02, -2.8274e-03, -3.7755e-02,  3.8392e-02, -1.3160e-05],
                [-2.4893e-02,  2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
                [ 4.5711e-02, -3.3377e-02,  2.4278e-02, -4.1101e-03,  1.8952e-02],
                [ 2.0236e-02,  6.4522e-03, -5.7711e-02, -4.5909e-02,  7.4404e-02],
                [-3.7771e-02, -3.1536e-03,  6.4748e-02, -4.4604e-02, -2.2752e-02]]],


              [[[-2.9177e-02,  5.0791e-02, -5.5935e-03,  4.7169e-02, -1.1366e-02],
                [ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02,  6.6127e-02],
                [-1.1869e-03,  7.9363e-03,  6.0124e-02,  8.7514e-04, -2.6833e-02],
                [-4.4174e-02,  1.9975e-02,  2.7939e-02,  2.7791e-02, -1.7919e-02],
                [-2.9191e-02,  2.5500e-02,  3.6869e-03,  3.1063e-02, -3.7672e-02]],

               [[-6.7084e-02,  5.3586e-02,  1.8485e-02, -7.8099e-03,  4.6288e-02],
                [ 4.9055e-02, -1.4634e-02,  3.2799e-02,  3.9726e-02, -7.0032e-02],
                [ 7.0897e-02,  3.4370e-02, -1.4814e-02, -3.9030e-02,  3.0867e-02],
                [ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02,  2.6245e-02],
                [ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02,  3.3599e-02]],

               [[ 1.1310e-02, -2.9816e-02,  6.3727e-02,  3.9850e-02,  1.3761e-02],
                [ 3.0453e-02, -4.8504e-02,  5.3189e-02,  1.9425e-02,  4.7484e-02],
                [ 6.1376e-02, -7.8290e-02, -6.8859e-02,  1.8497e-02, -1.1496e-02],
                [-7.8178e-02, -4.5904e-02,  7.3181e-02,  2.9441e-02,  4.6967e-02],
                [ 7.6978e-02,  7.2934e-02,  5.6798e-02,  5.8828e-02,  4.4637e-02]],

               [[ 6.0281e-02,  7.7289e-02,  7.9016e-02, -4.1437e-02,  3.1101e-02],
                [-5.0620e-02,  3.3108e-02, -5.8687e-02,  2.7694e-02,  5.4294e-02],
                [ 2.1156e-02,  1.7004e-03,  2.4742e-02,  6.9593e-02,  5.7699e-02],
                [ 6.8876e-02,  3.2239e-02,  3.3322e-02,  2.9973e-02,  7.4267e-02],
                [-2.4736e-03,  3.8454e-02, -2.5898e-02,  2.0443e-02,  6.0816e-02]],

               [[ 1.6487e-02, -6.8994e-03,  5.2835e-02,  5.7784e-02, -1.7036e-02],
                [-7.1883e-02,  3.1576e-04, -5.5839e-02,  1.4949e-02, -1.9834e-03],
                [-2.7395e-02,  3.9861e-04, -1.0588e-02,  9.9140e-03,  5.1499e-02],
                [ 7.4060e-02, -1.0655e-02,  1.1668e-02,  4.9183e-02,  5.2846e-02],
                [ 2.8634e-02,  4.2678e-02, -1.4281e-02,  1.3904e-03,  7.6289e-02]],

               [[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
                [-4.1579e-02,  2.8281e-02,  3.9429e-02,  7.5058e-03,  6.5170e-03],
                [-3.4494e-02,  7.5710e-02,  4.1078e-02,  4.4451e-02,  4.2661e-02],
                [-5.4398e-02,  5.1592e-02, -2.6367e-02, -3.2980e-02,  5.3860e-02],
                [-4.2436e-02, -8.4286e-03,  7.5331e-02, -6.6725e-02,  4.9887e-02]]],


              [[[-1.5211e-02, -6.2506e-03,  3.0621e-03,  3.0725e-02, -7.0877e-03],
                [ 1.1974e-02, -5.2611e-02, -2.7415e-02,  4.3479e-02, -4.2108e-02],
                [ 3.3816e-02,  6.1523e-02, -9.9011e-03, -3.7770e-02,  6.5915e-04],
                [ 5.3678e-03,  5.9921e-02, -3.4530e-02,  5.1942e-02,  5.3762e-02],
                [-4.7293e-02, -6.2274e-02, -7.5059e-02,  8.1645e-02,  2.1149e-02]],

               [[-1.4459e-02, -2.7155e-02, -2.5730e-02,  7.6751e-02, -1.6932e-02],
                [-5.3342e-02, -2.6885e-02,  4.3476e-02, -7.9174e-02, -3.5761e-02],
                [ 4.2970e-02,  2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
                [-2.5080e-02,  4.1672e-02,  4.2424e-02, -4.8704e-02, -6.0434e-02],
                [ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],

               [[-5.4303e-02,  6.7885e-02, -5.3922e-02,  6.5582e-02, -5.2617e-03],
                [-8.4440e-03,  8.0911e-02, -3.8667e-02, -5.6241e-04,  7.0876e-02],
                [ 5.4673e-02,  1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
                [-2.8238e-02, -5.0765e-02, -3.4076e-02,  3.1219e-02, -1.8919e-02],
                [-4.6076e-02, -7.6516e-02,  3.1247e-02,  4.1743e-02,  7.5575e-02]],

               [[ 3.8787e-03, -8.1731e-03,  4.7381e-03,  5.8261e-02,  4.6416e-02],
                [ 7.7171e-02, -6.5924e-02,  1.5769e-02,  2.6777e-02,  7.7365e-02],
                [-7.3126e-02,  4.7624e-02, -7.0620e-02,  7.3309e-02,  7.7585e-03],
                [-7.9208e-02, -7.3783e-03, -5.3142e-02,  3.4386e-02,  5.6230e-03],
                [-4.4492e-02,  7.5403e-02, -2.8887e-02,  2.6937e-02,  7.0698e-03]],

               [[-1.7104e-02, -6.6983e-02, -5.7655e-02,  5.1198e-02, -8.0137e-02],
                [-7.2406e-02, -3.7856e-02, -7.2086e-02,  7.2641e-02, -7.0749e-02],
                [ 6.6330e-02,  2.8797e-02,  6.2197e-02, -4.6068e-03,  1.7392e-03],
                [ 4.2384e-02,  5.9572e-02, -4.5953e-02,  6.6345e-03,  7.3979e-02],
                [-4.8313e-02, -1.3063e-02,  1.7648e-02, -5.0903e-02, -7.1852e-02]],

               [[ 3.2147e-02, -8.1585e-02,  6.7507e-02, -7.7056e-02,  1.7667e-02],
                [ 2.0535e-03, -7.4221e-02,  1.0343e-02,  4.3018e-02,  9.7351e-03],
                [-2.1410e-02,  5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
                [ 4.9847e-02,  6.9608e-02,  3.6233e-03,  5.7025e-02,  6.3206e-02],
                [ 1.4611e-02, -2.9885e-02,  5.6140e-02, -6.4338e-02,  8.5266e-03]]],


              ...,


              [[[ 5.7051e-02,  3.4026e-02, -3.7723e-02,  1.4372e-02, -4.4266e-03],
                [-8.0557e-02,  1.1810e-02, -6.9374e-02,  3.4264e-02, -3.9068e-02],
                [ 3.2814e-02, -4.9334e-02, -3.2234e-02,  3.7901e-02,  9.9268e-03],
                [ 1.2846e-03, -5.9199e-02, -5.6303e-02,  1.2189e-03,  7.8874e-02],
                [ 7.6858e-04,  2.4341e-02,  4.0423e-02, -7.7602e-02, -3.6388e-02]],

               [[ 5.9494e-02,  4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
                [-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
                [-6.8915e-02,  1.6417e-02, -9.0384e-03, -5.6058e-02,  1.1540e-02],
                [ 3.9632e-02,  3.8881e-02,  5.5834e-02,  7.5591e-02,  1.8463e-02],
                [ 4.5034e-02, -6.4665e-02,  6.7883e-02,  7.1108e-02,  8.0694e-02]],

               [[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
                [-7.9629e-02, -2.5316e-03, -3.4649e-02,  7.9736e-02,  2.4963e-02],
                [-1.4102e-02,  3.0896e-02, -5.4594e-02,  5.7641e-02,  7.8276e-02],
                [ 3.3722e-02,  1.6397e-02,  6.6251e-02,  2.5637e-02,  4.0073e-02],
                [ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],

               [[ 3.9434e-02,  6.7049e-02,  7.1627e-02,  6.9307e-02, -5.7508e-03],
                [ 2.8151e-02,  7.9890e-02, -6.4687e-02, -6.8959e-02,  6.8179e-02],
                [ 1.2583e-02,  6.6052e-02,  6.7770e-02,  1.0853e-02,  6.3935e-02],
                [ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
                [-1.1810e-04,  6.2292e-02, -2.1831e-02, -4.1282e-02,  3.4718e-02]],

               [[-8.9495e-03, -3.5923e-02, -4.9030e-02,  1.7068e-02,  5.7835e-02],
                [-6.2950e-02,  6.9258e-02,  1.4909e-02, -3.9252e-02,  3.0917e-02],
                [-5.0831e-02, -2.6109e-02, -4.2526e-02,  4.9180e-03, -6.7907e-02],
                [-1.4867e-02,  8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
                [ 6.5369e-02,  3.5118e-02, -3.5070e-02,  3.1514e-02, -1.7773e-02]],

               [[-1.9000e-02,  4.8772e-02, -4.0550e-02,  5.7766e-02, -4.8687e-02],
                [ 7.0112e-02,  7.4851e-02, -5.0324e-02,  4.2522e-02,  6.6367e-02],
                [-6.6793e-02, -6.3487e-02, -6.3574e-02,  7.3530e-02, -6.7062e-02],
                [ 1.9297e-02,  3.9876e-02,  7.0333e-03,  3.6541e-02,  3.0865e-02],
                [ 6.9009e-02,  2.7737e-03, -6.0400e-02,  1.5249e-03, -1.5177e-03]]],


              [[[-4.9098e-02, -1.2656e-02,  3.0326e-02,  4.6450e-02,  4.1143e-02],
                [ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02,  7.3691e-02],
                [-5.2880e-02, -5.3283e-02, -4.6874e-02,  2.0506e-02,  1.4432e-02],
                [ 5.3466e-05,  6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
                [-7.2209e-02, -2.8706e-02,  6.6109e-02,  5.8108e-02, -1.8114e-02]],

               [[-5.8877e-02,  3.5183e-02,  6.5460e-02,  5.2934e-02,  3.5997e-02],
                [-6.5718e-02,  2.7700e-02,  7.1110e-02, -5.7825e-02,  6.1866e-03],
                [-5.3281e-03, -7.6189e-02, -6.9421e-02,  6.4743e-02, -1.1912e-02],
                [ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02,  4.5729e-02],
                [ 2.0473e-03, -3.1893e-02, -3.0088e-02,  6.1322e-02, -1.3287e-02]],

               [[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
                [-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
                [-5.0579e-02,  5.1921e-02,  4.0006e-02, -5.3846e-02,  3.6710e-03],
                [-5.5284e-03,  5.2453e-02,  3.5617e-02, -4.4475e-02,  2.7835e-02],
                [ 3.6465e-02,  2.2936e-02,  4.9494e-02, -6.8768e-02, -6.8512e-02]],

               [[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
                [ 3.3837e-02,  6.9886e-02,  2.5937e-03, -4.4428e-02, -6.1442e-03],
                [-3.3799e-02,  7.6725e-02,  1.5202e-02,  2.7467e-02, -7.2112e-02],
                [-5.3887e-02,  5.3134e-02, -5.5426e-02,  8.1476e-02,  1.0773e-02],
                [-1.9578e-02,  1.7628e-02, -2.2382e-02,  6.7076e-02, -1.3475e-02]],

               [[ 3.7281e-02,  2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
                [-5.1809e-02, -1.0889e-02,  4.4019e-02, -4.0099e-02, -6.2939e-02],
                [ 7.8826e-02,  1.4336e-02, -7.8953e-02, -4.1699e-03,  2.1759e-02],
                [ 4.3422e-02,  6.1053e-02, -5.1035e-02,  2.5170e-02,  8.1194e-02],
                [-3.5907e-02,  3.5084e-02,  5.4858e-02,  5.7819e-02, -6.8527e-02]],

               [[ 6.0340e-02, -4.5873e-02,  4.5307e-02, -1.8559e-02, -5.9891e-02],
                [ 7.1101e-02,  5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
                [ 4.5972e-02,  4.6237e-02, -1.8353e-02,  5.5372e-03,  5.8802e-02],
                [-8.0939e-02,  2.2098e-03, -2.7943e-03,  6.9556e-02,  3.5299e-03],
                [-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],


              [[[-5.8326e-02,  4.3804e-02,  5.4642e-02,  2.9479e-02,  5.5766e-02],
                [-6.2955e-02,  4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
                [ 7.8974e-02,  1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
                [ 1.1472e-02,  6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
                [-5.7437e-02, -5.7150e-02,  4.9108e-02,  2.2168e-02, -5.4964e-02]],

               [[-3.2895e-02, -2.2746e-03,  6.8428e-02, -7.4781e-02,  6.5675e-02],
                [-8.0232e-02, -2.6468e-02, -2.1136e-02,  2.1449e-02,  6.4572e-02],
                [ 2.9930e-03,  1.1987e-02,  4.8122e-03,  3.4183e-02, -7.8918e-02],
                [ 6.3749e-02, -2.5083e-02,  1.1253e-02, -4.4485e-02,  3.3380e-02],
                [ 5.0096e-03, -1.7321e-02,  8.0185e-02, -2.3853e-02, -2.9333e-03]],

               [[ 2.6648e-02, -7.6799e-02,  3.2204e-03, -7.7476e-02, -4.4615e-03],
                [ 5.7110e-02,  7.8575e-02,  5.3204e-02, -7.8592e-02,  4.1383e-03],
                [ 1.6194e-02,  2.5400e-02,  7.4070e-02, -3.9092e-03, -2.9417e-02],
                [-7.9407e-02,  2.5042e-02, -3.8854e-02,  2.8143e-02,  2.8485e-03],
                [-3.3828e-02, -7.5645e-02,  7.8511e-02, -4.4048e-02,  6.0887e-02]],

               [[-6.4552e-02, -3.1646e-02,  6.5499e-02, -6.8577e-02, -5.1529e-02],
                [ 6.1176e-02, -4.8461e-02,  4.7687e-02, -3.0069e-02, -1.7665e-02],
                [ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
                [ 6.1360e-02, -1.9826e-02, -6.4593e-02,  3.5071e-02, -5.9178e-02],
                [-6.6739e-02,  2.6098e-02, -5.5998e-02,  8.1334e-02,  3.7472e-02]],

               [[-5.5207e-02,  1.4355e-02, -2.2037e-02, -2.4025e-02,  7.2631e-02],
                [-1.0448e-02,  1.9105e-03, -5.5223e-02,  4.6377e-02, -6.8534e-02],
                [-2.4292e-02,  7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
                [ 4.5334e-02, -2.3274e-02, -4.3572e-02,  4.3487e-03, -4.6057e-02],
                [-5.3757e-02, -2.0336e-02, -5.2245e-02,  2.2213e-02, -6.7578e-03]],

               [[ 5.7154e-02,  6.9033e-02, -2.7450e-02, -5.9039e-02,  3.0233e-02],
                [ 5.5904e-02,  5.2798e-02, -2.2586e-02,  2.8411e-02, -6.8010e-03],
                [ 5.1257e-02, -4.3710e-02,  8.7161e-03,  1.9411e-02, -3.5285e-03],
                [-8.0450e-02,  6.1012e-02, -7.7756e-02, -2.1472e-02,  4.7537e-02],
                [-4.7231e-02,  3.7300e-02,  2.7754e-02, -2.4025e-02,  1.0065e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[-9.3244e-02,  7.2930e-02, -1.4004e-02,  4.0448e-02,  4.7394e-02],
                [ 3.4764e-03,  3.3045e-03, -6.4835e-02,  4.5686e-02,  9.0246e-02],
                [ 2.8020e-02,  1.8739e-02, -5.1935e-02, -3.1779e-02, -6.5367e-02],
                [-1.0352e-01,  6.2800e-02,  3.3458e-02,  1.6324e-01,  7.5646e-02],
                [-1.4166e-02, -7.0325e-02, -2.7509e-02, -1.0614e-02, -7.1825e-02]],

               [[-1.2719e-01,  8.5667e-02,  3.7926e-02,  3.3971e-04, -1.0475e-02],
                [ 7.9021e-02,  3.0413e-03, -1.6426e-01,  1.3341e-02, -2.5737e-02],
                [ 2.6697e-02, -1.3385e-01,  6.9021e-03,  3.0326e-02,  6.2550e-02],
                [-1.3212e-02, -6.8929e-02,  1.4790e-01, -2.7115e-02,  2.0803e-02],
                [ 7.5669e-02,  1.1428e-01, -5.9129e-02,  7.8492e-02, -3.0813e-02]],

               [[-6.6693e-02, -1.3390e-01,  3.4061e-02, -1.3415e-02,  6.6936e-03],
                [ 8.5119e-02, -8.8495e-02,  5.1307e-03,  5.9909e-02, -7.1546e-02],
                [-3.1164e-02,  2.3079e-02,  4.8668e-02,  2.0298e-02, -1.6369e-02],
                [-1.4313e-01, -5.0021e-02, -5.2223e-02,  1.3387e-02, -1.0466e-01],
                [ 5.5770e-02,  1.9499e-02, -1.1884e-03,  1.0549e-02,  7.3979e-02]],

               [[-1.7652e-03, -8.5820e-03,  3.8955e-02, -5.1522e-04,  3.0524e-02],
                [-7.6496e-02, -1.0952e-01,  1.2580e-02, -1.0383e-01,  1.3392e-02],
                [ 5.0957e-02,  7.2177e-02,  6.7416e-02,  4.0250e-02, -6.1816e-02],
                [ 7.9146e-02, -2.6732e-02,  3.3583e-02, -5.8278e-02,  1.7562e-01],
                [-7.7750e-03,  1.5841e-02, -1.4416e-01,  7.1971e-02, -3.2296e-02]],

               [[ 3.1765e-02,  7.7612e-03, -9.2562e-03, -7.5009e-02, -7.6771e-02],
                [ 5.2192e-02,  8.9829e-02,  8.0666e-02, -7.5012e-02,  8.2851e-02],
                [-2.9721e-02,  5.4033e-02, -5.8591e-02, -8.4672e-02, -3.5956e-02],
                [ 4.6556e-03, -8.2645e-02, -2.9310e-02,  2.2807e-03,  6.9149e-02],
                [ 5.1460e-03, -5.1742e-02, -9.3557e-02, -2.8508e-03,  6.1673e-02]],

               [[-7.7168e-02,  7.8402e-02, -1.1065e-01,  3.5423e-02,  2.2741e-02],
                [-3.0649e-02,  5.4842e-02, -5.2164e-02, -7.5771e-02, -3.3854e-02],
                [-5.6595e-03, -2.0763e-02,  5.1015e-02,  6.8125e-04, -4.5188e-02],
                [-6.0728e-02,  1.5548e-02,  1.1190e-02, -3.3515e-02,  9.3201e-02],
                [-8.1993e-02,  5.8329e-02,  4.3559e-02, -1.3500e-01, -6.2694e-02]]],


              [[[-2.5201e-03,  1.0502e-01, -1.6741e-02,  7.6163e-02, -1.7975e-02],
                [ 7.7132e-02, -4.8635e-02,  2.0458e-02, -1.4585e-01,  2.5789e-02],
                [-2.3901e-02, -3.0057e-02,  1.0268e-01, -9.8991e-03, -1.2776e-02],
                [-3.0604e-02, -1.9044e-02, -1.0034e-01, -3.4933e-03, -1.0588e-01],
                [-4.3267e-02,  4.4136e-02, -1.4653e-02,  8.0375e-02, -4.3509e-03]],

               [[-7.2069e-02,  7.5176e-02,  3.5200e-02, -1.4117e-02,  1.4985e-01],
                [ 9.3505e-02, -1.5317e-02, -2.7400e-02,  7.2958e-02, -3.5223e-02],
                [ 1.0651e-01,  4.8038e-03, -4.8003e-03, -2.4901e-02,  5.0562e-02],
                [ 8.1387e-02, -1.0440e-01,  9.1612e-03, -1.0838e-01,  3.4357e-02],
                [ 3.7520e-02,  1.0998e-02, -1.2370e-01, -1.4700e-01,  2.2564e-02]],

               [[-2.6191e-02,  5.5962e-03,  5.2655e-02,  3.7240e-02, -9.0327e-04],
                [ 5.6995e-03, -5.7081e-03,  2.0693e-02,  1.2346e-02,  4.0957e-03],
                [ 9.9342e-02, -9.1494e-02, -7.2406e-02, -2.3302e-03, -4.5766e-02],
                [-9.5160e-02, -4.8321e-02,  5.4787e-02,  7.6552e-02,  6.8904e-02],
                [ 8.1214e-02,  4.4162e-02,  3.9048e-02,  2.1962e-01,  7.6550e-02]],

               [[ 4.3722e-02,  1.4723e-01,  9.8580e-02,  5.3083e-02,  1.7335e-02],
                [ 1.6842e-02,  9.2504e-02, -1.2577e-01,  4.3781e-02,  4.5560e-02],
                [-2.2627e-03,  4.4276e-02, -3.1316e-02, -3.0269e-03,  1.1357e-01],
                [ 1.1986e-01, -5.7009e-02, -1.8555e-02,  3.8186e-03,  1.8096e-01],
                [-2.2767e-02,  1.2045e-01, -2.4305e-02,  5.8296e-02,  9.8688e-02]],

               [[ 1.5614e-02, -5.0428e-02,  9.2436e-02,  1.5314e-02,  3.8010e-02],
                [-3.0040e-02, -4.5455e-02, -8.2217e-02, -4.5022e-03,  1.6350e-02],
                [-3.2035e-02, -7.1502e-02,  6.1119e-02,  1.0737e-02,  1.0597e-01],
                [ 1.1230e-01,  2.2913e-03,  2.2935e-02,  4.4280e-02,  4.6554e-02],
                [ 6.3816e-02,  5.0896e-02, -8.1320e-03, -5.9140e-02, -1.3816e-02]],

               [[ 2.7798e-02,  2.0180e-02, -6.8479e-02, -3.8563e-02, -2.1970e-02],
                [-1.1560e-02,  1.0920e-01,  4.0527e-02,  1.5456e-02, -3.0192e-02],
                [-3.4818e-02,  2.9930e-02,  3.8703e-02,  6.9933e-02,  1.3544e-01],
                [-1.1075e-02,  1.0404e-01,  2.5879e-03, -3.4677e-02,  1.3333e-01],
                [-9.6145e-02,  4.9468e-02,  2.3339e-02, -2.3944e-02,  2.2928e-02]]],


              [[[ 1.6603e-03, -7.8303e-02,  1.2174e-01,  1.7890e-02,  4.8606e-02],
                [-5.6957e-03, -1.3151e-01, -5.8330e-02,  2.0716e-02, -1.2664e-01],
                [ 1.0507e-01,  7.2393e-02,  3.5766e-02, -4.7793e-03, -1.0268e-02],
                [ 2.1223e-03,  2.5655e-02,  1.9555e-02,  1.4210e-01,  1.0301e-02],
                [-1.4564e-01, -1.8514e-02, -4.7687e-02,  1.5104e-01, -3.0745e-02]],

               [[-3.2338e-04,  2.1239e-02, -5.0826e-02,  6.4949e-02,  5.0461e-04],
                [-8.2764e-02, -1.5075e-02,  1.5609e-02, -1.3402e-01, -1.1833e-02],
                [ 4.4014e-02, -5.9378e-03, -1.2508e-01, -7.1479e-03,  6.9006e-02],
                [-8.8556e-03, -7.4991e-02,  1.4296e-02, -6.0615e-02,  3.2985e-02],
                [ 1.5668e-01, -8.7098e-02, -5.0354e-02, -1.2569e-01, -1.5741e-01]],

               [[-1.0422e-01,  1.4817e-01, -9.1310e-02,  5.6767e-02, -1.2098e-02],
                [-2.7485e-02,  1.3706e-01, -3.2490e-03, -6.4062e-02,  1.3419e-01],
                [ 1.0312e-02, -8.3901e-02, -1.7662e-02, -1.5195e-01, -8.3465e-02],
                [-5.6543e-04, -4.3555e-02, -5.6245e-02,  7.3568e-02,  2.1678e-02],
                [ 7.6391e-03, -4.8854e-03,  5.4257e-02,  7.8329e-03,  1.4124e-01]],

               [[-4.9119e-02, -5.2545e-02,  2.3897e-02,  3.6596e-04,  4.2827e-02],
                [ 2.3076e-02, -1.8356e-02,  2.9836e-02,  1.2485e-01,  1.2366e-01],
                [-7.8947e-02,  1.4472e-01, -4.3746e-02,  1.0565e-02,  4.6012e-02],
                [-5.7214e-02, -4.1273e-03, -5.6967e-02, -4.2106e-03,  7.4851e-02],
                [-9.0327e-02,  8.1366e-03, -5.6940e-02,  8.7756e-02, -1.0730e-01]],

               [[-2.5770e-02, -1.0837e-01, -6.5197e-02,  1.0767e-01, -1.8528e-02],
                [ 3.6587e-02, -1.0597e-01, -6.8699e-02,  2.4231e-02, -1.2268e-01],
                [ 1.2846e-01, -1.1661e-02,  5.9071e-02,  3.8168e-02,  8.9368e-02],
                [ 7.8908e-02,  3.6826e-02, -1.2949e-02, -5.2112e-02,  6.9411e-02],
                [-1.5190e-02, -6.7784e-02,  4.7824e-02, -4.4707e-02, -1.0409e-01]],

               [[ 6.1472e-02, -8.7888e-02,  9.1761e-02, -9.8433e-02, -1.2520e-02],
                [-3.4953e-02, -1.2213e-01,  5.5129e-02, -2.4383e-03, -3.0046e-02],
                [-2.7596e-02,  5.2428e-02,  1.3167e-01, -3.1463e-02,  1.8369e-02],
                [ 7.4151e-03,  9.0555e-02, -1.2976e-02,  4.0466e-02,  1.1523e-01],
                [ 1.1646e-01, -2.0131e-02,  7.8028e-02, -6.9566e-02, -5.8961e-04]]],


              ...,


              [[[ 1.8656e-02,  3.4279e-02, -1.0672e-01, -1.0023e-02,  2.2085e-02],
                [-4.6725e-02,  7.7616e-02, -1.4231e-02,  4.5028e-02, -2.9861e-02],
                [ 9.5183e-03, -1.7526e-03,  5.1253e-02,  7.1305e-02,  2.6603e-02],
                [-3.0437e-02, -7.2302e-02, -1.0535e-01,  1.4485e-02,  6.3855e-02],
                [-5.4683e-03,  5.6009e-02,  7.2122e-03, -1.0415e-01,  3.5188e-02]],

               [[ 1.2157e-01,  6.7376e-02,  2.1207e-02, -1.0825e-01, -1.5351e-01],
                [-7.0602e-02, -9.9218e-03,  3.5316e-02, -1.9505e-02, -8.3313e-02],
                [-1.7510e-02,  9.2322e-02, -6.4570e-02, -8.1725e-02,  1.3421e-02],
                [ 2.8115e-02, -4.7646e-03,  8.4901e-02,  3.6822e-02, -4.3886e-02],
                [-1.4940e-02, -7.2755e-02,  1.3097e-01,  4.6871e-02,  5.6414e-03]],

               [[-1.4081e-01,  4.8099e-02, -1.5321e-01,  3.9479e-02, -7.0164e-02],
                [-7.3554e-02, -1.1309e-02, -6.2749e-02,  1.2635e-01, -4.7698e-02],
                [ 1.8018e-02,  1.0312e-01, -4.4310e-02,  7.4449e-02,  2.4631e-02],
                [ 3.6606e-02, -2.5128e-02,  3.0890e-02,  3.3290e-03,  3.1577e-02],
                [-1.8735e-02, -1.2198e-01, -1.8037e-02, -9.1278e-02, -1.3295e-01]],

               [[-2.3376e-02,  6.9800e-02,  2.6887e-02, -1.7161e-02, -8.7163e-02],
                [ 1.5594e-02,  5.3034e-02, -1.1990e-01,  7.9062e-02,  1.0965e-01],
                [ 1.7833e-02,  8.8824e-02,  3.1692e-02,  8.7611e-02,  6.3664e-02],
                [-1.8869e-03, -1.1539e-01, -4.1857e-02, -8.9586e-02, -7.2488e-02],
                [ 1.9608e-02,  1.3425e-01, -8.0856e-02, -8.3590e-02, -4.5203e-02]],

               [[ 1.0801e-02,  1.5272e-02, -1.4688e-01, -3.7206e-02,  5.1565e-02],
                [-2.0180e-02, -5.6527e-03, -8.9918e-02, -7.9719e-02,  2.5003e-02],
                [ 4.3385e-02, -4.1702e-03,  6.6365e-02, -8.6634e-03, -1.1395e-01],
                [ 1.2979e-02,  7.9360e-02, -3.8312e-02, -5.8354e-02, -7.3427e-02],
                [ 9.6145e-03, -1.4716e-02, -1.0418e-02, -4.8538e-02, -4.1533e-02]],

               [[ 4.0994e-02, -6.5618e-03, -3.9249e-02,  1.4940e-01,  6.1836e-02],
                [ 1.4003e-01,  1.0897e-01, -9.9539e-02,  7.6101e-02,  4.7514e-02],
                [ 1.8425e-02, -4.1603e-02, -3.0696e-02,  3.7560e-02, -5.6115e-02],
                [ 1.1114e-01, -1.9242e-04,  1.1626e-02,  5.5963e-02, -4.1088e-02],
                [ 1.1746e-01, -1.4505e-02, -3.8933e-02, -9.7828e-03, -4.2146e-02]]],


              [[[ 2.7058e-02, -5.4692e-02, -1.6676e-02,  1.3220e-01,  6.2700e-02],
                [ 3.9145e-02, -6.8409e-02, -7.5723e-03,  6.6492e-03,  3.2803e-02],
                [-5.9434e-02, -1.6559e-02, -6.5511e-02, -1.8650e-02,  4.8143e-02],
                [-7.7847e-03,  9.0031e-02, -6.6468e-02, -5.4535e-02, -1.9140e-02],
                [-9.0769e-02, -1.1950e-01,  6.6162e-02,  7.2317e-02, -2.5983e-02]],

               [[-1.8429e-02,  3.7081e-02,  3.3936e-03,  8.1601e-02,  7.5250e-02],
                [-1.6394e-01, -1.9931e-02,  9.0637e-02, -1.5919e-01,  6.6636e-02],
                [ 3.0836e-02, -4.4388e-02, -1.4237e-01,  6.4615e-02, -1.9187e-02],
                [ 2.0550e-01,  5.2410e-02, -8.6873e-02,  9.8989e-03,  1.0537e-02],
                [ 9.6366e-02, -4.6873e-02,  2.8437e-02,  9.9133e-02, -1.2426e-02]],

               [[ 4.3380e-02, -1.6536e-02, -1.1810e-01, -1.5150e-02, -3.4028e-02],
                [-5.2022e-02, -4.7191e-02,  3.4113e-02,  2.2352e-02, -1.6375e-02],
                [-2.4627e-03,  6.9880e-02,  3.8121e-02, -1.2208e-01,  4.7194e-02],
                [ 3.0015e-02,  8.7410e-03, -2.2237e-02,  7.8091e-02,  7.9454e-02],
                [-5.1503e-02,  2.3619e-02,  8.4588e-02, -1.0277e-01,  8.3154e-03]],

               [[-6.7868e-02, -1.2149e-01, -1.7491e-02,  1.0022e-01, -4.3358e-02],
                [ 3.1464e-02,  7.4058e-02, -3.2289e-03, -2.6757e-02, -8.5261e-02],
                [ 6.9376e-02,  1.2054e-01,  3.5335e-02,  9.2568e-02,  2.1358e-02],
                [-4.7479e-02,  2.5571e-02, -4.7574e-02,  1.1774e-02, -2.6935e-02],
                [-4.9366e-02,  1.1098e-02, -4.2080e-02,  8.9565e-02,  4.6191e-03]],

               [[-1.7267e-02,  5.0607e-02, -4.9050e-02, -1.0601e-02, -5.8215e-02],
                [-1.5839e-02, -2.6607e-02, -1.8453e-02, -7.0404e-02, -1.2247e-01],
                [ 1.5462e-01,  4.8792e-03, -1.8811e-01,  5.1346e-03,  3.5262e-02],
                [ 6.8772e-02,  6.6209e-02, -8.8063e-02,  1.0676e-01,  1.2461e-01],
                [ 4.1533e-03, -4.0087e-02,  6.4295e-02, -8.9270e-03, -4.2519e-02]],

               [[ 4.6680e-02, -2.2108e-02,  4.5739e-02, -6.1995e-02, -7.3645e-02],
                [ 1.2634e-01,  5.2019e-02, -1.0544e-01, -1.1378e-02,  4.0753e-03],
                [ 7.9469e-02,  3.9514e-02, -1.6562e-02,  3.0074e-02,  3.8365e-02],
                [-1.0846e-01, -1.3977e-01, -4.5255e-03, -7.7981e-03, -2.7929e-02],
                [ 1.2333e-02, -7.6110e-02, -1.5976e-02,  2.1001e-02, -2.0793e-02]]],


              [[[ 3.5135e-02,  3.3776e-02, -2.1454e-02,  4.0599e-02,  7.6321e-02],
                [-2.2444e-02,  5.9005e-02, -9.7919e-02, -5.0292e-02, -2.5280e-02],
                [ 8.2182e-02,  1.1689e-02, -4.0238e-02, -1.5465e-02, -1.0631e-01],
                [-7.9225e-02,  3.7144e-02, -2.1691e-02,  7.6563e-02, -2.9466e-02],
                [-6.8622e-02, -8.2420e-02,  1.0508e-01,  5.7378e-02, -5.2815e-02]],

               [[-8.5097e-02,  8.7198e-02,  1.3588e-01, -1.0821e-02,  6.6252e-02],
                [-8.5424e-02, -1.2942e-01, -3.9451e-02,  2.3167e-03,  9.8789e-02],
                [ 4.4885e-02,  7.8130e-02,  6.4348e-02,  1.4060e-01, -6.9407e-02],
                [ 6.7809e-02, -1.0070e-01,  8.3850e-03,  2.3277e-02,  7.3318e-02],
                [ 1.9004e-02,  3.3662e-04,  1.2365e-01,  3.7783e-02, -4.4722e-02]],

               [[-3.6487e-03, -3.5660e-02, -3.1783e-03, -1.0376e-01, -2.5887e-02],
                [ 8.2903e-02,  1.1622e-01,  7.4534e-02, -6.7094e-02,  2.8624e-02],
                [ 5.6662e-03,  7.5438e-02,  2.3080e-02, -2.7497e-02, -4.4862e-02],
                [-1.0469e-01, -1.1870e-03, -1.0996e-01,  1.3349e-01,  4.3062e-02],
                [-5.9718e-02, -1.3859e-02, -5.2602e-03,  4.7957e-02,  6.4259e-02]],

               [[ 1.6269e-02, -5.9308e-02,  4.2326e-02, -1.7657e-01, -3.3243e-02],
                [ 2.8233e-02, -3.1127e-02,  2.9193e-02, -9.8563e-02, -2.3034e-02],
                [ 6.9031e-02, -3.6136e-02, -6.7969e-02, -4.8913e-02, -5.7119e-02],
                [ 1.7338e-02, -1.3542e-02, -8.9504e-02,  7.9912e-02, -6.6703e-02],
                [-4.8714e-02, -1.3528e-02, -9.2340e-02,  1.3472e-01,  1.0488e-01]],

               [[-8.7143e-02,  4.2698e-02, -6.4632e-02,  1.7840e-02,  5.3781e-02],
                [ 1.1379e-02,  3.3266e-02, -1.4602e-02,  7.8038e-02, -1.2929e-01],
                [-1.1283e-01,  9.1412e-02, -1.4805e-02, -2.7218e-02, -1.6207e-01],
                [-4.6992e-04, -6.5583e-02, -3.9846e-02, -5.5303e-02, -8.9434e-02],
                [ 1.8681e-03, -7.4996e-02,  5.9021e-02,  2.6332e-02,  2.7155e-03]],

               [[ 4.9158e-02,  7.0112e-02, -1.1107e-01, -1.1723e-01,  8.6018e-02],
                [ 1.0823e-02,  6.9816e-02, -3.1992e-02,  4.0998e-02, -8.9671e-02],
                [ 2.2780e-02, -3.4835e-02,  5.5046e-03,  4.8949e-02, -4.9500e-02],
                [-6.3235e-02, -4.4342e-02, -2.1533e-01, -3.0609e-02,  4.5277e-02],
                [-5.5601e-02, -3.0792e-02,  4.2926e-02,  2.9293e-02, -2.3960e-02]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345,  0.0684,  0.0584, -0.0804,
               0.0098, -0.0490, -0.0535,  0.0145,  0.0056,  0.0082, -0.0256,  0.0140],
             requires_grad=True)
       tensor: tensor([ 0.0483, -0.0138, -0.0666, -0.0158,  0.0085,  0.0382,  0.0643, -0.0354,
               0.1643,  0.0064, -0.0694,  0.0054,  0.0250,  0.0626, -0.1082,  0.0414],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.]],

               [[-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., -0.]],

               [[-0., -0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., 0.]],

               [[0., 0., 0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.]]],


              [[[-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[-0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.]],

               [[0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., -0., 0.]]],


              ...,


              [[[0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.]],

               [[0., 0., 0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., 0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., -0.]],

               [[-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., 0., -0.]]],


              [[[-0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., 0., -0.]],

               [[0., -0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.]]],


              [[[-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., -0., -0.]],

               [[0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., -0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.]],

               [[0., 0., -0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-2.9409e-02,  6.0003e-02,  5.3245e-02,  2.8255e-02,  5.1865e-02],
                [-1.3120e-03,  6.9924e-02, -2.1958e-02,  4.5196e-02,  5.5987e-02],
                [ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
                [-6.1391e-02,  2.6227e-02,  2.2223e-02,  5.1530e-02,  7.8572e-02],
                [ 1.1986e-02, -3.5154e-02,  2.2590e-02, -2.8268e-02, -4.1092e-02]],

               [[-7.7353e-02, -3.3001e-03,  2.2446e-02,  3.6329e-02, -8.4312e-03],
                [ 5.9898e-02,  3.3921e-02, -4.0756e-02,  5.6990e-02, -7.9915e-02],
                [ 1.1058e-02, -7.5468e-02,  3.1362e-02, -1.2730e-02,  7.0498e-02],
                [-6.7726e-03, -3.9453e-02,  6.8676e-02,  2.8066e-02,  4.1992e-02],
                [ 8.0445e-02,  6.9594e-02,  2.9287e-02,  4.4932e-02, -7.2432e-02]],

               [[-6.8457e-02, -7.5792e-02,  3.8843e-02,  3.2058e-02,  2.1625e-02],
                [ 7.5601e-02, -5.7455e-02,  1.2939e-02, -3.5165e-02, -5.6556e-02],
                [ 8.3346e-03,  2.2498e-02,  6.8399e-03, -2.3131e-02, -2.1235e-03],
                [-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
                [ 4.4843e-02,  1.4243e-02, -7.8475e-02,  2.0007e-02,  6.8106e-02]],

               [[-7.1642e-03, -4.0418e-02,  1.6455e-03, -7.4808e-03,  3.0932e-02],
                [-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02,  7.4918e-02],
                [ 2.7178e-03,  7.6559e-02,  2.3546e-02,  1.0888e-02, -3.2943e-02],
                [-3.1639e-02, -4.4183e-02, -1.1435e-02,  1.8024e-02,  7.9902e-02],
                [-1.6372e-02,  6.6337e-03, -5.5328e-02,  1.6804e-02, -1.1563e-02]],

               [[ 3.4681e-02,  3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
                [ 8.0547e-02,  4.9685e-02,  7.5298e-02, -2.9211e-02,  2.2224e-02],
                [-2.4409e-02,  7.0003e-02,  2.1558e-04, -4.4977e-02,  2.3133e-02],
                [-7.4964e-02, -6.4867e-02, -4.5671e-02,  3.1878e-02,  3.5104e-02],
                [ 2.2541e-02,  3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],

               [[-7.0267e-02, -2.8274e-03, -3.7755e-02,  3.8392e-02, -1.3160e-05],
                [-2.4893e-02,  2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
                [ 4.5711e-02, -3.3377e-02,  2.4278e-02, -4.1101e-03,  1.8952e-02],
                [ 2.0236e-02,  6.4522e-03, -5.7711e-02, -4.5909e-02,  7.4404e-02],
                [-3.7771e-02, -3.1536e-03,  6.4748e-02, -4.4604e-02, -2.2752e-02]]],


              [[[-2.9177e-02,  5.0791e-02, -5.5935e-03,  4.7169e-02, -1.1366e-02],
                [ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02,  6.6127e-02],
                [-1.1869e-03,  7.9363e-03,  6.0124e-02,  8.7514e-04, -2.6833e-02],
                [-4.4174e-02,  1.9975e-02,  2.7939e-02,  2.7791e-02, -1.7919e-02],
                [-2.9191e-02,  2.5500e-02,  3.6869e-03,  3.1063e-02, -3.7672e-02]],

               [[-6.7084e-02,  5.3586e-02,  1.8485e-02, -7.8099e-03,  4.6288e-02],
                [ 4.9055e-02, -1.4634e-02,  3.2799e-02,  3.9726e-02, -7.0032e-02],
                [ 7.0897e-02,  3.4370e-02, -1.4814e-02, -3.9030e-02,  3.0867e-02],
                [ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02,  2.6245e-02],
                [ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02,  3.3599e-02]],

               [[ 1.1310e-02, -2.9816e-02,  6.3727e-02,  3.9850e-02,  1.3761e-02],
                [ 3.0453e-02, -4.8504e-02,  5.3189e-02,  1.9425e-02,  4.7484e-02],
                [ 6.1376e-02, -7.8290e-02, -6.8859e-02,  1.8497e-02, -1.1496e-02],
                [-7.8178e-02, -4.5904e-02,  7.3181e-02,  2.9441e-02,  4.6967e-02],
                [ 7.6978e-02,  7.2934e-02,  5.6798e-02,  5.8828e-02,  4.4637e-02]],

               [[ 6.0281e-02,  7.7289e-02,  7.9016e-02, -4.1437e-02,  3.1101e-02],
                [-5.0620e-02,  3.3108e-02, -5.8687e-02,  2.7694e-02,  5.4294e-02],
                [ 2.1156e-02,  1.7004e-03,  2.4742e-02,  6.9593e-02,  5.7699e-02],
                [ 6.8876e-02,  3.2239e-02,  3.3322e-02,  2.9973e-02,  7.4267e-02],
                [-2.4736e-03,  3.8454e-02, -2.5898e-02,  2.0443e-02,  6.0816e-02]],

               [[ 1.6487e-02, -6.8994e-03,  5.2835e-02,  5.7784e-02, -1.7036e-02],
                [-7.1883e-02,  3.1576e-04, -5.5839e-02,  1.4949e-02, -1.9834e-03],
                [-2.7395e-02,  3.9861e-04, -1.0588e-02,  9.9140e-03,  5.1499e-02],
                [ 7.4060e-02, -1.0655e-02,  1.1668e-02,  4.9183e-02,  5.2846e-02],
                [ 2.8634e-02,  4.2678e-02, -1.4281e-02,  1.3904e-03,  7.6289e-02]],

               [[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
                [-4.1579e-02,  2.8281e-02,  3.9429e-02,  7.5058e-03,  6.5170e-03],
                [-3.4494e-02,  7.5710e-02,  4.1078e-02,  4.4451e-02,  4.2661e-02],
                [-5.4398e-02,  5.1592e-02, -2.6367e-02, -3.2980e-02,  5.3860e-02],
                [-4.2436e-02, -8.4286e-03,  7.5331e-02, -6.6725e-02,  4.9887e-02]]],


              [[[-1.5211e-02, -6.2506e-03,  3.0621e-03,  3.0725e-02, -7.0877e-03],
                [ 1.1974e-02, -5.2611e-02, -2.7415e-02,  4.3479e-02, -4.2108e-02],
                [ 3.3816e-02,  6.1523e-02, -9.9011e-03, -3.7770e-02,  6.5915e-04],
                [ 5.3678e-03,  5.9921e-02, -3.4530e-02,  5.1942e-02,  5.3762e-02],
                [-4.7293e-02, -6.2274e-02, -7.5059e-02,  8.1645e-02,  2.1149e-02]],

               [[-1.4459e-02, -2.7155e-02, -2.5730e-02,  7.6751e-02, -1.6932e-02],
                [-5.3342e-02, -2.6885e-02,  4.3476e-02, -7.9174e-02, -3.5761e-02],
                [ 4.2970e-02,  2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
                [-2.5080e-02,  4.1672e-02,  4.2424e-02, -4.8704e-02, -6.0434e-02],
                [ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],

               [[-5.4303e-02,  6.7885e-02, -5.3922e-02,  6.5582e-02, -5.2617e-03],
                [-8.4440e-03,  8.0911e-02, -3.8667e-02, -5.6241e-04,  7.0876e-02],
                [ 5.4673e-02,  1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
                [-2.8238e-02, -5.0765e-02, -3.4076e-02,  3.1219e-02, -1.8919e-02],
                [-4.6076e-02, -7.6516e-02,  3.1247e-02,  4.1743e-02,  7.5575e-02]],

               [[ 3.8787e-03, -8.1731e-03,  4.7381e-03,  5.8261e-02,  4.6416e-02],
                [ 7.7171e-02, -6.5924e-02,  1.5769e-02,  2.6777e-02,  7.7365e-02],
                [-7.3126e-02,  4.7624e-02, -7.0620e-02,  7.3309e-02,  7.7585e-03],
                [-7.9208e-02, -7.3783e-03, -5.3142e-02,  3.4386e-02,  5.6230e-03],
                [-4.4492e-02,  7.5403e-02, -2.8887e-02,  2.6937e-02,  7.0698e-03]],

               [[-1.7104e-02, -6.6983e-02, -5.7655e-02,  5.1198e-02, -8.0137e-02],
                [-7.2406e-02, -3.7856e-02, -7.2086e-02,  7.2641e-02, -7.0749e-02],
                [ 6.6330e-02,  2.8797e-02,  6.2197e-02, -4.6068e-03,  1.7392e-03],
                [ 4.2384e-02,  5.9572e-02, -4.5953e-02,  6.6345e-03,  7.3979e-02],
                [-4.8313e-02, -1.3063e-02,  1.7648e-02, -5.0903e-02, -7.1852e-02]],

               [[ 3.2147e-02, -8.1585e-02,  6.7507e-02, -7.7056e-02,  1.7667e-02],
                [ 2.0535e-03, -7.4221e-02,  1.0343e-02,  4.3018e-02,  9.7351e-03],
                [-2.1410e-02,  5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
                [ 4.9847e-02,  6.9608e-02,  3.6233e-03,  5.7025e-02,  6.3206e-02],
                [ 1.4611e-02, -2.9885e-02,  5.6140e-02, -6.4338e-02,  8.5266e-03]]],


              ...,


              [[[ 5.7051e-02,  3.4026e-02, -3.7723e-02,  1.4372e-02, -4.4266e-03],
                [-8.0557e-02,  1.1810e-02, -6.9374e-02,  3.4264e-02, -3.9068e-02],
                [ 3.2814e-02, -4.9334e-02, -3.2234e-02,  3.7901e-02,  9.9268e-03],
                [ 1.2846e-03, -5.9199e-02, -5.6303e-02,  1.2189e-03,  7.8874e-02],
                [ 7.6858e-04,  2.4341e-02,  4.0423e-02, -7.7602e-02, -3.6388e-02]],

               [[ 5.9494e-02,  4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
                [-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
                [-6.8915e-02,  1.6417e-02, -9.0384e-03, -5.6058e-02,  1.1540e-02],
                [ 3.9632e-02,  3.8881e-02,  5.5834e-02,  7.5591e-02,  1.8463e-02],
                [ 4.5034e-02, -6.4665e-02,  6.7883e-02,  7.1108e-02,  8.0694e-02]],

               [[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
                [-7.9629e-02, -2.5316e-03, -3.4649e-02,  7.9736e-02,  2.4963e-02],
                [-1.4102e-02,  3.0896e-02, -5.4594e-02,  5.7641e-02,  7.8276e-02],
                [ 3.3722e-02,  1.6397e-02,  6.6251e-02,  2.5637e-02,  4.0073e-02],
                [ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],

               [[ 3.9434e-02,  6.7049e-02,  7.1627e-02,  6.9307e-02, -5.7508e-03],
                [ 2.8151e-02,  7.9890e-02, -6.4687e-02, -6.8959e-02,  6.8179e-02],
                [ 1.2583e-02,  6.6052e-02,  6.7770e-02,  1.0853e-02,  6.3935e-02],
                [ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
                [-1.1810e-04,  6.2292e-02, -2.1831e-02, -4.1282e-02,  3.4718e-02]],

               [[-8.9495e-03, -3.5923e-02, -4.9030e-02,  1.7068e-02,  5.7835e-02],
                [-6.2950e-02,  6.9258e-02,  1.4909e-02, -3.9252e-02,  3.0917e-02],
                [-5.0831e-02, -2.6109e-02, -4.2526e-02,  4.9180e-03, -6.7907e-02],
                [-1.4867e-02,  8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
                [ 6.5369e-02,  3.5118e-02, -3.5070e-02,  3.1514e-02, -1.7773e-02]],

               [[-1.9000e-02,  4.8772e-02, -4.0550e-02,  5.7766e-02, -4.8687e-02],
                [ 7.0112e-02,  7.4851e-02, -5.0324e-02,  4.2522e-02,  6.6367e-02],
                [-6.6793e-02, -6.3487e-02, -6.3574e-02,  7.3530e-02, -6.7062e-02],
                [ 1.9297e-02,  3.9876e-02,  7.0333e-03,  3.6541e-02,  3.0865e-02],
                [ 6.9009e-02,  2.7737e-03, -6.0400e-02,  1.5249e-03, -1.5177e-03]]],


              [[[-4.9098e-02, -1.2656e-02,  3.0326e-02,  4.6450e-02,  4.1143e-02],
                [ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02,  7.3691e-02],
                [-5.2880e-02, -5.3283e-02, -4.6874e-02,  2.0506e-02,  1.4432e-02],
                [ 5.3466e-05,  6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
                [-7.2209e-02, -2.8706e-02,  6.6109e-02,  5.8108e-02, -1.8114e-02]],

               [[-5.8877e-02,  3.5183e-02,  6.5460e-02,  5.2934e-02,  3.5997e-02],
                [-6.5718e-02,  2.7700e-02,  7.1110e-02, -5.7825e-02,  6.1866e-03],
                [-5.3281e-03, -7.6189e-02, -6.9421e-02,  6.4743e-02, -1.1912e-02],
                [ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02,  4.5729e-02],
                [ 2.0473e-03, -3.1893e-02, -3.0088e-02,  6.1322e-02, -1.3287e-02]],

               [[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
                [-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
                [-5.0579e-02,  5.1921e-02,  4.0006e-02, -5.3846e-02,  3.6710e-03],
                [-5.5284e-03,  5.2453e-02,  3.5617e-02, -4.4475e-02,  2.7835e-02],
                [ 3.6465e-02,  2.2936e-02,  4.9494e-02, -6.8768e-02, -6.8512e-02]],

               [[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
                [ 3.3837e-02,  6.9886e-02,  2.5937e-03, -4.4428e-02, -6.1442e-03],
                [-3.3799e-02,  7.6725e-02,  1.5202e-02,  2.7467e-02, -7.2112e-02],
                [-5.3887e-02,  5.3134e-02, -5.5426e-02,  8.1476e-02,  1.0773e-02],
                [-1.9578e-02,  1.7628e-02, -2.2382e-02,  6.7076e-02, -1.3475e-02]],

               [[ 3.7281e-02,  2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
                [-5.1809e-02, -1.0889e-02,  4.4019e-02, -4.0099e-02, -6.2939e-02],
                [ 7.8826e-02,  1.4336e-02, -7.8953e-02, -4.1699e-03,  2.1759e-02],
                [ 4.3422e-02,  6.1053e-02, -5.1035e-02,  2.5170e-02,  8.1194e-02],
                [-3.5907e-02,  3.5084e-02,  5.4858e-02,  5.7819e-02, -6.8527e-02]],

               [[ 6.0340e-02, -4.5873e-02,  4.5307e-02, -1.8559e-02, -5.9891e-02],
                [ 7.1101e-02,  5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
                [ 4.5972e-02,  4.6237e-02, -1.8353e-02,  5.5372e-03,  5.8802e-02],
                [-8.0939e-02,  2.2098e-03, -2.7943e-03,  6.9556e-02,  3.5299e-03],
                [-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],


              [[[-5.8326e-02,  4.3804e-02,  5.4642e-02,  2.9479e-02,  5.5766e-02],
                [-6.2955e-02,  4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
                [ 7.8974e-02,  1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
                [ 1.1472e-02,  6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
                [-5.7437e-02, -5.7150e-02,  4.9108e-02,  2.2168e-02, -5.4964e-02]],

               [[-3.2895e-02, -2.2746e-03,  6.8428e-02, -7.4781e-02,  6.5675e-02],
                [-8.0232e-02, -2.6468e-02, -2.1136e-02,  2.1449e-02,  6.4572e-02],
                [ 2.9930e-03,  1.1987e-02,  4.8122e-03,  3.4183e-02, -7.8918e-02],
                [ 6.3749e-02, -2.5083e-02,  1.1253e-02, -4.4485e-02,  3.3380e-02],
                [ 5.0096e-03, -1.7321e-02,  8.0185e-02, -2.3853e-02, -2.9333e-03]],

               [[ 2.6648e-02, -7.6799e-02,  3.2204e-03, -7.7476e-02, -4.4615e-03],
                [ 5.7110e-02,  7.8575e-02,  5.3204e-02, -7.8592e-02,  4.1383e-03],
                [ 1.6194e-02,  2.5400e-02,  7.4070e-02, -3.9092e-03, -2.9417e-02],
                [-7.9407e-02,  2.5042e-02, -3.8854e-02,  2.8143e-02,  2.8485e-03],
                [-3.3828e-02, -7.5645e-02,  7.8511e-02, -4.4048e-02,  6.0887e-02]],

               [[-6.4552e-02, -3.1646e-02,  6.5499e-02, -6.8577e-02, -5.1529e-02],
                [ 6.1176e-02, -4.8461e-02,  4.7687e-02, -3.0069e-02, -1.7665e-02],
                [ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
                [ 6.1360e-02, -1.9826e-02, -6.4593e-02,  3.5071e-02, -5.9178e-02],
                [-6.6739e-02,  2.6098e-02, -5.5998e-02,  8.1334e-02,  3.7472e-02]],

               [[-5.5207e-02,  1.4355e-02, -2.2037e-02, -2.4025e-02,  7.2631e-02],
                [-1.0448e-02,  1.9105e-03, -5.5223e-02,  4.6377e-02, -6.8534e-02],
                [-2.4292e-02,  7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
                [ 4.5334e-02, -2.3274e-02, -4.3572e-02,  4.3487e-03, -4.6057e-02],
                [-5.3757e-02, -2.0336e-02, -5.2245e-02,  2.2213e-02, -6.7578e-03]],

               [[ 5.7154e-02,  6.9033e-02, -2.7450e-02, -5.9039e-02,  3.0233e-02],
                [ 5.5904e-02,  5.2798e-02, -2.2586e-02,  2.8411e-02, -6.8010e-03],
                [ 5.1257e-02, -4.3710e-02,  8.7161e-03,  1.9411e-02, -3.5285e-03],
                [-8.0450e-02,  6.1012e-02, -7.7756e-02, -2.1472e-02,  4.7537e-02],
                [-4.7231e-02,  3.7300e-02,  2.7754e-02, -2.4025e-02,  1.0065e-02]]]])
      (bias): Normal:
       loc: tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345,  0.0684,  0.0584, -0.0804,
               0.0098, -0.0490, -0.0535,  0.0145,  0.0056,  0.0082, -0.0256,  0.0140])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[-0.0489, -0.0457,  0.0358,  ...,  0.0488, -0.0310, -0.0318],
              [-0.0198,  0.0492, -0.0495,  ...,  0.0437, -0.0228, -0.0161],
              [ 0.0042,  0.0213, -0.0018,  ..., -0.0004,  0.0377,  0.0324],
              ...,
              [ 0.0020, -0.0197,  0.0377,  ..., -0.0133, -0.0496,  0.0166],
              [ 0.0128, -0.0165,  0.0298,  ..., -0.0352,  0.0281,  0.0219],
              [ 0.0448, -0.0166, -0.0012,  ..., -0.0042, -0.0289, -0.0339]],
             requires_grad=True)
       tensor: tensor([[-0.0611, -0.0210, -0.0065,  ...,  0.1009, -0.0345, -0.0676],
              [-0.0331,  0.0645,  0.0254,  ...,  0.0572, -0.0789,  0.0253],
              [-0.0835,  0.0054, -0.0547,  ...,  0.0489, -0.0221, -0.0500],
              ...,
              [-0.0506, -0.0202,  0.0091,  ..., -0.0377, -0.0570,  0.0214],
              [ 0.0031,  0.0121,  0.0345,  ..., -0.0832,  0.0463,  0.1483],
              [ 0.0786, -0.0678,  0.0186,  ..., -0.0506, -0.0724, -0.0756]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0350,  0.0458, -0.0330,  0.0477,  0.0383, -0.0136, -0.0182,  0.0285,
               0.0197, -0.0431,  0.0120, -0.0445,  0.0171, -0.0019, -0.0141, -0.0021,
              -0.0429, -0.0159,  0.0028,  0.0272, -0.0290,  0.0047,  0.0452, -0.0022,
               0.0279,  0.0323, -0.0433,  0.0049,  0.0063, -0.0388,  0.0090, -0.0233,
               0.0251,  0.0375,  0.0274, -0.0337,  0.0122,  0.0217,  0.0230, -0.0405,
              -0.0476, -0.0063, -0.0021,  0.0267,  0.0014,  0.0228, -0.0130, -0.0471,
              -0.0170, -0.0349, -0.0472,  0.0116,  0.0002, -0.0426, -0.0129,  0.0492,
               0.0117, -0.0143, -0.0025,  0.0040, -0.0466, -0.0037,  0.0341, -0.0261,
               0.0327, -0.0433,  0.0025,  0.0201,  0.0211, -0.0235,  0.0472, -0.0291,
               0.0431, -0.0314, -0.0255,  0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
              -0.0305, -0.0172,  0.0238,  0.0029, -0.0029,  0.0172,  0.0227,  0.0006,
               0.0120, -0.0068, -0.0043, -0.0289,  0.0060,  0.0199,  0.0122,  0.0423,
              -0.0015, -0.0034, -0.0201, -0.0374,  0.0159, -0.0258,  0.0075, -0.0097,
              -0.0048,  0.0477, -0.0470,  0.0045,  0.0128, -0.0441,  0.0218,  0.0365,
              -0.0206,  0.0348, -0.0249,  0.0256,  0.0222,  0.0019,  0.0289,  0.0248],
             requires_grad=True)
       tensor: tensor([-0.0037,  0.0363, -0.0302,  0.0283, -0.0190,  0.0309,  0.0084, -0.0720,
               0.0065, -0.0307,  0.0175, -0.0668,  0.0494,  0.0374,  0.0391,  0.0511,
              -0.0019, -0.0578, -0.0897,  0.0769, -0.0574,  0.0433,  0.0668,  0.0505,
              -0.0497,  0.0800, -0.1190,  0.0599,  0.0376, -0.0340,  0.0294, -0.0257,
               0.0688,  0.0540,  0.0491, -0.0268,  0.1529, -0.0314,  0.0392, -0.1326,
              -0.0932, -0.0667,  0.0950,  0.0431,  0.1215, -0.0887,  0.0444, -0.0129,
               0.0161, -0.0260,  0.0330,  0.0153, -0.0312, -0.1064,  0.0864,  0.0355,
               0.0347, -0.0217, -0.0082,  0.0588, -0.0612, -0.0228,  0.1342, -0.1155,
               0.0447,  0.0436, -0.0388, -0.0279,  0.0129, -0.0139,  0.0634, -0.0067,
               0.0634, -0.0554,  0.0294, -0.0108, -0.0582, -0.1314, -0.0447, -0.0445,
               0.0491, -0.0629,  0.1620,  0.0590, -0.0344, -0.1214, -0.0177, -0.0116,
               0.0295, -0.0703, -0.0219, -0.1209,  0.0475,  0.0840, -0.1216,  0.0770,
               0.0231, -0.0501,  0.0136, -0.0599,  0.0445, -0.0822, -0.0009,  0.1143,
               0.0308, -0.0435, -0.0758, -0.0836, -0.0294,  0.0312,  0.0669,  0.0201,
               0.0623,  0.0240, -0.0735,  0.0121,  0.0403,  0.0362,  0.0488,  0.0317],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., -0., 0.,  ..., 0., -0., -0.],
              [-0., 0., -0.,  ..., 0., -0., -0.],
              [0., 0., -0.,  ..., -0., 0., 0.],
              ...,
              [0., -0., 0.,  ..., -0., -0., 0.],
              [0., -0., 0.,  ..., -0., 0., 0.],
              [0., -0., -0.,  ..., -0., -0., -0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-0.0489, -0.0457,  0.0358,  ...,  0.0488, -0.0310, -0.0318],
              [-0.0198,  0.0492, -0.0495,  ...,  0.0437, -0.0228, -0.0161],
              [ 0.0042,  0.0213, -0.0018,  ..., -0.0004,  0.0377,  0.0324],
              ...,
              [ 0.0020, -0.0197,  0.0377,  ..., -0.0133, -0.0496,  0.0166],
              [ 0.0128, -0.0165,  0.0298,  ..., -0.0352,  0.0281,  0.0219],
              [ 0.0448, -0.0166, -0.0012,  ..., -0.0042, -0.0289, -0.0339]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
              0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
              -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
              0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0350,  0.0458, -0.0330,  0.0477,  0.0383, -0.0136, -0.0182,  0.0285,
               0.0197, -0.0431,  0.0120, -0.0445,  0.0171, -0.0019, -0.0141, -0.0021,
              -0.0429, -0.0159,  0.0028,  0.0272, -0.0290,  0.0047,  0.0452, -0.0022,
               0.0279,  0.0323, -0.0433,  0.0049,  0.0063, -0.0388,  0.0090, -0.0233,
               0.0251,  0.0375,  0.0274, -0.0337,  0.0122,  0.0217,  0.0230, -0.0405,
              -0.0476, -0.0063, -0.0021,  0.0267,  0.0014,  0.0228, -0.0130, -0.0471,
              -0.0170, -0.0349, -0.0472,  0.0116,  0.0002, -0.0426, -0.0129,  0.0492,
               0.0117, -0.0143, -0.0025,  0.0040, -0.0466, -0.0037,  0.0341, -0.0261,
               0.0327, -0.0433,  0.0025,  0.0201,  0.0211, -0.0235,  0.0472, -0.0291,
               0.0431, -0.0314, -0.0255,  0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
              -0.0305, -0.0172,  0.0238,  0.0029, -0.0029,  0.0172,  0.0227,  0.0006,
               0.0120, -0.0068, -0.0043, -0.0289,  0.0060,  0.0199,  0.0122,  0.0423,
              -0.0015, -0.0034, -0.0201, -0.0374,  0.0159, -0.0258,  0.0075, -0.0097,
              -0.0048,  0.0477, -0.0470,  0.0045,  0.0128, -0.0441,  0.0218,  0.0365,
              -0.0206,  0.0348, -0.0249,  0.0256,  0.0222,  0.0019,  0.0289,  0.0248])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
               0.0498, 0.0498, 0.0498]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806,  0.0095, -0.0646,  0.0145,
               -0.0077, -0.0512, -0.0072, -0.0008,  0.0279, -0.0383, -0.0224,  0.0434,
               -0.0912, -0.0509,  0.0756, -0.0889,  0.0356, -0.0862, -0.0046,  0.0507,
               -0.0356, -0.0616, -0.0509, -0.0035,  0.0123,  0.0190, -0.0453,  0.0815,
               -0.0149,  0.0448, -0.0308, -0.0292, -0.0423,  0.0691,  0.0686, -0.0398,
               -0.0657,  0.0157, -0.0508,  0.0847, -0.0897,  0.0655,  0.0407,  0.0535,
               -0.0541, -0.0812,  0.0122, -0.0665, -0.0799,  0.0247, -0.0409,  0.0105,
               -0.0471, -0.0825, -0.0042,  0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
                0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
                0.0852,  0.0711,  0.0833,  0.0214,  0.0743,  0.0575,  0.0583, -0.0007,
                0.0814,  0.0736, -0.0248,  0.0284,  0.0873, -0.0174,  0.0206, -0.0740,
                0.0276, -0.0414,  0.0508, -0.0087, -0.0581,  0.0255,  0.0058,  0.0142,
               -0.0266,  0.0067, -0.0468,  0.0654,  0.0305, -0.0043, -0.0613,  0.0733,
                0.0400, -0.0446, -0.0243, -0.0434, -0.0616,  0.0371,  0.0253,  0.0681,
                0.0847, -0.0068,  0.0176, -0.0169, -0.0387,  0.0219, -0.0046, -0.0663],
              [-0.0324, -0.0686,  0.0105,  0.0805,  0.0090,  0.0304,  0.0097, -0.0191,
               -0.0591,  0.0876, -0.0748,  0.0383,  0.0680,  0.0441,  0.0479,  0.0484,
                0.0302, -0.0039,  0.0855, -0.0066,  0.0661, -0.0492,  0.0843, -0.0566,
                0.0517,  0.0880,  0.0308, -0.0874, -0.0144,  0.0143, -0.0663, -0.0484,
               -0.0368,  0.0709,  0.0610,  0.0495, -0.0031, -0.0503,  0.0562, -0.0030,
                0.0753,  0.0173,  0.0221, -0.0259,  0.0145,  0.0206, -0.0740,  0.0226,
               -0.0414,  0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
                0.0882,  0.0519,  0.0840,  0.0558,  0.0087,  0.0270,  0.0901,  0.0010,
                0.0620,  0.0696,  0.0825,  0.0557, -0.0043, -0.0531, -0.0447,  0.0474,
                0.0724,  0.0483, -0.0868,  0.0503, -0.0060,  0.0524, -0.0355,  0.0002,
               -0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416,  0.0861,
                0.0530,  0.0840, -0.0316, -0.0839, -0.0451, -0.0664,  0.0725,  0.0301,
                0.0456, -0.0145,  0.0455, -0.0850,  0.0010, -0.0722, -0.0800,  0.0512,
               -0.0753, -0.0348, -0.0249,  0.0067,  0.0063, -0.0506, -0.0310,  0.0592,
                0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]],
             requires_grad=True)
       tensor: tensor([[ 0.1279, -0.0866, -0.0686,  0.0555, -0.0638,  0.0050, -0.0615, -0.0530,
               -0.0351, -0.0392, -0.0150, -0.0328,  0.0167,  0.0180, -0.0615,  0.2103,
               -0.0772, -0.0698,  0.1113, -0.0363,  0.0195, -0.1669,  0.1130,  0.0848,
               -0.0937, -0.1075, -0.0521,  0.0558, -0.0513, -0.0210,  0.0153,  0.0361,
               -0.0155,  0.0751, -0.0595, -0.0718, -0.1218,  0.1084, -0.0045, -0.0181,
               -0.0626,  0.0720, -0.0465,  0.0618, -0.0506,  0.0861,  0.0443,  0.0697,
               -0.0507, -0.0497, -0.0155, -0.0383, -0.0927,  0.0032,  0.0311,  0.0651,
               -0.0475, -0.0090, -0.0013,  0.0427,  0.0172, -0.0191, -0.0302,  0.0091,
                0.0124, -0.1253, -0.0366, -0.0554, -0.0467, -0.0135, -0.1149, -0.0916,
                0.1359,  0.0563,  0.0745,  0.0458,  0.0777,  0.0731,  0.1130, -0.0496,
                0.0771,  0.1293, -0.0316,  0.0355,  0.0950,  0.0435,  0.0511, -0.0068,
               -0.0193, -0.0411, -0.0321, -0.0270, -0.0606,  0.0143,  0.1399,  0.1026,
               -0.0679,  0.0816, -0.0428,  0.1694, -0.0943, -0.0838, -0.0730,  0.0811,
               -0.0070, -0.1130, -0.0120, -0.0559, -0.0140,  0.0470,  0.0229,  0.0403,
                0.1721, -0.0774,  0.0857,  0.0897, -0.0025, -0.0037, -0.1085, -0.0464],
              [-0.0931, -0.0503,  0.0025,  0.0769,  0.0082, -0.0647,  0.0253,  0.0304,
               -0.1109,  0.1148, -0.0474,  0.0097, -0.0118,  0.0804,  0.1731,  0.0798,
                0.1452,  0.0859,  0.0719, -0.0387,  0.0086,  0.0264,  0.1407, -0.0077,
                0.0467,  0.0441,  0.0155, -0.1639,  0.1011,  0.0373, -0.0116, -0.0501,
               -0.0860,  0.0579, -0.0236,  0.0648,  0.0227, -0.1241,  0.1053,  0.0215,
                0.1890,  0.0845,  0.0857, -0.1136,  0.0768, -0.0305, -0.0581,  0.0310,
               -0.0614,  0.1498, -0.0297,  0.0213, -0.0608, -0.0596, -0.0909,  0.0021,
                0.1340,  0.0689,  0.0496,  0.0734, -0.0441, -0.0336,  0.1958, -0.0212,
                0.1678,  0.0891,  0.1948,  0.0640, -0.0615,  0.0174, -0.0600,  0.0613,
                0.0016,  0.1543, -0.1042,  0.0458,  0.0164,  0.1102, -0.0672, -0.0036,
               -0.0017, -0.0501,  0.0129, -0.0794, -0.0849, -0.0259, -0.1499,  0.0300,
                0.0810,  0.1579, -0.0678, -0.1053,  0.0361, -0.0487,  0.0602, -0.0358,
                0.0447, -0.0971,  0.0473, -0.1580, -0.0150, -0.0541, -0.1103,  0.0850,
               -0.0783, -0.0095, -0.0338, -0.0945, -0.0547, -0.0184, -0.1288,  0.0039,
               -0.0115, -0.0988, -0.0244, -0.1131,  0.0605, -0.0540, -0.0119, -0.0788]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0576,  0.0882], requires_grad=True)
       tensor: tensor([-0.0505,  0.0428], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., 0.,
               -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
               -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0.,
               0., 0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0.,
               -0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., 0., -0., -0., 0., -0., -0.],
              [-0., -0., 0., 0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., -0., 0., -0., 0., -0.,
               0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.,
               -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., -0., 0.,
               0., 0., -0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0.,
               0., -0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806,  0.0095, -0.0646,  0.0145,
               -0.0077, -0.0512, -0.0072, -0.0008,  0.0279, -0.0383, -0.0224,  0.0434,
               -0.0912, -0.0509,  0.0756, -0.0889,  0.0356, -0.0862, -0.0046,  0.0507,
               -0.0356, -0.0616, -0.0509, -0.0035,  0.0123,  0.0190, -0.0453,  0.0815,
               -0.0149,  0.0448, -0.0308, -0.0292, -0.0423,  0.0691,  0.0686, -0.0398,
               -0.0657,  0.0157, -0.0508,  0.0847, -0.0897,  0.0655,  0.0407,  0.0535,
               -0.0541, -0.0812,  0.0122, -0.0665, -0.0799,  0.0247, -0.0409,  0.0105,
               -0.0471, -0.0825, -0.0042,  0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
                0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
                0.0852,  0.0711,  0.0833,  0.0214,  0.0743,  0.0575,  0.0583, -0.0007,
                0.0814,  0.0736, -0.0248,  0.0284,  0.0873, -0.0174,  0.0206, -0.0740,
                0.0276, -0.0414,  0.0508, -0.0087, -0.0581,  0.0255,  0.0058,  0.0142,
               -0.0266,  0.0067, -0.0468,  0.0654,  0.0305, -0.0043, -0.0613,  0.0733,
                0.0400, -0.0446, -0.0243, -0.0434, -0.0616,  0.0371,  0.0253,  0.0681,
                0.0847, -0.0068,  0.0176, -0.0169, -0.0387,  0.0219, -0.0046, -0.0663],
              [-0.0324, -0.0686,  0.0105,  0.0805,  0.0090,  0.0304,  0.0097, -0.0191,
               -0.0591,  0.0876, -0.0748,  0.0383,  0.0680,  0.0441,  0.0479,  0.0484,
                0.0302, -0.0039,  0.0855, -0.0066,  0.0661, -0.0492,  0.0843, -0.0566,
                0.0517,  0.0880,  0.0308, -0.0874, -0.0144,  0.0143, -0.0663, -0.0484,
               -0.0368,  0.0709,  0.0610,  0.0495, -0.0031, -0.0503,  0.0562, -0.0030,
                0.0753,  0.0173,  0.0221, -0.0259,  0.0145,  0.0206, -0.0740,  0.0226,
               -0.0414,  0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
                0.0882,  0.0519,  0.0840,  0.0558,  0.0087,  0.0270,  0.0901,  0.0010,
                0.0620,  0.0696,  0.0825,  0.0557, -0.0043, -0.0531, -0.0447,  0.0474,
                0.0724,  0.0483, -0.0868,  0.0503, -0.0060,  0.0524, -0.0355,  0.0002,
               -0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416,  0.0861,
                0.0530,  0.0840, -0.0316, -0.0839, -0.0451, -0.0664,  0.0725,  0.0301,
                0.0456, -0.0145,  0.0455, -0.0850,  0.0010, -0.0722, -0.0800,  0.0512,
               -0.0753, -0.0348, -0.0249,  0.0067,  0.0063, -0.0506, -0.0310,  0.0592,
                0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]])
      (bias): Normal:
       loc: tensor([-0., 0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0576,  0.0882])
    )
    (observed): Observed()
  )
)

Fit the model

Finally we can set up the training loop

optim = torch.optim.Adam(net.parameters())
for i in range(1):
    for data, target in loader:
        net.observe(classification=target)
        borch.sample(net)
        net(data)
        loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / len(loader))
        loss.backward()
        optim.step()
        optim.zero_grad()

Now we can check the accuracy, Note that one should stop condtioning on the target by setting net.observe(None)

net.observe(None)
tot_acc = 0
with torch.no_grad():
    for i, (data, target) in enumerate(loader):
        borch.sample(net)
        out = net(data)
        acc = float((target == out).sum().float() / target.shape[0]) * 100
        tot_acc += acc
    tot_acc /= i + 1
print(tot_acc)

Out:

69.9999988079071

the accuracy is basically random, this is due to the fact that we are fitting white noise so it to be expected.

But in case you have trouble getting higher accuracy you should consider running for more epochs, setting up an augmentation pipeline (see: the data loading tutorial) and changing your posterior. The posterior can be changed using

net.apply(borch.set_posteriors(borch.posterior.Automatic))

Out:

Net(
  (posterior): Automatic()
  (prior): Module(
    (classification): Categorical:
     logits: tensor([[-0.1017, -0.1535],
            [ 0.1836, -0.1202],
            [ 0.2641, -0.0718],
            [ 0.1435,  0.0717],
            [ 0.1556,  0.0155],
            [ 0.0704,  0.0826],
            [-0.0427, -0.1238],
            [-0.1261, -0.1772],
            [ 0.0826,  0.0842],
            [ 0.0605,  0.0630],
            [ 0.1126, -0.0491],
            [ 0.1652, -0.0307],
            [ 0.2373, -0.1125],
            [ 0.2617,  0.0919],
            [ 0.0945, -0.0213],
            [-0.1082, -0.0242],
            [ 0.2161,  0.0436],
            [ 0.2820, -0.0276],
            [ 0.1257, -0.0373],
            [ 0.2219,  0.0406]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([])
  )
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., -0., 0., -0., 0., 0.], requires_grad=True)
       tensor: tensor([-0.5569, -0.0093, -0.1423, -0.3891, -0.2014, -0.4750],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.]]],


              [[[-0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.],
                [0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[-0., 0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.]]],


              [[[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-2.3901e-02,  1.7009e-01, -4.0790e-02,  1.9047e-01,  1.1854e-01],
                [-1.0204e-01, -1.4715e-01,  9.0185e-03,  8.0316e-02, -1.4466e-01],
                [ 1.7187e-01,  8.4639e-02,  1.6665e-01, -1.5608e-01, -9.5238e-02],
                [ 1.9682e-01, -9.3068e-02, -1.9482e-02,  1.6220e-01, -5.5923e-02],
                [-1.5368e-01, -1.2494e-01,  3.1866e-02, -1.7428e-01, -6.5961e-02]]],


              [[[-5.4801e-02,  5.0452e-02, -1.5372e-01, -1.1482e-01,  1.5138e-01],
                [ 1.8247e-03,  4.6463e-02, -1.7931e-01, -8.4841e-02, -6.3566e-02],
                [-1.9791e-02, -1.0920e-01,  1.2796e-01, -6.0495e-02,  1.2142e-01],
                [-1.0610e-01,  2.8335e-02,  2.0862e-02, -1.2132e-01, -5.6049e-03],
                [ 2.7007e-02,  1.5627e-01,  7.0422e-02, -2.1336e-03, -5.9226e-02]]],


              [[[-1.6497e-01, -9.5347e-02,  9.7235e-02,  1.7565e-01, -1.4118e-01],
                [-1.1203e-02, -5.6668e-02,  9.0249e-02,  1.9961e-01, -2.0049e-02],
                [-4.5493e-02, -2.0235e-02, -1.9463e-01,  1.5131e-01,  1.6076e-01],
                [-1.9071e-01, -1.6333e-01,  1.0380e-01, -7.2042e-02,  1.0249e-01],
                [ 1.7660e-01,  1.8708e-02, -1.0379e-01,  8.4113e-02, -1.3492e-01]]],


              [[[ 1.3284e-01, -1.7679e-02, -9.9538e-02,  1.5133e-01,  1.0864e-01],
                [-1.9522e-01,  1.0066e-01, -1.0742e-01, -1.1599e-01,  1.6930e-01],
                [-1.0281e-01, -1.4473e-01,  1.6300e-01, -7.4540e-02, -6.5797e-02],
                [ 1.5015e-01,  7.7701e-03,  7.3404e-02,  9.5653e-02, -1.2661e-01],
                [-3.2228e-02, -7.9872e-02,  1.9932e-01,  6.0159e-02,  1.3894e-01]]],


              [[[-1.7235e-01,  5.8651e-06,  8.9371e-02, -1.5355e-01, -1.2702e-01],
                [ 5.9223e-02,  1.1539e-01, -5.5243e-03, -6.4484e-02, -1.3380e-01],
                [ 9.6366e-03, -1.0979e-01, -1.1570e-01,  7.1673e-03,  8.9918e-02],
                [ 2.4720e-02,  5.8142e-02, -1.0872e-01, -1.4363e-01, -1.1776e-01],
                [-8.8460e-02, -1.7740e-01, -7.1380e-02, -1.1692e-01, -1.7076e-02]]],


              [[[-1.8614e-01, -1.2378e-01, -1.3271e-01, -1.5860e-02, -9.4571e-02],
                [-7.8788e-03,  4.7546e-02,  1.4185e-01,  8.6187e-02, -1.0654e-01],
                [-6.4892e-02, -7.4628e-02,  5.9973e-02,  5.0245e-02,  1.1456e-01],
                [-4.4647e-02,  9.8620e-02, -1.2445e-01, -6.6966e-02,  3.5321e-02],
                [ 1.9966e-01, -3.8652e-03,  3.5615e-03, -1.3894e-02,  3.5925e-02]]]])
      (bias): Normal:
       loc: tensor([-0., -0., 0., -0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1531, -0.0004,  0.1924, -0.0493,  0.0953,  0.0265])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.],
             requires_grad=True)
       tensor: tensor([ 0.3142,  0.4254, -0.2881, -0.2760,  0.0442,  0.1406, -0.1709, -0.1576,
              -0.0890,  0.1179,  0.0289,  0.0774, -0.6501, -0.1156, -0.0588,  0.1319],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.]],

               [[-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., -0.]],

               [[-0., -0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., -0., 0., 0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., 0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., 0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., 0., 0., 0.]],

               [[0., 0., 0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [0., 0., -0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.]]],


              [[[-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[-0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.]],

               [[0., -0., 0., 0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.]],

               [[0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., -0., 0.]]],


              ...,


              [[[0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.]],

               [[0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., 0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.]],

               [[0., 0., 0., 0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., 0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., -0.]],

               [[-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., 0., -0.]]],


              [[[-0., -0., 0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., 0., -0.]],

               [[0., -0., -0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.]]],


              [[[-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., -0., -0.]],

               [[0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., -0., 0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.]],

               [[0., 0., -0., -0., 0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-2.9409e-02,  6.0003e-02,  5.3245e-02,  2.8255e-02,  5.1865e-02],
                [-1.3120e-03,  6.9924e-02, -2.1958e-02,  4.5196e-02,  5.5987e-02],
                [ 7.2948e-02, -1.2874e-02, -1.4010e-02, -7.2656e-02, -3.8272e-02],
                [-6.1391e-02,  2.6227e-02,  2.2223e-02,  5.1530e-02,  7.8572e-02],
                [ 1.1986e-02, -3.5154e-02,  2.2590e-02, -2.8268e-02, -4.1092e-02]],

               [[-7.7353e-02, -3.3001e-03,  2.2446e-02,  3.6329e-02, -8.4312e-03],
                [ 5.9898e-02,  3.3921e-02, -4.0756e-02,  5.6990e-02, -7.9915e-02],
                [ 1.1058e-02, -7.5468e-02,  3.1362e-02, -1.2730e-02,  7.0498e-02],
                [-6.7726e-03, -3.9453e-02,  6.8676e-02,  2.8066e-02,  4.1992e-02],
                [ 8.0445e-02,  6.9594e-02,  2.9287e-02,  4.4932e-02, -7.2432e-02]],

               [[-6.8457e-02, -7.5792e-02,  3.8843e-02,  3.2058e-02,  2.1625e-02],
                [ 7.5601e-02, -5.7455e-02,  1.2939e-02, -3.5165e-02, -5.6556e-02],
                [ 8.3346e-03,  2.2498e-02,  6.8399e-03, -2.3131e-02, -2.1235e-03],
                [-6.8738e-02, -6.5464e-02, -4.5512e-02, -2.1878e-02, -6.3732e-02],
                [ 4.4843e-02,  1.4243e-02, -7.8475e-02,  2.0007e-02,  6.8106e-02]],

               [[-7.1642e-03, -4.0418e-02,  1.6455e-03, -7.4808e-03,  3.0932e-02],
                [-5.6909e-02, -6.4950e-02, -4.3916e-02, -6.4650e-02,  7.4918e-02],
                [ 2.7178e-03,  7.6559e-02,  2.3546e-02,  1.0888e-02, -3.2943e-02],
                [-3.1639e-02, -4.4183e-02, -1.1435e-02,  1.8024e-02,  7.9902e-02],
                [-1.6372e-02,  6.6337e-03, -5.5328e-02,  1.6804e-02, -1.1563e-02]],

               [[ 3.4681e-02,  3.3633e-02, -5.8334e-02, -4.5784e-02, -2.0550e-02],
                [ 8.0547e-02,  4.9685e-02,  7.5298e-02, -2.9211e-02,  2.2224e-02],
                [-2.4409e-02,  7.0003e-02,  2.1558e-04, -4.4977e-02,  2.3133e-02],
                [-7.4964e-02, -6.4867e-02, -4.5671e-02,  3.1878e-02,  3.5104e-02],
                [ 2.2541e-02,  3.2949e-02, -3.5216e-02, -4.0980e-02, -1.9989e-02]],

               [[-7.0267e-02, -2.8274e-03, -3.7755e-02,  3.8392e-02, -1.3160e-05],
                [-2.4893e-02,  2.6938e-02, -6.6667e-02, -6.4968e-02, -7.5396e-02],
                [ 4.5711e-02, -3.3377e-02,  2.4278e-02, -4.1101e-03,  1.8952e-02],
                [ 2.0236e-02,  6.4522e-03, -5.7711e-02, -4.5909e-02,  7.4404e-02],
                [-3.7771e-02, -3.1536e-03,  6.4748e-02, -4.4604e-02, -2.2752e-02]]],


              [[[-2.9177e-02,  5.0791e-02, -5.5935e-03,  4.7169e-02, -1.1366e-02],
                [ 2.2291e-02, -2.1153e-02, -1.4060e-02, -5.3690e-02,  6.6127e-02],
                [-1.1869e-03,  7.9363e-03,  6.0124e-02,  8.7514e-04, -2.6833e-02],
                [-4.4174e-02,  1.9975e-02,  2.7939e-02,  2.7791e-02, -1.7919e-02],
                [-2.9191e-02,  2.5500e-02,  3.6869e-03,  3.1063e-02, -3.7672e-02]],

               [[-6.7084e-02,  5.3586e-02,  1.8485e-02, -7.8099e-03,  4.6288e-02],
                [ 4.9055e-02, -1.4634e-02,  3.2799e-02,  3.9726e-02, -7.0032e-02],
                [ 7.0897e-02,  3.4370e-02, -1.4814e-02, -3.9030e-02,  3.0867e-02],
                [ 3.5541e-02, -7.2574e-02, -1.5650e-03, -8.1162e-02,  2.6245e-02],
                [ 5.5721e-02, -2.2033e-02, -7.2623e-02, -5.1459e-02,  3.3599e-02]],

               [[ 1.1310e-02, -2.9816e-02,  6.3727e-02,  3.9850e-02,  1.3761e-02],
                [ 3.0453e-02, -4.8504e-02,  5.3189e-02,  1.9425e-02,  4.7484e-02],
                [ 6.1376e-02, -7.8290e-02, -6.8859e-02,  1.8497e-02, -1.1496e-02],
                [-7.8178e-02, -4.5904e-02,  7.3181e-02,  2.9441e-02,  4.6967e-02],
                [ 7.6978e-02,  7.2934e-02,  5.6798e-02,  5.8828e-02,  4.4637e-02]],

               [[ 6.0281e-02,  7.7289e-02,  7.9016e-02, -4.1437e-02,  3.1101e-02],
                [-5.0620e-02,  3.3108e-02, -5.8687e-02,  2.7694e-02,  5.4294e-02],
                [ 2.1156e-02,  1.7004e-03,  2.4742e-02,  6.9593e-02,  5.7699e-02],
                [ 6.8876e-02,  3.2239e-02,  3.3322e-02,  2.9973e-02,  7.4267e-02],
                [-2.4736e-03,  3.8454e-02, -2.5898e-02,  2.0443e-02,  6.0816e-02]],

               [[ 1.6487e-02, -6.8994e-03,  5.2835e-02,  5.7784e-02, -1.7036e-02],
                [-7.1883e-02,  3.1576e-04, -5.5839e-02,  1.4949e-02, -1.9834e-03],
                [-2.7395e-02,  3.9861e-04, -1.0588e-02,  9.9140e-03,  5.1499e-02],
                [ 7.4060e-02, -1.0655e-02,  1.1668e-02,  4.9183e-02,  5.2846e-02],
                [ 2.8634e-02,  4.2678e-02, -1.4281e-02,  1.3904e-03,  7.6289e-02]],

               [[-5.1256e-02, -2.2514e-02, -7.2964e-02, -4.4120e-02, -5.8914e-02],
                [-4.1579e-02,  2.8281e-02,  3.9429e-02,  7.5058e-03,  6.5170e-03],
                [-3.4494e-02,  7.5710e-02,  4.1078e-02,  4.4451e-02,  4.2661e-02],
                [-5.4398e-02,  5.1592e-02, -2.6367e-02, -3.2980e-02,  5.3860e-02],
                [-4.2436e-02, -8.4286e-03,  7.5331e-02, -6.6725e-02,  4.9887e-02]]],


              [[[-1.5211e-02, -6.2506e-03,  3.0621e-03,  3.0725e-02, -7.0877e-03],
                [ 1.1974e-02, -5.2611e-02, -2.7415e-02,  4.3479e-02, -4.2108e-02],
                [ 3.3816e-02,  6.1523e-02, -9.9011e-03, -3.7770e-02,  6.5915e-04],
                [ 5.3678e-03,  5.9921e-02, -3.4530e-02,  5.1942e-02,  5.3762e-02],
                [-4.7293e-02, -6.2274e-02, -7.5059e-02,  8.1645e-02,  2.1149e-02]],

               [[-1.4459e-02, -2.7155e-02, -2.5730e-02,  7.6751e-02, -1.6932e-02],
                [-5.3342e-02, -2.6885e-02,  4.3476e-02, -7.9174e-02, -3.5761e-02],
                [ 4.2970e-02,  2.5516e-02, -6.6640e-02, -2.9457e-03, -8.2757e-03],
                [-2.5080e-02,  4.1672e-02,  4.2424e-02, -4.8704e-02, -6.0434e-02],
                [ 1.2884e-02, -7.9950e-02, -7.0913e-02, -8.0863e-02, -5.4536e-02]],

               [[-5.4303e-02,  6.7885e-02, -5.3922e-02,  6.5582e-02, -5.2617e-03],
                [-8.4440e-03,  8.0911e-02, -3.8667e-02, -5.6241e-04,  7.0876e-02],
                [ 5.4673e-02,  1.3465e-02, -2.7178e-02, -3.6691e-02, -2.6519e-02],
                [-2.8238e-02, -5.0765e-02, -3.4076e-02,  3.1219e-02, -1.8919e-02],
                [-4.6076e-02, -7.6516e-02,  3.1247e-02,  4.1743e-02,  7.5575e-02]],

               [[ 3.8787e-03, -8.1731e-03,  4.7381e-03,  5.8261e-02,  4.6416e-02],
                [ 7.7171e-02, -6.5924e-02,  1.5769e-02,  2.6777e-02,  7.7365e-02],
                [-7.3126e-02,  4.7624e-02, -7.0620e-02,  7.3309e-02,  7.7585e-03],
                [-7.9208e-02, -7.3783e-03, -5.3142e-02,  3.4386e-02,  5.6230e-03],
                [-4.4492e-02,  7.5403e-02, -2.8887e-02,  2.6937e-02,  7.0698e-03]],

               [[-1.7104e-02, -6.6983e-02, -5.7655e-02,  5.1198e-02, -8.0137e-02],
                [-7.2406e-02, -3.7856e-02, -7.2086e-02,  7.2641e-02, -7.0749e-02],
                [ 6.6330e-02,  2.8797e-02,  6.2197e-02, -4.6068e-03,  1.7392e-03],
                [ 4.2384e-02,  5.9572e-02, -4.5953e-02,  6.6345e-03,  7.3979e-02],
                [-4.8313e-02, -1.3063e-02,  1.7648e-02, -5.0903e-02, -7.1852e-02]],

               [[ 3.2147e-02, -8.1585e-02,  6.7507e-02, -7.7056e-02,  1.7667e-02],
                [ 2.0535e-03, -7.4221e-02,  1.0343e-02,  4.3018e-02,  9.7351e-03],
                [-2.1410e-02,  5.4089e-02, -2.4102e-02, -4.0551e-02, -3.6118e-03],
                [ 4.9847e-02,  6.9608e-02,  3.6233e-03,  5.7025e-02,  6.3206e-02],
                [ 1.4611e-02, -2.9885e-02,  5.6140e-02, -6.4338e-02,  8.5266e-03]]],


              ...,


              [[[ 5.7051e-02,  3.4026e-02, -3.7723e-02,  1.4372e-02, -4.4266e-03],
                [-8.0557e-02,  1.1810e-02, -6.9374e-02,  3.4264e-02, -3.9068e-02],
                [ 3.2814e-02, -4.9334e-02, -3.2234e-02,  3.7901e-02,  9.9268e-03],
                [ 1.2846e-03, -5.9199e-02, -5.6303e-02,  1.2189e-03,  7.8874e-02],
                [ 7.6858e-04,  2.4341e-02,  4.0423e-02, -7.7602e-02, -3.6388e-02]],

               [[ 5.9494e-02,  4.4230e-02, -5.9128e-02, -1.6639e-02, -6.4884e-02],
                [-2.3457e-02, -7.5842e-03, -3.3986e-02, -2.0435e-02, -4.2466e-02],
                [-6.8915e-02,  1.6417e-02, -9.0384e-03, -5.6058e-02,  1.1540e-02],
                [ 3.9632e-02,  3.8881e-02,  5.5834e-02,  7.5591e-02,  1.8463e-02],
                [ 4.5034e-02, -6.4665e-02,  6.7883e-02,  7.1108e-02,  8.0694e-02]],

               [[-7.4652e-02, -2.9270e-02, -7.4301e-02, -1.4067e-02, -6.0331e-02],
                [-7.9629e-02, -2.5316e-03, -3.4649e-02,  7.9736e-02,  2.4963e-02],
                [-1.4102e-02,  3.0896e-02, -5.4594e-02,  5.7641e-02,  7.8276e-02],
                [ 3.3722e-02,  1.6397e-02,  6.6251e-02,  2.5637e-02,  4.0073e-02],
                [ 1.9370e-02, -1.4960e-02, -4.0503e-02, -3.6491e-02, -6.9970e-02]],

               [[ 3.9434e-02,  6.7049e-02,  7.1627e-02,  6.9307e-02, -5.7508e-03],
                [ 2.8151e-02,  7.9890e-02, -6.4687e-02, -6.8959e-02,  6.8179e-02],
                [ 1.2583e-02,  6.6052e-02,  6.7770e-02,  1.0853e-02,  6.3935e-02],
                [ 4.4214e-02, -5.4527e-02, -6.3199e-02, -2.4454e-02, -8.0348e-02],
                [-1.1810e-04,  6.2292e-02, -2.1831e-02, -4.1282e-02,  3.4718e-02]],

               [[-8.9495e-03, -3.5923e-02, -4.9030e-02,  1.7068e-02,  5.7835e-02],
                [-6.2950e-02,  6.9258e-02,  1.4909e-02, -3.9252e-02,  3.0917e-02],
                [-5.0831e-02, -2.6109e-02, -4.2526e-02,  4.9180e-03, -6.7907e-02],
                [-1.4867e-02,  8.3498e-03, -6.3780e-02, -6.3819e-02, -7.7414e-02],
                [ 6.5369e-02,  3.5118e-02, -3.5070e-02,  3.1514e-02, -1.7773e-02]],

               [[-1.9000e-02,  4.8772e-02, -4.0550e-02,  5.7766e-02, -4.8687e-02],
                [ 7.0112e-02,  7.4851e-02, -5.0324e-02,  4.2522e-02,  6.6367e-02],
                [-6.6793e-02, -6.3487e-02, -6.3574e-02,  7.3530e-02, -6.7062e-02],
                [ 1.9297e-02,  3.9876e-02,  7.0333e-03,  3.6541e-02,  3.0865e-02],
                [ 6.9009e-02,  2.7737e-03, -6.0400e-02,  1.5249e-03, -1.5177e-03]]],


              [[[-4.9098e-02, -1.2656e-02,  3.0326e-02,  4.6450e-02,  4.1143e-02],
                [ 6.5180e-02, -4.5543e-02, -6.0194e-02, -8.1101e-02,  7.3691e-02],
                [-5.2880e-02, -5.3283e-02, -4.6874e-02,  2.0506e-02,  1.4432e-02],
                [ 5.3466e-05,  6.1875e-02, -5.2208e-02, -2.1149e-02, -6.5709e-02],
                [-7.2209e-02, -2.8706e-02,  6.6109e-02,  5.8108e-02, -1.8114e-02]],

               [[-5.8877e-02,  3.5183e-02,  6.5460e-02,  5.2934e-02,  3.5997e-02],
                [-6.5718e-02,  2.7700e-02,  7.1110e-02, -5.7825e-02,  6.1866e-03],
                [-5.3281e-03, -7.6189e-02, -6.9421e-02,  6.4743e-02, -1.1912e-02],
                [ 7.6864e-02, -5.8819e-03, -2.0277e-02, -1.6263e-02,  4.5729e-02],
                [ 2.0473e-03, -3.1893e-02, -3.0088e-02,  6.1322e-02, -1.3287e-02]],

               [[ 4.0654e-02, -3.8251e-03, -5.8287e-02, -6.9760e-03, -4.9954e-02],
                [-3.1949e-02, -6.5679e-02, -9.8746e-04, -5.5646e-02, -1.6937e-03],
                [-5.0579e-02,  5.1921e-02,  4.0006e-02, -5.3846e-02,  3.6710e-03],
                [-5.5284e-03,  5.2453e-02,  3.5617e-02, -4.4475e-02,  2.7835e-02],
                [ 3.6465e-02,  2.2936e-02,  4.9494e-02, -6.8768e-02, -6.8512e-02]],

               [[-1.5606e-02, -5.8101e-02, -4.8349e-02, -5.4572e-03, -8.1381e-02],
                [ 3.3837e-02,  6.9886e-02,  2.5937e-03, -4.4428e-02, -6.1442e-03],
                [-3.3799e-02,  7.6725e-02,  1.5202e-02,  2.7467e-02, -7.2112e-02],
                [-5.3887e-02,  5.3134e-02, -5.5426e-02,  8.1476e-02,  1.0773e-02],
                [-1.9578e-02,  1.7628e-02, -2.2382e-02,  6.7076e-02, -1.3475e-02]],

               [[ 3.7281e-02,  2.7106e-02, -7.8289e-03, -6.1201e-02, -4.5366e-02],
                [-5.1809e-02, -1.0889e-02,  4.4019e-02, -4.0099e-02, -6.2939e-02],
                [ 7.8826e-02,  1.4336e-02, -7.8953e-02, -4.1699e-03,  2.1759e-02],
                [ 4.3422e-02,  6.1053e-02, -5.1035e-02,  2.5170e-02,  8.1194e-02],
                [-3.5907e-02,  3.5084e-02,  5.4858e-02,  5.7819e-02, -6.8527e-02]],

               [[ 6.0340e-02, -4.5873e-02,  4.5307e-02, -1.8559e-02, -5.9891e-02],
                [ 7.1101e-02,  5.7979e-03, -2.1455e-02, -5.7839e-02, -2.6964e-02],
                [ 4.5972e-02,  4.6237e-02, -1.8353e-02,  5.5372e-03,  5.8802e-02],
                [-8.0939e-02,  2.2098e-03, -2.7943e-03,  6.9556e-02,  3.5299e-03],
                [-2.4275e-02, -6.1490e-02, -2.4350e-02, -5.8685e-02, -7.6820e-02]]],


              [[[-5.8326e-02,  4.3804e-02,  5.4642e-02,  2.9479e-02,  5.5766e-02],
                [-6.2955e-02,  4.9442e-02, -1.7882e-02, -6.4492e-02, -3.5590e-02],
                [ 7.8974e-02,  1.8189e-02, -4.3076e-02, -4.6822e-02, -5.9352e-02],
                [ 1.1472e-02,  6.9467e-02, -3.5045e-02, -1.3463e-03, -7.0617e-02],
                [-5.7437e-02, -5.7150e-02,  4.9108e-02,  2.2168e-02, -5.4964e-02]],

               [[-3.2895e-02, -2.2746e-03,  6.8428e-02, -7.4781e-02,  6.5675e-02],
                [-8.0232e-02, -2.6468e-02, -2.1136e-02,  2.1449e-02,  6.4572e-02],
                [ 2.9930e-03,  1.1987e-02,  4.8122e-03,  3.4183e-02, -7.8918e-02],
                [ 6.3749e-02, -2.5083e-02,  1.1253e-02, -4.4485e-02,  3.3380e-02],
                [ 5.0096e-03, -1.7321e-02,  8.0185e-02, -2.3853e-02, -2.9333e-03]],

               [[ 2.6648e-02, -7.6799e-02,  3.2204e-03, -7.7476e-02, -4.4615e-03],
                [ 5.7110e-02,  7.8575e-02,  5.3204e-02, -7.8592e-02,  4.1383e-03],
                [ 1.6194e-02,  2.5400e-02,  7.4070e-02, -3.9092e-03, -2.9417e-02],
                [-7.9407e-02,  2.5042e-02, -3.8854e-02,  2.8143e-02,  2.8485e-03],
                [-3.3828e-02, -7.5645e-02,  7.8511e-02, -4.4048e-02,  6.0887e-02]],

               [[-6.4552e-02, -3.1646e-02,  6.5499e-02, -6.8577e-02, -5.1529e-02],
                [ 6.1176e-02, -4.8461e-02,  4.7687e-02, -3.0069e-02, -1.7665e-02],
                [ 7.7632e-02, -1.7017e-02, -6.2812e-02, -1.8810e-02, -4.1500e-02],
                [ 6.1360e-02, -1.9826e-02, -6.4593e-02,  3.5071e-02, -5.9178e-02],
                [-6.6739e-02,  2.6098e-02, -5.5998e-02,  8.1334e-02,  3.7472e-02]],

               [[-5.5207e-02,  1.4355e-02, -2.2037e-02, -2.4025e-02,  7.2631e-02],
                [-1.0448e-02,  1.9105e-03, -5.5223e-02,  4.6377e-02, -6.8534e-02],
                [-2.4292e-02,  7.5258e-02, -8.0224e-02, -6.6001e-02, -4.6628e-02],
                [ 4.5334e-02, -2.3274e-02, -4.3572e-02,  4.3487e-03, -4.6057e-02],
                [-5.3757e-02, -2.0336e-02, -5.2245e-02,  2.2213e-02, -6.7578e-03]],

               [[ 5.7154e-02,  6.9033e-02, -2.7450e-02, -5.9039e-02,  3.0233e-02],
                [ 5.5904e-02,  5.2798e-02, -2.2586e-02,  2.8411e-02, -6.8010e-03],
                [ 5.1257e-02, -4.3710e-02,  8.7161e-03,  1.9411e-02, -3.5285e-03],
                [-8.0450e-02,  6.1012e-02, -7.7756e-02, -2.1472e-02,  4.7537e-02],
                [-4.7231e-02,  3.7300e-02,  2.7754e-02, -2.4025e-02,  1.0065e-02]]]])
      (bias): Normal:
       loc: tensor([0., -0., -0., -0., -0., 0., 0., -0., 0., -0., -0., 0., 0., 0., -0., 0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0215, -0.0800, -0.0787, -0.0173, -0.0345,  0.0684,  0.0584, -0.0804,
               0.0098, -0.0490, -0.0535,  0.0145,  0.0056,  0.0082, -0.0256,  0.0140])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
              0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
              -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
              0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.],
             requires_grad=True)
       tensor: tensor([ 0.0928, -0.0557, -0.0987, -0.0677,  0.0549, -0.0778,  0.0566, -0.0783,
              -0.1079, -0.0651,  0.1489, -0.0263, -0.0720,  0.0551, -0.1069,  0.0144,
              -0.0300,  0.0989,  0.0729, -0.0451,  0.0721,  0.0009,  0.1521,  0.1494,
               0.1114, -0.0930, -0.1384, -0.0729, -0.0837, -0.0453,  0.0476,  0.0383,
              -0.0459, -0.0405, -0.1977,  0.0385, -0.0533, -0.0039,  0.1198,  0.0202,
              -0.1197, -0.0355, -0.0581, -0.1910,  0.0959, -0.0030,  0.0761,  0.1127,
              -0.0603, -0.1291,  0.0822,  0.1164,  0.0096, -0.1714,  0.0553, -0.1927,
              -0.1611, -0.1322, -0.0809, -0.0486,  0.1202, -0.0975, -0.0132,  0.0918,
               0.1541, -0.1653,  0.0311,  0.0700,  0.0002, -0.0480, -0.0241, -0.2082,
              -0.1500,  0.0072,  0.0052,  0.0860,  0.0694,  0.1631, -0.0141, -0.1648,
              -0.1694, -0.0393,  0.0137, -0.0039,  0.0152,  0.0567,  0.0944,  0.0417,
               0.0136, -0.1908, -0.0396,  0.1616,  0.1286,  0.2245, -0.0121, -0.1299,
              -0.1069,  0.0543,  0.0613, -0.1372, -0.0757, -0.0744, -0.0156, -0.0350,
              -0.1974, -0.0417,  0.1595, -0.1018, -0.0879, -0.0230, -0.1762,  0.0350,
              -0.1005, -0.0702, -0.1774,  0.0834,  0.0309,  0.1576,  0.0500,  0.1578],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., -0., 0.,  ..., 0., -0., -0.],
              [-0., 0., -0.,  ..., 0., -0., -0.],
              [0., 0., -0.,  ..., -0., 0., 0.],
              ...,
              [0., -0., 0.,  ..., -0., -0., 0.],
              [0., -0., 0.,  ..., -0., 0., 0.],
              [0., -0., -0.,  ..., -0., -0., -0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-0.0489, -0.0457,  0.0358,  ...,  0.0488, -0.0310, -0.0318],
              [-0.0198,  0.0492, -0.0495,  ...,  0.0437, -0.0228, -0.0161],
              [ 0.0042,  0.0213, -0.0018,  ..., -0.0004,  0.0377,  0.0324],
              ...,
              [ 0.0020, -0.0197,  0.0377,  ..., -0.0133, -0.0496,  0.0166],
              [ 0.0128, -0.0165,  0.0298,  ..., -0.0352,  0.0281,  0.0219],
              [ 0.0448, -0.0166, -0.0012,  ..., -0.0042, -0.0289, -0.0339]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0.,
              0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., -0., -0.,
              -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0.,
              0., -0., -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., 0., 0., -0., -0., -0., 0., 0., 0., 0.,
              -0., -0., -0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0350,  0.0458, -0.0330,  0.0477,  0.0383, -0.0136, -0.0182,  0.0285,
               0.0197, -0.0431,  0.0120, -0.0445,  0.0171, -0.0019, -0.0141, -0.0021,
              -0.0429, -0.0159,  0.0028,  0.0272, -0.0290,  0.0047,  0.0452, -0.0022,
               0.0279,  0.0323, -0.0433,  0.0049,  0.0063, -0.0388,  0.0090, -0.0233,
               0.0251,  0.0375,  0.0274, -0.0337,  0.0122,  0.0217,  0.0230, -0.0405,
              -0.0476, -0.0063, -0.0021,  0.0267,  0.0014,  0.0228, -0.0130, -0.0471,
              -0.0170, -0.0349, -0.0472,  0.0116,  0.0002, -0.0426, -0.0129,  0.0492,
               0.0117, -0.0143, -0.0025,  0.0040, -0.0466, -0.0037,  0.0341, -0.0261,
               0.0327, -0.0433,  0.0025,  0.0201,  0.0211, -0.0235,  0.0472, -0.0291,
               0.0431, -0.0314, -0.0255,  0.0108, -0.0499, -0.0164, -0.0294, -0.0290,
              -0.0305, -0.0172,  0.0238,  0.0029, -0.0029,  0.0172,  0.0227,  0.0006,
               0.0120, -0.0068, -0.0043, -0.0289,  0.0060,  0.0199,  0.0122,  0.0423,
              -0.0015, -0.0034, -0.0201, -0.0374,  0.0159, -0.0258,  0.0075, -0.0097,
              -0.0048,  0.0477, -0.0470,  0.0045,  0.0128, -0.0441,  0.0218,  0.0365,
              -0.0206,  0.0348, -0.0249,  0.0256,  0.0222,  0.0019,  0.0289,  0.0248])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=2, bias=True
    (posterior): Automatic(
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.7071, 0.7071], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0., 0.], requires_grad=True)
       tensor: tensor([-0.4664, -0.5310], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., 0., -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., 0.,
               -0., -0., -0., -0., 0., 0., -0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
               -0., -0., 0., -0., -0., 0., -0., 0., -0., -0., -0., 0., -0., -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0.,
               0., 0., 0., 0., 0., 0., 0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., 0.,
               -0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., 0., -0., -0., 0., -0., -0.],
              [-0., -0., 0., 0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., -0., 0., -0., 0., -0.,
               0., 0., 0., -0., -0., 0., -0., -0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., 0.,
               -0., 0., -0., -0., -0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., -0., 0.,
               0., 0., -0., 0., -0., 0., -0., 0., -0., -0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0.,
               0., -0., 0., -0., 0., -0., -0., 0., -0., -0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
               0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0760, -0.0099, -0.0237, -0.0145, -0.0806,  0.0095, -0.0646,  0.0145,
               -0.0077, -0.0512, -0.0072, -0.0008,  0.0279, -0.0383, -0.0224,  0.0434,
               -0.0912, -0.0509,  0.0756, -0.0889,  0.0356, -0.0862, -0.0046,  0.0507,
               -0.0356, -0.0616, -0.0509, -0.0035,  0.0123,  0.0190, -0.0453,  0.0815,
               -0.0149,  0.0448, -0.0308, -0.0292, -0.0423,  0.0691,  0.0686, -0.0398,
               -0.0657,  0.0157, -0.0508,  0.0847, -0.0897,  0.0655,  0.0407,  0.0535,
               -0.0541, -0.0812,  0.0122, -0.0665, -0.0799,  0.0247, -0.0409,  0.0105,
               -0.0471, -0.0825, -0.0042,  0.0652, -0.0086, -0.0002, -0.0784, -0.0430,
                0.0104, -0.0905, -0.0506, -0.0340, -0.0407, -0.0163, -0.0497, -0.0516,
                0.0852,  0.0711,  0.0833,  0.0214,  0.0743,  0.0575,  0.0583, -0.0007,
                0.0814,  0.0736, -0.0248,  0.0284,  0.0873, -0.0174,  0.0206, -0.0740,
                0.0276, -0.0414,  0.0508, -0.0087, -0.0581,  0.0255,  0.0058,  0.0142,
               -0.0266,  0.0067, -0.0468,  0.0654,  0.0305, -0.0043, -0.0613,  0.0733,
                0.0400, -0.0446, -0.0243, -0.0434, -0.0616,  0.0371,  0.0253,  0.0681,
                0.0847, -0.0068,  0.0176, -0.0169, -0.0387,  0.0219, -0.0046, -0.0663],
              [-0.0324, -0.0686,  0.0105,  0.0805,  0.0090,  0.0304,  0.0097, -0.0191,
               -0.0591,  0.0876, -0.0748,  0.0383,  0.0680,  0.0441,  0.0479,  0.0484,
                0.0302, -0.0039,  0.0855, -0.0066,  0.0661, -0.0492,  0.0843, -0.0566,
                0.0517,  0.0880,  0.0308, -0.0874, -0.0144,  0.0143, -0.0663, -0.0484,
               -0.0368,  0.0709,  0.0610,  0.0495, -0.0031, -0.0503,  0.0562, -0.0030,
                0.0753,  0.0173,  0.0221, -0.0259,  0.0145,  0.0206, -0.0740,  0.0226,
               -0.0414,  0.0712, -0.0427, -0.0477, -0.0386, -0.0709, -0.0451, -0.0469,
                0.0882,  0.0519,  0.0840,  0.0558,  0.0087,  0.0270,  0.0901,  0.0010,
                0.0620,  0.0696,  0.0825,  0.0557, -0.0043, -0.0531, -0.0447,  0.0474,
                0.0724,  0.0483, -0.0868,  0.0503, -0.0060,  0.0524, -0.0355,  0.0002,
               -0.0195, -0.0888, -0.0211, -0.0551, -0.0292, -0.0041, -0.0416,  0.0861,
                0.0530,  0.0840, -0.0316, -0.0839, -0.0451, -0.0664,  0.0725,  0.0301,
                0.0456, -0.0145,  0.0455, -0.0850,  0.0010, -0.0722, -0.0800,  0.0512,
               -0.0753, -0.0348, -0.0249,  0.0067,  0.0063, -0.0506, -0.0310,  0.0592,
                0.0374, -0.0612, -0.0298, -0.0912, -0.0550, -0.0416, -0.0526, -0.0344]])
      (bias): Normal:
       loc: tensor([-0., 0.])
       scale: tensor([0.7071, 0.7071])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0576,  0.0882])
    )
    (observed): Observed()
  )
)

One can also set the posterior when one creates the module

nn.Linear(10, 10, posterior=borch.posterior.Normal(log_scale=-3))

Out:

Linear(
  in_features=10, out_features=10, bias=True
  (posterior): Normal(
    (weight): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([[0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498],
            [0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
             0.0498]], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([[-0., 0., 0., -0., 0., -0., 0., 0., 0., -0.],
            [0., 0., -0., -0., -0., -0., -0., -0., 0., 0.],
            [0., 0., 0., -0., -0., -0., 0., 0., -0., 0.],
            [0., -0., 0., -0., 0., 0., -0., -0., -0., 0.],
            [-0., 0., 0., -0., -0., -0., -0., -0., 0., 0.],
            [-0., -0., 0., -0., 0., 0., 0., -0., 0., -0.],
            [0., -0., 0., 0., 0., -0., 0., 0., -0., -0.],
            [0., -0., 0., -0., -0., -0., -0., 0., 0., -0.],
            [-0., -0., 0., 0., -0., 0., 0., -0., -0., 0.],
            [-0., 0., 0., -0., -0., 0., -0., 0., -0., 0.]], requires_grad=True)
     tensor: tensor([[-0.0003,  0.0373, -0.0517, -0.0143,  0.0863,  0.0142, -0.0411,  0.0303,
              0.0511, -0.0203],
            [ 0.0079, -0.0920, -0.0046, -0.0104, -0.0095,  0.0203,  0.0310,  0.0202,
             -0.0121, -0.0181],
            [-0.0519, -0.0255,  0.0884, -0.0048,  0.0630,  0.0475, -0.0290, -0.0855,
              0.0418, -0.0381],
            [-0.1063,  0.0282,  0.0966, -0.0082, -0.0365,  0.0427, -0.0479,  0.0883,
              0.0667,  0.0052],
            [-0.0698, -0.0686,  0.0423,  0.0692,  0.1213,  0.0422,  0.0058,  0.0344,
             -0.0401, -0.0121],
            [-0.0083,  0.0183,  0.0247,  0.0571, -0.0163,  0.1211,  0.0445, -0.1065,
             -0.0504,  0.0108],
            [-0.0505, -0.0583, -0.0083, -0.0533,  0.0170, -0.0274, -0.1463, -0.0559,
              0.0722,  0.0532],
            [-0.0191, -0.0246,  0.0021,  0.1291, -0.0634,  0.0648, -0.0016,  0.0228,
             -0.0267,  0.0685],
            [ 0.0373,  0.0184,  0.0326, -0.1323,  0.0879, -0.0471, -0.0354, -0.0547,
             -0.0245,  0.0197],
            [-0.0411,  0.0050,  0.0211, -0.0370,  0.0525, -0.0592,  0.0245,  0.0202,
             -0.0753,  0.0099]], grad_fn=<AddBackward0>)
    (bias): Normal:
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     scale: Transform:
     tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
            0.0498], grad_fn=<ExpBackward0>)
     loc: Parameter containing:
    tensor([-0., -0., -0., 0., 0., -0., -0., 0., 0., -0.], requires_grad=True)
     tensor: tensor([-0.0311, -0.0460, -0.1053,  0.0139,  0.0346, -0.0215,  0.0681,  0.0781,
             0.0452,  0.0135], grad_fn=<AddBackward0>)
  )
  (prior): Module(
    (weight): Normal:
     loc: tensor([[-0., 0., 0., -0., 0., -0., 0., 0., 0., -0.],
            [0., 0., -0., -0., -0., -0., -0., -0., 0., 0.],
            [0., 0., 0., -0., -0., -0., 0., 0., -0., 0.],
            [0., -0., 0., -0., 0., 0., -0., -0., -0., 0.],
            [-0., 0., 0., -0., -0., -0., -0., -0., 0., 0.],
            [-0., -0., 0., -0., 0., 0., 0., -0., 0., -0.],
            [0., -0., 0., 0., 0., -0., 0., 0., -0., -0.],
            [0., -0., 0., -0., -0., -0., -0., 0., 0., -0.],
            [-0., -0., 0., 0., -0., 0., 0., -0., -0., 0.],
            [-0., 0., 0., -0., -0., 0., -0., 0., -0., 0.]])
     scale: tensor([[0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162],
            [0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
             0.3162]])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([[-0.2634,  0.2061,  0.3041, -0.1349,  0.0784, -0.1657,  0.1957,  0.2490,
              0.3160, -0.0849],
            [ 0.0335,  0.2293, -0.0953, -0.1795, -0.1057, -0.0171, -0.1661, -0.0754,
              0.0487,  0.0602],
            [ 0.2953,  0.0108,  0.1470, -0.1304, -0.2690, -0.3157,  0.2241,  0.0583,
             -0.1642,  0.2801],
            [ 0.1391, -0.0884,  0.2268, -0.0267,  0.1603,  0.0974, -0.0735, -0.3121,
             -0.0606,  0.2517],
            [-0.2838,  0.1884,  0.2694, -0.1517, -0.0660, -0.2486, -0.0599, -0.1401,
              0.2265,  0.2869],
            [-0.2059, -0.0081,  0.2682, -0.1052,  0.1061,  0.2965,  0.1716, -0.0551,
              0.1544, -0.1494],
            [ 0.1081, -0.1222,  0.0729,  0.0693,  0.2599, -0.2775,  0.0092,  0.0497,
             -0.2638, -0.1386],
            [ 0.0835, -0.2367,  0.0503, -0.1869, -0.0921, -0.2095, -0.2027,  0.0749,
              0.1702, -0.2820],
            [-0.2977, -0.0958,  0.0508,  0.1128, -0.2777,  0.1770,  0.1128, -0.2343,
             -0.0854,  0.0976],
            [-0.1118,  0.3101,  0.1880, -0.0396, -0.1929,  0.1096, -0.1623,  0.1923,
             -0.2513,  0.1232]])
    (bias): Normal:
     loc: tensor([-0., -0., -0., 0., 0., -0., -0., 0., 0., -0.])
     scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
            0.3162])
     posterior: Automatic()
     prior: Module()
     observed: Observed()
     tensor: tensor([-0.2393, -0.2382, -0.0876,  0.1571,  0.1483, -0.3024, -0.2600,  0.2103,
             0.3014, -0.0652])
  )
  (observed): Observed()
)

See the borch.posterior documentation for other posteriors and what parameters you can set. Note that all posteriors does not work with all parameters but you can have different posteriors for the different borch.Module’s in your network.

Exercises

  1. Use what you have learned to train an image classifier for MNIST, you should achieve an accuracy larger than 98 %. Note: you can access MNST using torchvision.datasets.MNIST.

  2. Fit the same model architecture with normal torch and compare the likelihood with the borch network, What are the differences and why?

  3. Port the model to CIFAR and see how you can improve the accuracy.

  4. Show how the Categorical distribution is related to the cross entropy loss function that is commonly used in frequentest deep learning.

Total running time of the script: ( 0 minutes 0.185 seconds)

Gallery generated by Sphinx-Gallery