Neural Networks

Neural networks can be constructed using the borch.nn package.

Now that you’ve had a glimpse of autograd, nn depends on autograd to define models and differentiate them. An nn.Module contains layers, and a method forward(input) that returns the output.

For example, look at this network that classifies digit images:

convnet

convnet

It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.

A typical training procedure for a neural network is as follows:

  • Define a network that has some learnable parameters and/or randomVariables

  • For each batch in a dataset, do:

    • Process the input data through the network

    • Compute the loss (how far is the output from being correct?)

    • Propagate gradients back into the network’s parameters

    • Update the weights of the network, typically using a simple update rule: weight = weight - learning_rate * gradient

Define the network

Let’s define this network:

import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)

        # 6 input channels, 16 output channels, 5x5 convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)

        # An affine operation: y = Wx + b
        # NB after two convolutional operations with 5x5 kernels and no padding,
        # the spatial dimension of an image with intial dimension 32x32 is
        # 5x5 (with 16 channels)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-0.0774,  0.1503, -0.1255, -0.0640,  0.0630],
                [ 0.0595,  0.1217, -0.0309, -0.0112,  0.0043],
                [-0.1666,  0.1739, -0.0512,  0.1344, -0.0321],
                [-0.1129, -0.1998,  0.1292,  0.1143,  0.1476],
                [-0.0027,  0.0059,  0.1920,  0.0098,  0.0048]]],


              [[[ 0.0141,  0.0599,  0.1305, -0.0544,  0.1488],
                [ 0.1270, -0.0026, -0.1380, -0.1876,  0.0489],
                [ 0.1652,  0.0822,  0.0593,  0.1513,  0.0839],
                [ 0.0577, -0.0596, -0.0329,  0.1613, -0.0014],
                [ 0.1026, -0.1543,  0.1927,  0.0072,  0.1568]]],


              [[[ 0.1353, -0.0796,  0.1384,  0.1742,  0.1012],
                [-0.0558,  0.0222,  0.0847,  0.0301, -0.0592],
                [-0.1298,  0.0103,  0.0711, -0.1305, -0.1070],
                [ 0.1687,  0.1726, -0.1407,  0.0858, -0.1460],
                [ 0.0043, -0.0089,  0.0811, -0.0686, -0.0656]]],


              [[[ 0.1872, -0.0864,  0.0585,  0.1785, -0.0565],
                [-0.1216,  0.1257, -0.0771, -0.0275,  0.0386],
                [ 0.1893, -0.0486,  0.1845, -0.0519,  0.1210],
                [ 0.0378,  0.0850,  0.0683,  0.1155, -0.0512],
                [ 0.0487,  0.0793, -0.0349,  0.1779, -0.1606]]],


              [[[-0.1173,  0.0174,  0.0488,  0.0573,  0.1199],
                [ 0.0558,  0.0349, -0.1852,  0.0371, -0.0067],
                [-0.1553, -0.1714,  0.0416, -0.0263,  0.1124],
                [-0.1636,  0.0513, -0.1301,  0.1010, -0.1350],
                [-0.0892, -0.1678,  0.1047, -0.0118, -0.1502]]],


              [[[-0.0300, -0.1080,  0.0871, -0.1145,  0.0902],
                [-0.1072,  0.1453,  0.1286,  0.1711, -0.0089],
                [-0.0012, -0.1712,  0.0279,  0.0078, -0.0613],
                [-0.0928,  0.0310, -0.0577, -0.0740,  0.0623],
                [ 0.0680,  0.1578,  0.1678,  0.1280, -0.0358]]]], requires_grad=True)
       tensor: tensor([[[[-0.0912,  0.1651, -0.0824, -0.1180,  0.1151],
                [ 0.1682,  0.1379,  0.0727,  0.0383,  0.0096],
                [-0.1481,  0.2358, -0.0234,  0.1954, -0.0651],
                [-0.1287, -0.1321,  0.0942,  0.2415,  0.2055],
                [ 0.0139,  0.0093,  0.1414, -0.0114,  0.0490]]],


              [[[ 0.0070,  0.0804,  0.1924, -0.0535,  0.1665],
                [ 0.1279, -0.0944, -0.0823, -0.1923,  0.0228],
                [ 0.2292,  0.0453,  0.1051,  0.1499,  0.0351],
                [ 0.0366, -0.0931, -0.0020,  0.1834, -0.0846],
                [ 0.0633, -0.2066,  0.0871,  0.0926,  0.1730]]],


              [[[ 0.1552, -0.0400,  0.1363,  0.1708,  0.1237],
                [-0.0265, -0.0098,  0.0622,  0.0631, -0.1061],
                [-0.1840,  0.0792,  0.1222, -0.2076, -0.0974],
                [ 0.0809,  0.1328, -0.0938, -0.0133, -0.0958],
                [-0.0706,  0.0437,  0.0714, -0.1809, -0.1122]]],


              [[[ 0.2537, -0.0915,  0.0056,  0.1602, -0.1162],
                [-0.1037,  0.2019, -0.0716,  0.0006, -0.0065],
                [ 0.1379, -0.0639,  0.1727, -0.1089,  0.2738],
                [ 0.0247,  0.0751,  0.0598,  0.0804, -0.1690],
                [ 0.0494,  0.1083, -0.0208,  0.0432, -0.1384]]],


              [[[-0.2125,  0.0266,  0.0491,  0.0178,  0.0777],
                [ 0.0899,  0.0169, -0.2210,  0.0300,  0.0083],
                [-0.1368, -0.2311,  0.0537, -0.0592,  0.1702],
                [-0.2015,  0.0771, -0.1454,  0.0375, -0.2007],
                [-0.0780, -0.1716,  0.0919, -0.0333, -0.1659]]],


              [[[-0.0071, -0.1041,  0.1053, -0.1506,  0.0268],
                [-0.1314,  0.1820,  0.0954,  0.1974, -0.0269],
                [-0.0590, -0.2216,  0.0056, -0.1258, -0.1532],
                [-0.0695,  0.0420, -0.0078, -0.1728,  0.1632],
                [-0.0141,  0.2390,  0.1728,  0.1627, -0.1093]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.1635,  0.0177,  0.1092,  0.1259,  0.1528,  0.0503],
             requires_grad=True)
       tensor: tensor([-0.1425,  0.0416,  0.1181,  0.1505,  0.1053,  0.0927],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., 0., -0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.]]],


              [[[0., 0., 0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [0., -0., -0., 0., -0.],
                [0., -0., 0., 0., 0.]]],


              [[[0., -0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [0., -0., 0., -0., -0.]]],


              [[[0., -0., 0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.]]],


              [[[-0., 0., 0., 0., 0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[-0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., 0., 0., -0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.0774,  0.1503, -0.1255, -0.0640,  0.0630],
                [ 0.0595,  0.1217, -0.0309, -0.0112,  0.0043],
                [-0.1666,  0.1739, -0.0512,  0.1344, -0.0321],
                [-0.1129, -0.1998,  0.1292,  0.1143,  0.1476],
                [-0.0027,  0.0059,  0.1920,  0.0098,  0.0048]]],


              [[[ 0.0141,  0.0599,  0.1305, -0.0544,  0.1488],
                [ 0.1270, -0.0026, -0.1380, -0.1876,  0.0489],
                [ 0.1652,  0.0822,  0.0593,  0.1513,  0.0839],
                [ 0.0577, -0.0596, -0.0329,  0.1613, -0.0014],
                [ 0.1026, -0.1543,  0.1927,  0.0072,  0.1568]]],


              [[[ 0.1353, -0.0796,  0.1384,  0.1742,  0.1012],
                [-0.0558,  0.0222,  0.0847,  0.0301, -0.0592],
                [-0.1298,  0.0103,  0.0711, -0.1305, -0.1070],
                [ 0.1687,  0.1726, -0.1407,  0.0858, -0.1460],
                [ 0.0043, -0.0089,  0.0811, -0.0686, -0.0656]]],


              [[[ 0.1872, -0.0864,  0.0585,  0.1785, -0.0565],
                [-0.1216,  0.1257, -0.0771, -0.0275,  0.0386],
                [ 0.1893, -0.0486,  0.1845, -0.0519,  0.1210],
                [ 0.0378,  0.0850,  0.0683,  0.1155, -0.0512],
                [ 0.0487,  0.0793, -0.0349,  0.1779, -0.1606]]],


              [[[-0.1173,  0.0174,  0.0488,  0.0573,  0.1199],
                [ 0.0558,  0.0349, -0.1852,  0.0371, -0.0067],
                [-0.1553, -0.1714,  0.0416, -0.0263,  0.1124],
                [-0.1636,  0.0513, -0.1301,  0.1010, -0.1350],
                [-0.0892, -0.1678,  0.1047, -0.0118, -0.1502]]],


              [[[-0.0300, -0.1080,  0.0871, -0.1145,  0.0902],
                [-0.1072,  0.1453,  0.1286,  0.1711, -0.0089],
                [-0.0012, -0.1712,  0.0279,  0.0078, -0.0613],
                [-0.0928,  0.0310, -0.0577, -0.0740,  0.0623],
                [ 0.0680,  0.1578,  0.1678,  0.1280, -0.0358]]]])
      (bias): Normal:
       loc: tensor([-0., 0., 0., 0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.1635,  0.0177,  0.1092,  0.1259,  0.1528,  0.0503])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[ 1.7231e-02,  4.3731e-02,  5.4681e-02,  3.4570e-02,  4.3578e-02],
                [-2.7786e-02, -4.7320e-02,  3.1995e-02,  4.2537e-02,  7.5819e-02],
                [ 8.9881e-03,  2.6359e-02, -8.3574e-03, -5.6418e-02, -3.9465e-02],
                [-1.9657e-02,  5.7231e-03,  5.5066e-03, -1.4926e-02, -7.4856e-02],
                [-6.2679e-02,  3.7226e-02,  8.0120e-02, -2.3663e-02, -5.3511e-02]],

               [[-7.5590e-02,  6.4561e-02, -4.3921e-02, -7.7799e-03, -5.2721e-02],
                [-1.0355e-02,  5.9574e-02,  1.1166e-02, -4.8245e-02,  4.1576e-02],
                [ 5.6966e-02, -7.2947e-03, -1.9590e-03,  6.9857e-02, -2.1827e-02],
                [ 3.2483e-02, -6.1154e-02,  7.6961e-02,  4.5617e-02, -5.6834e-02],
                [ 3.4771e-02, -3.0784e-02,  4.8407e-02,  2.7358e-02,  6.5453e-02]],

               [[-7.3062e-02, -2.7048e-02,  1.0655e-02,  7.5364e-02,  6.3020e-02],
                [ 1.2004e-02, -3.4794e-02,  4.0299e-02,  6.4137e-02, -2.4964e-02],
                [ 7.4683e-02, -2.4405e-02, -6.7041e-02,  2.7558e-02,  2.7585e-02],
                [-2.3105e-02,  3.5924e-02, -1.3142e-02,  3.6770e-02, -5.1734e-02],
                [-3.0248e-03, -7.3805e-02,  1.5975e-02, -3.8458e-02, -3.4536e-02]],

               [[ 1.2704e-02, -6.5938e-02,  5.8697e-04, -4.0066e-03,  6.0744e-02],
                [-6.6682e-02,  7.7161e-02,  8.1204e-02,  1.0834e-02, -9.7154e-03],
                [-4.5248e-02,  4.1191e-02, -3.1746e-02,  1.4713e-02,  1.5446e-02],
                [-6.1854e-02, -5.4387e-02,  3.8440e-02, -3.7393e-02,  7.6462e-02],
                [ 6.8184e-02, -3.0699e-02, -6.0217e-02,  3.2648e-02, -2.2715e-02]],

               [[ 7.6031e-02,  6.2038e-02,  2.1564e-02, -6.5555e-02, -6.4686e-02],
                [-4.2324e-02,  3.6401e-02, -6.7974e-02, -3.9416e-02, -4.4166e-02],
                [ 6.7690e-02, -6.9973e-02, -4.7069e-02, -7.6832e-03, -4.0938e-03],
                [-2.6045e-02, -4.6521e-02,  3.4383e-03,  6.4282e-02, -2.5852e-02],
                [ 3.2686e-02,  3.1602e-02,  3.5400e-02, -7.0873e-02,  7.1903e-02]],

               [[ 6.0529e-02, -5.2981e-02, -3.9675e-02, -6.5714e-02,  2.8020e-02],
                [ 3.5057e-02,  1.3588e-02, -3.5319e-02, -7.1891e-02, -6.0487e-02],
                [-6.3268e-02,  1.4282e-02, -3.6912e-02,  3.1696e-02,  5.1382e-02],
                [-7.2075e-02, -5.2067e-02,  5.7057e-02, -6.0579e-02, -5.5177e-02],
                [-2.2326e-03,  6.6584e-02, -6.6819e-02, -5.9396e-02,  4.9889e-02]]],


              [[[ 6.1514e-02,  3.9812e-02, -7.7776e-02, -4.9428e-02, -2.5424e-02],
                [-1.3566e-02, -3.6820e-02, -1.2480e-02,  4.0800e-02, -7.4332e-02],
                [ 2.8230e-02, -2.4562e-02, -2.9086e-02,  6.3475e-02, -1.5058e-02],
                [-1.5993e-02,  2.6303e-02,  8.7087e-03, -5.3252e-02,  2.7997e-02],
                [ 2.8034e-02, -3.0461e-02, -1.0604e-02,  4.4506e-02, -1.2910e-02]],

               [[-7.8455e-02, -7.8485e-02, -5.5244e-02,  4.7309e-03, -3.7835e-02],
                [-5.3291e-02,  3.5150e-02, -6.4878e-02, -7.4903e-02, -5.5693e-02],
                [ 3.2479e-02,  6.5471e-02,  4.4178e-02, -5.4057e-02, -7.9336e-02],
                [-6.8247e-02,  6.5587e-02, -8.7773e-04, -3.5329e-02, -3.8033e-02],
                [-3.8157e-02,  4.2570e-02,  7.6967e-02, -4.9785e-02, -2.8498e-02]],

               [[ 4.3747e-02,  5.4973e-02, -2.6460e-02, -1.5235e-02,  7.8260e-02],
                [ 1.8068e-02,  1.9185e-02, -6.4915e-02,  7.4767e-03,  6.4246e-02],
                [-6.3240e-02, -7.0417e-03, -6.4609e-02, -3.5764e-03,  1.0317e-02],
                [-4.5486e-02,  7.3897e-02, -2.8131e-02, -8.0268e-02,  1.2427e-02],
                [ 2.6636e-02,  5.3320e-02,  8.0113e-02, -2.4988e-02, -3.3729e-02]],

               [[-3.0252e-02,  2.0032e-02,  2.2845e-02, -3.2871e-02, -7.2805e-02],
                [-1.2501e-02,  1.2056e-02,  1.7524e-02,  3.1636e-02, -3.6845e-02],
                [ 2.2584e-02,  1.5433e-02, -8.1113e-02,  4.8879e-03, -4.5628e-02],
                [-6.5354e-03, -1.2652e-02, -6.9698e-02,  6.7797e-02,  6.2843e-02],
                [ 1.2404e-02,  1.8414e-02, -4.5958e-02, -6.2785e-03,  4.0931e-02]],

               [[-1.8430e-02, -2.1989e-02, -8.0820e-03,  7.2810e-02, -9.3746e-03],
                [-7.0612e-02,  5.1172e-02, -7.0489e-03, -4.5064e-02,  3.5534e-02],
                [-3.2953e-02,  3.0727e-02,  4.8882e-02, -3.9965e-02, -2.3211e-02],
                [ 6.9945e-02,  2.6839e-02,  3.3160e-03, -2.7622e-02,  7.6837e-02],
                [-1.7656e-02, -4.1508e-03,  3.8821e-02,  3.6178e-02,  3.2764e-02]],

               [[-3.1492e-02, -7.9136e-02, -4.2417e-02,  3.8367e-02, -7.3788e-02],
                [-4.9922e-02,  2.4759e-02,  1.3668e-03,  7.1963e-02, -6.7763e-02],
                [-4.6869e-02,  6.6676e-02,  3.5155e-02, -1.8898e-02, -3.7238e-02],
                [-1.2625e-02, -5.9552e-02, -6.8551e-02,  5.6212e-02, -3.5397e-02],
                [-3.0192e-02, -5.3948e-02,  4.1616e-03, -1.9126e-02, -2.7772e-02]]],


              [[[ 7.5304e-02,  4.8642e-02,  6.6816e-03, -4.8956e-02,  7.0525e-03],
                [ 5.3506e-02, -6.3276e-02,  3.2309e-02, -6.3837e-02, -7.6400e-03],
                [ 8.0993e-02,  5.0589e-02, -4.4620e-02, -1.7407e-02, -4.0847e-02],
                [-7.9109e-02, -5.0230e-02, -6.1707e-02,  3.5480e-02,  9.4730e-03],
                [-4.0758e-02,  3.8149e-02, -2.6608e-02,  6.5586e-02,  6.7719e-02]],

               [[ 3.7999e-02,  5.1426e-02,  4.6990e-02, -8.4217e-04, -7.8228e-02],
                [ 7.4991e-02, -1.3921e-03, -5.2225e-02, -3.7669e-02, -1.3416e-02],
                [-6.8178e-02, -3.8191e-02, -4.8837e-02,  5.6142e-02,  8.0981e-02],
                [-4.9641e-02,  2.5925e-02,  3.2090e-02,  1.2274e-02, -7.0899e-02],
                [ 2.0209e-02,  7.2872e-02, -2.2630e-02, -6.1033e-02,  4.0857e-02]],

               [[-6.5934e-02,  6.1233e-02,  3.0293e-02, -7.4631e-02,  7.8141e-02],
                [-1.4982e-02,  2.9501e-02, -6.2282e-02,  2.6212e-02, -2.5934e-02],
                [ 3.0596e-02,  2.5406e-02,  4.6264e-02, -2.3452e-02, -1.9204e-02],
                [ 8.5599e-03,  7.3909e-02, -6.4926e-02, -8.6878e-03, -3.7090e-02],
                [ 3.4250e-02, -5.0255e-02, -7.8042e-02, -5.1085e-02, -8.0715e-02]],

               [[ 2.6317e-02, -3.6153e-02, -5.3455e-02,  2.1931e-03,  6.2524e-04],
                [-4.6965e-02,  3.9974e-02, -1.2460e-02, -4.2091e-02, -4.6497e-02],
                [-3.4564e-02, -1.1512e-02,  2.1072e-02, -6.0303e-02,  1.7830e-02],
                [-6.8750e-02,  2.9228e-02,  2.5533e-02,  5.8319e-02, -2.6735e-02],
                [ 3.2792e-02,  3.2718e-02,  5.9794e-02, -7.6940e-02, -4.0923e-02]],

               [[-7.0395e-02, -7.5975e-02, -2.7344e-02, -2.3934e-02, -2.9200e-02],
                [ 5.4427e-02, -6.5287e-02,  3.4746e-02,  1.0117e-02, -6.4013e-02],
                [-7.9322e-02,  7.8159e-02, -7.0473e-02,  4.4684e-02,  2.1939e-02],
                [ 5.5900e-02, -6.5708e-02,  3.3804e-02, -1.8570e-02, -7.0815e-02],
                [ 9.0783e-03, -7.2442e-02,  7.7275e-02,  5.2036e-03,  2.7754e-02]],

               [[-3.5030e-02, -1.2288e-02,  6.1587e-02,  1.0093e-02, -3.9849e-02],
                [ 3.1310e-02,  3.2095e-02, -7.3972e-02,  6.0673e-02,  6.6248e-02],
                [ 7.5880e-02,  7.9613e-02, -1.4684e-02,  6.6668e-02,  1.1897e-02],
                [-6.6942e-02,  6.9498e-02,  7.3219e-02, -5.7326e-02,  7.9767e-02],
                [-2.2455e-02,  2.7784e-02, -5.6006e-02, -2.8644e-02, -4.7902e-02]]],


              ...,


              [[[ 5.4641e-02, -2.8195e-02,  1.2863e-02,  6.9362e-02, -6.5228e-02],
                [-1.4326e-02, -2.4860e-02,  6.7933e-02,  6.0904e-03, -5.2127e-02],
                [-3.5126e-02,  1.9526e-02, -4.2415e-02,  7.7512e-03,  6.7621e-02],
                [-7.7149e-02, -3.9294e-02,  2.8953e-02,  2.8484e-02, -5.9953e-02],
                [ 3.0683e-02, -7.1082e-02, -1.7986e-02,  1.3298e-02, -6.4964e-02]],

               [[ 3.5600e-02, -6.4155e-02,  4.7039e-02, -2.0131e-02, -2.7915e-02],
                [ 2.9918e-02,  6.5944e-03, -6.6870e-02, -6.3787e-02, -4.9677e-02],
                [-3.6079e-02, -2.8304e-02,  2.9721e-02,  2.8190e-02, -5.0218e-02],
                [ 6.4923e-02, -4.9635e-02, -3.6667e-03,  7.9379e-02,  2.5979e-02],
                [-4.8337e-02,  7.7505e-02, -7.3112e-02,  2.4510e-02,  2.5683e-02]],

               [[ 2.8887e-03, -2.1671e-02, -5.2347e-02, -6.3329e-02, -4.0586e-02],
                [-1.2757e-02,  3.3395e-02, -7.8268e-02,  7.3369e-02,  5.0369e-04],
                [ 2.6221e-02, -2.9271e-02, -6.5565e-02, -1.6796e-02, -4.9055e-02],
                [-5.8221e-02, -4.2509e-02,  4.6818e-06,  2.6047e-05,  4.1964e-02],
                [ 1.0361e-02, -1.0747e-02, -5.5872e-02, -4.5506e-02, -2.9223e-02]],

               [[-1.9352e-02, -8.0087e-02, -3.3809e-03,  3.9983e-02,  6.5648e-02],
                [-2.5674e-02,  5.8915e-02,  1.8416e-02,  5.8460e-02,  3.2707e-02],
                [-6.6357e-02,  6.9795e-02,  8.6752e-03,  5.9294e-02,  1.7985e-02],
                [-6.5379e-02,  2.3563e-02,  5.0532e-02,  3.5488e-03, -4.1146e-02],
                [ 8.1383e-02,  5.7224e-02, -7.2400e-02, -4.0180e-02, -6.1370e-02]],

               [[ 5.9150e-02,  5.6013e-02, -4.5474e-02, -4.0693e-02,  4.2932e-02],
                [-1.3553e-02,  4.4707e-02, -2.7249e-02,  2.3061e-02,  1.9638e-02],
                [ 2.1247e-02, -6.3221e-02,  5.1882e-02, -1.6282e-02,  6.9770e-02],
                [-7.0485e-02,  7.4524e-02,  4.4509e-02, -6.5970e-02,  2.9617e-02],
                [-8.1458e-02, -4.9716e-02,  1.2315e-02, -2.0425e-02,  3.0172e-02]],

               [[ 5.2212e-02, -3.5905e-02,  2.5783e-02, -6.0258e-02,  3.5215e-02],
                [-7.5870e-02, -1.5704e-02,  3.3627e-02, -4.0729e-02,  5.7335e-02],
                [-3.5374e-02, -7.5164e-02, -7.3468e-02, -1.1014e-02,  1.6214e-02],
                [-3.1993e-02, -2.4012e-02, -1.6525e-03, -6.9434e-02,  2.8824e-02],
                [-2.4923e-02,  7.8550e-02,  4.5400e-02,  2.7779e-02, -6.5854e-02]]],


              [[[ 2.3195e-02, -6.8559e-02, -1.9293e-02, -3.7088e-02, -7.3186e-02],
                [ 7.8055e-02,  5.0381e-03,  3.0678e-02, -5.8232e-02,  2.8428e-02],
                [-2.2133e-02, -1.6136e-03,  6.5804e-02,  3.9714e-03, -3.9261e-02],
                [ 2.5493e-02, -1.4515e-02,  3.1299e-02, -1.6629e-02, -2.5878e-02],
                [-2.4748e-02, -6.8695e-02,  4.8038e-02,  1.7510e-02, -2.4795e-02]],

               [[ 2.7344e-02,  5.4179e-02,  1.8617e-02,  6.7468e-02,  4.8763e-02],
                [ 3.7600e-02,  3.9927e-02,  5.1062e-02,  2.1710e-02, -3.2169e-02],
                [-3.8513e-02, -4.6700e-02, -3.3343e-02,  5.7257e-02,  7.1398e-02],
                [ 3.1596e-02, -6.1682e-02, -1.1294e-02, -4.6606e-02, -1.9235e-02],
                [-5.1762e-02,  4.1756e-03,  5.5901e-02,  5.0582e-02, -3.5234e-02]],

               [[ 6.8751e-02,  1.9294e-02,  9.9260e-04,  4.8577e-02,  1.1296e-02],
                [-2.5931e-03,  5.6043e-02, -3.9379e-02, -1.5890e-02,  2.7560e-02],
                [-6.1309e-02,  4.4243e-02, -6.8550e-02, -7.5816e-02,  6.4328e-02],
                [-3.5933e-02,  1.5707e-02, -4.1360e-03,  2.3218e-02, -5.3996e-02],
                [-5.8497e-02,  2.2945e-02, -1.6730e-02, -4.9801e-02,  4.5134e-02]],

               [[ 4.3283e-02,  1.4086e-02, -7.7765e-03, -2.4735e-02, -6.1307e-02],
                [-3.7167e-02,  1.0755e-05, -3.5806e-02, -2.9059e-02,  7.9799e-02],
                [-2.3940e-02,  3.7716e-02,  2.2327e-02,  4.3423e-02, -4.9689e-02],
                [-3.8642e-02, -4.3529e-02, -2.7830e-02, -4.9579e-02,  6.6404e-02],
                [ 1.3762e-02, -2.3933e-02,  7.1659e-02, -1.3726e-02, -8.0242e-02]],

               [[-1.5402e-02,  1.8298e-02, -5.1471e-02,  2.4580e-02,  9.6023e-03],
                [ 6.1876e-02,  7.8261e-02,  2.6394e-02, -7.8227e-02, -7.6062e-02],
                [ 6.7169e-02, -7.8952e-03, -7.6834e-02, -7.2395e-02, -8.1512e-02],
                [ 2.4895e-02,  7.4719e-02, -6.9676e-02, -2.0183e-02,  4.7940e-02],
                [-1.6931e-02,  6.4322e-02,  4.9096e-02,  6.7067e-02, -5.1128e-02]],

               [[-4.2134e-03,  7.9587e-02, -7.7337e-02, -1.6919e-02, -4.4513e-02],
                [ 4.5003e-02,  2.9848e-02, -3.2239e-02, -5.3997e-02,  3.4833e-02],
                [-4.6084e-02,  7.7325e-02, -8.2341e-04, -2.9711e-02,  4.7059e-02],
                [ 7.1990e-02, -1.8925e-02,  6.9833e-02, -3.8232e-02, -5.3586e-02],
                [ 5.6777e-02,  5.4212e-02, -7.0351e-02,  7.8116e-02, -5.8073e-03]]],


              [[[-3.9872e-02,  2.2878e-02, -5.4838e-02, -7.8741e-02, -2.4075e-02],
                [ 5.5670e-02, -7.5194e-02, -2.3993e-02, -1.3565e-02, -7.6118e-02],
                [ 5.9835e-02,  7.7078e-02, -1.9101e-02,  3.7423e-02,  5.8969e-02],
                [-7.6931e-02,  5.4068e-02,  7.6462e-02,  6.0935e-02,  6.1393e-02],
                [-3.0153e-02, -6.9821e-02, -7.9367e-02, -6.8787e-02,  4.8573e-02]],

               [[-7.7210e-04,  1.2697e-02, -7.1333e-02, -4.3644e-02, -3.0627e-02],
                [ 6.4280e-02, -6.3660e-02,  5.7517e-02,  3.5869e-02, -2.4693e-02],
                [ 4.2786e-03, -8.1200e-02,  6.9931e-02, -2.5703e-04,  1.7692e-02],
                [ 8.5987e-03, -1.2595e-02,  7.9205e-02,  3.0073e-02, -3.2985e-02],
                [-6.3697e-02,  3.2692e-02, -1.9431e-02,  5.3542e-02, -2.0049e-02]],

               [[ 4.7488e-02,  3.1486e-02,  4.3938e-02,  3.8207e-02, -6.3004e-02],
                [ 7.6382e-02,  1.8666e-02,  1.0028e-02,  6.2085e-02,  5.3552e-02],
                [-3.0010e-02,  2.7386e-02, -2.2148e-02, -5.4034e-02, -2.1415e-02],
                [-3.2287e-02, -4.1362e-02,  1.2052e-02, -6.5838e-02, -4.6819e-02],
                [-6.8102e-02,  5.9098e-02,  2.8529e-02, -5.3848e-02,  2.3559e-02]],

               [[ 2.5513e-02,  3.7517e-02,  5.5636e-02,  4.3730e-02, -3.5048e-02],
                [-5.4454e-02,  7.0706e-02, -5.7952e-02,  2.3890e-02,  3.0251e-02],
                [ 2.0294e-02, -6.2255e-02, -7.7577e-02, -6.8416e-02, -4.8070e-02],
                [ 5.3928e-02, -6.0171e-02,  4.9991e-02,  4.6665e-02, -1.5579e-02],
                [ 1.9901e-02, -6.1094e-02, -1.4091e-02, -6.6292e-02,  1.2545e-02]],

               [[ 7.3009e-02,  7.1030e-02, -3.5882e-02, -5.9879e-02,  2.1529e-02],
                [-2.7738e-02, -4.8476e-02, -3.5715e-03, -5.2242e-03,  4.8341e-03],
                [-4.1100e-02,  2.7022e-02, -5.5728e-02,  5.0925e-02,  2.2531e-02],
                [ 6.5409e-02,  2.5243e-02,  5.5194e-02,  5.0815e-02,  3.7556e-02],
                [ 4.0211e-02,  3.1016e-02, -2.9596e-02, -4.3925e-03, -4.0317e-02]],

               [[ 4.9369e-02,  4.1262e-02,  7.0892e-02, -7.3260e-02,  4.2668e-02],
                [ 3.7235e-02,  3.5402e-02,  3.3255e-02,  5.4474e-02,  5.3561e-02],
                [-1.7112e-02, -1.1525e-02,  1.9306e-02,  9.0656e-04,  2.4812e-02],
                [-1.0841e-02, -3.1940e-02,  6.6983e-02, -1.3595e-02,  7.1947e-02],
                [-1.4649e-02,  3.0190e-02, -4.8740e-02, -8.1647e-02, -4.3005e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[ 1.8883e-02,  6.8035e-02,  4.9171e-02,  9.7844e-02,  8.8935e-02],
                [-1.8658e-02, -6.0121e-02,  1.3134e-01,  1.0360e-01,  1.0962e-01],
                [-4.9652e-02,  9.1593e-03, -1.0749e-01, -2.3633e-02, -6.8945e-02],
                [-7.2921e-02,  5.4190e-03, -3.7902e-02, -9.1251e-02, -9.6080e-02],
                [-9.6740e-02,  3.3653e-02,  1.2174e-01, -2.1966e-02, -6.8750e-02]],

               [[-5.1564e-02,  9.9327e-02, -9.0770e-02,  1.1558e-01, -6.4334e-02],
                [-3.3721e-02,  1.3854e-01,  7.1261e-02,  2.3334e-02,  6.2038e-03],
                [ 4.1357e-02,  4.7015e-02,  1.1618e-01,  1.1209e-01, -1.5182e-02],
                [-4.2263e-02, -5.2076e-02,  1.2271e-01,  6.1315e-02, -3.5989e-02],
                [ 2.4518e-02, -7.3358e-02,  1.1719e-01,  6.8904e-02,  6.7703e-02]],

               [[-9.5581e-02, -2.1765e-02,  2.3421e-03,  1.1011e-01,  2.8519e-02],
                [ 9.2380e-02, -8.1339e-02,  4.1274e-02,  8.2023e-02,  4.4325e-02],
                [ 1.2266e-01, -8.0220e-02,  1.5860e-02,  2.0438e-02,  9.9168e-02],
                [-5.8377e-02,  3.0828e-02, -4.9333e-02,  2.4652e-02, -9.0897e-02],
                [-2.0278e-03, -9.2374e-02,  7.0469e-02,  4.6639e-02,  1.4054e-02]],

               [[ 6.7318e-02, -1.4048e-01, -6.6024e-02, -7.5259e-02, -7.3505e-03],
                [-5.9789e-03,  8.2757e-02, -3.2293e-03, -6.3011e-02, -3.5699e-02],
                [-4.6309e-03,  2.9589e-02, -8.0518e-02, -3.5417e-02, -8.1604e-02],
                [-4.7370e-02, -1.8663e-02,  1.4938e-02, -2.6669e-02,  4.6225e-02],
                [ 1.0167e-01, -4.0576e-03, -6.3230e-02,  1.9653e-02,  2.6441e-03]],

               [[ 7.7611e-02,  9.2593e-02,  2.8293e-02, -8.5795e-02,  2.8690e-03],
                [ 2.5574e-03,  1.1727e-01, -1.4177e-01, -8.9946e-02, -8.8752e-02],
                [ 1.1321e-01, -3.1915e-03, -1.5451e-01, -1.1661e-01,  3.4048e-02],
                [-5.9230e-02,  3.2312e-02,  6.5634e-02,  6.5836e-02,  6.9050e-02],
                [ 2.9482e-02,  5.5532e-02,  1.4458e-02, -2.6399e-02,  1.5079e-01]],

               [[ 8.1164e-02, -1.5860e-02, -6.3951e-02, -1.4953e-01,  2.6079e-02],
                [ 2.9781e-03,  3.3932e-02, -7.3611e-02, -3.7866e-02,  6.0404e-03],
                [-9.4234e-02,  6.2938e-03, -1.1837e-01, -2.1665e-02,  7.8637e-02],
                [-9.3179e-02, -4.4013e-02,  4.6450e-02,  1.6881e-02, -7.4646e-03],
                [-1.1145e-02, -3.3850e-02, -8.3531e-02, -7.9519e-02,  1.1401e-01]]],


              [[[ 3.2203e-02,  2.0910e-02, -9.8143e-02,  4.6689e-02, -3.8075e-02],
                [ 1.1719e-01,  5.9686e-03,  1.1960e-02, -3.0230e-03, -5.2190e-02],
                [ 1.9934e-03, -4.0891e-02, -3.4388e-02,  1.7186e-01, -2.0041e-02],
                [ 6.2537e-02, -4.8928e-02,  3.7143e-02, -7.1022e-02, -5.8980e-02],
                [-1.8138e-02, -1.2401e-03,  9.8512e-04,  1.2971e-02, -1.6621e-02]],

               [[-6.7650e-02, -1.2190e-01, -2.9892e-02, -7.9844e-03,  7.0830e-02],
                [ 1.9891e-02,  7.0736e-02, -5.9314e-02, -8.9657e-02,  5.4045e-02],
                [ 6.0116e-02,  1.1283e-01,  8.1388e-02, -3.8265e-03, -8.1494e-02],
                [-1.2480e-01,  4.0550e-02, -1.2108e-02,  3.0710e-02, -6.6255e-03],
                [-3.1892e-02,  4.0138e-02,  1.0084e-01, -1.1013e-01,  1.3274e-03]],

               [[ 1.6757e-01,  4.7177e-02,  3.3340e-02, -3.8067e-02,  8.9873e-02],
                [-3.3208e-02, -1.9158e-03,  2.2761e-02, -1.7572e-02,  1.7825e-01],
                [-3.9615e-02,  3.0436e-02, -2.7566e-02,  2.7972e-02,  4.4888e-02],
                [-1.1037e-01,  2.3430e-02, -5.5122e-02, -1.4218e-01, -3.7357e-02],
                [ 2.8481e-02,  8.2003e-02,  7.9964e-02, -2.2991e-03,  2.2325e-02]],

               [[-4.3186e-02,  3.3446e-02,  7.1732e-02,  3.7266e-02,  8.8012e-03],
                [-1.2796e-02,  6.7978e-02,  7.0809e-02,  7.3238e-02, -4.9962e-02],
                [-1.9777e-02,  2.5981e-02, -1.3383e-01, -4.4816e-02, -4.7046e-02],
                [-2.1233e-02,  6.0056e-02, -1.1861e-01,  7.6347e-02,  5.1444e-02],
                [ 2.1803e-03,  7.3725e-02, -5.2894e-02, -1.7344e-02,  9.9680e-02]],

               [[-8.6873e-02, -8.9221e-02, -5.2683e-02,  1.4289e-01, -8.9839e-02],
                [-7.3905e-03,  3.6452e-02,  2.7346e-02, -8.3526e-02,  3.3003e-02],
                [-9.6205e-02,  1.0080e-01,  1.0827e-01, -1.0169e-01, -1.2534e-01],
                [ 1.2297e-01,  4.5520e-02, -8.3773e-02, -1.2272e-02,  1.1171e-02],
                [-2.2908e-02, -5.9739e-03, -3.1049e-02,  4.7375e-02,  3.4770e-03]],

               [[-4.6835e-02, -1.4363e-01, -1.4266e-01,  4.5771e-02, -7.2627e-02],
                [-1.0377e-01,  5.4805e-02,  2.4762e-02,  1.2288e-01, -6.7578e-02],
                [-1.0513e-01,  8.9687e-02,  4.6560e-02, -2.5460e-02, -7.3801e-02],
                [ 3.8986e-02, -3.2050e-02,  2.2372e-02,  3.5860e-02,  1.3424e-02],
                [-4.8953e-02, -7.0792e-02, -3.7677e-02, -9.9571e-02, -3.7092e-02]]],


              [[[ 2.7008e-02,  9.3211e-02, -2.5650e-02, -1.2120e-01,  8.9026e-02],
                [ 3.3563e-02,  4.8096e-02,  2.4675e-02, -6.3760e-02, -1.0420e-02],
                [ 4.8996e-02,  1.9733e-02, -6.8690e-03,  5.3957e-02,  2.5845e-02],
                [-1.0207e-01,  5.9896e-02, -5.3257e-02,  6.9946e-02,  7.8127e-03],
                [ 3.7401e-02,  2.6128e-03, -6.7030e-02,  1.2134e-01,  1.4503e-01]],

               [[ 3.0161e-02,  1.6157e-02,  8.4664e-02,  4.0789e-02, -5.2141e-02],
                [ 6.8237e-02, -1.7970e-02,  2.1938e-02, -3.8875e-02, -4.1224e-02],
                [ 2.0714e-02, -2.9881e-02, -7.7453e-03,  1.2165e-01,  7.5875e-02],
                [ 1.2965e-02,  1.2694e-02,  5.0254e-02,  3.9009e-02, -1.0212e-01],
                [ 1.5359e-02,  3.5578e-03,  3.9632e-02, -9.0168e-02,  1.0299e-02]],

               [[-2.0960e-02,  3.3475e-02,  4.8820e-02, -5.4706e-02,  5.5381e-02],
                [ 2.4722e-02, -2.7748e-02, -2.8843e-02, -2.2392e-02, -6.9717e-02],
                [-2.7829e-02,  1.4679e-03,  2.4757e-02, -8.7816e-02,  1.8764e-02],
                [-1.0710e-01,  1.1481e-01, -5.6012e-02,  8.1546e-02, -9.8836e-02],
                [ 2.9070e-02, -6.9273e-02, -5.0700e-02, -4.3182e-02, -8.8499e-02]],

               [[ 1.5313e-02, -3.8199e-02,  3.5146e-03, -1.5000e-02, -4.4576e-02],
                [-2.5990e-02,  5.2993e-03,  3.3307e-02, -1.1827e-01, -8.6286e-03],
                [ 1.0449e-02,  5.4776e-02,  5.0992e-03, -1.5860e-01, -7.7440e-03],
                [-3.6204e-03, -4.7962e-02, -4.2371e-02,  6.8431e-02, -5.3339e-02],
                [ 4.4269e-02,  4.1502e-02, -1.2462e-02, -7.4518e-02, -7.7262e-02]],

               [[ 5.1610e-03, -5.2510e-02,  5.5136e-02, -1.9081e-02, -9.9088e-02],
                [ 8.0757e-02, -3.4951e-02,  4.9554e-02, -3.2164e-02,  5.6647e-02],
                [-8.6125e-02,  6.2141e-02, -7.2851e-02,  6.7731e-02,  1.7170e-02],
                [ 5.2774e-02, -1.6591e-01,  3.4490e-02,  5.7166e-03, -6.3123e-02],
                [ 4.2355e-02, -5.6360e-02,  7.3484e-02,  3.0939e-02,  3.8095e-02]],

               [[ 1.5646e-03,  1.7834e-02,  7.3988e-03, -1.5700e-02, -3.9332e-02],
                [ 1.0043e-02,  5.1222e-02, -9.1322e-02,  5.1683e-02,  1.2517e-02],
                [ 8.1686e-02, -1.2309e-02, -2.5122e-02,  7.3020e-02,  5.9552e-02],
                [-5.6445e-02,  2.9868e-02,  2.9082e-02, -1.0408e-02,  1.5398e-01],
                [ 5.2559e-02, -1.1959e-02, -1.4521e-02,  9.7013e-02, -9.5779e-02]]],


              ...,


              [[[ 1.0793e-01, -2.4074e-02,  3.9709e-02,  1.0110e-01, -1.1722e-01],
                [-1.4333e-02, -1.4521e-02,  2.3373e-03,  6.7708e-02, -4.9924e-02],
                [ 8.0992e-02,  6.2219e-03, -9.5247e-02, -1.0383e-02,  1.2116e-01],
                [-1.3052e-01,  3.1068e-02, -2.0858e-02,  1.0203e-01, -1.2493e-02],
                [-4.8291e-02, -3.7350e-02, -2.5973e-02, -5.2889e-02, -1.4211e-01]],

               [[ 1.4048e-02, -6.0941e-02,  1.4314e-01, -4.0112e-02,  2.0749e-03],
                [-5.3904e-02,  4.6111e-02, -9.0083e-02, -6.0871e-02,  2.1673e-02],
                [-6.1354e-02, -2.4384e-02,  3.1761e-02,  3.5867e-02, -5.3144e-02],
                [ 9.2168e-02,  1.2534e-02, -2.1070e-02,  1.7463e-01,  4.6518e-02],
                [-8.0494e-02,  8.2646e-02, -1.1117e-01,  2.9695e-02, -6.1277e-02]],

               [[-6.9924e-02, -3.8576e-02, -8.7181e-02, -1.2090e-01, -7.3311e-02],
                [ 6.9342e-02,  9.1480e-02, -8.6245e-02,  1.2549e-01,  1.5314e-02],
                [-5.9550e-02, -6.7689e-02, -3.1209e-02, -6.1005e-02, -4.3274e-02],
                [-8.6596e-02,  2.7538e-02, -8.3665e-03,  6.6909e-02,  4.2178e-02],
                [ 4.2158e-02,  3.3340e-03,  9.0755e-03,  1.9898e-03, -2.1537e-02]],

               [[-1.8037e-02, -4.7885e-03,  7.9153e-03,  3.5667e-03,  6.1726e-02],
                [-9.7950e-02,  2.7148e-02,  5.6938e-02,  2.7169e-02,  5.3373e-02],
                [-9.3417e-02,  7.2626e-02, -2.6297e-02,  4.8845e-03,  2.5102e-02],
                [-1.8309e-02, -2.4018e-04,  7.9144e-02,  9.0640e-02, -2.4433e-02],
                [ 1.6471e-01,  9.4758e-02, -1.0681e-01, -1.2206e-01, -4.2883e-02]],

               [[-3.1497e-02, -4.8369e-03, -6.9565e-02,  3.7852e-02,  8.5132e-02],
                [-1.4033e-02,  9.7230e-02, -6.4714e-02,  5.6117e-02,  8.1022e-02],
                [ 5.0067e-03, -8.6493e-02, -2.8838e-02,  2.2062e-02,  1.9418e-02],
                [-5.2541e-05,  1.3477e-01,  1.5361e-02, -3.6824e-02,  3.4045e-02],
                [-1.7272e-01,  2.9210e-03,  1.5939e-02, -1.3572e-02,  4.1629e-02]],

               [[ 4.4768e-02, -3.5056e-02,  3.7457e-02, -9.0991e-02,  4.8750e-02],
                [-1.3768e-02, -5.2716e-02,  4.2374e-02, -6.0435e-02,  6.5938e-02],
                [ 4.1509e-02, -5.4461e-02, -1.1085e-01,  1.2433e-02, -6.0321e-02],
                [-8.3952e-03, -2.5181e-02,  8.1283e-03, -2.1729e-02,  1.3753e-01],
                [ 3.2617e-03,  1.1116e-01, -5.8501e-02,  7.8371e-03, -4.6730e-03]]],


              [[[ 7.5170e-03, -6.9893e-02, -7.2599e-02, -3.9132e-02, -8.8662e-02],
                [ 8.1816e-02,  1.0796e-02,  4.0675e-02, -1.0288e-01,  1.6084e-02],
                [ 5.3820e-02, -3.8369e-03, -3.4138e-02, -6.7229e-02,  3.4613e-02],
                [ 3.6387e-02, -6.1193e-02,  1.1059e-01,  4.5436e-02,  2.8397e-02],
                [-1.6659e-02, -9.3031e-02,  2.7387e-02,  9.7177e-02, -4.2562e-02]],

               [[-2.0895e-02,  6.4056e-02,  9.3401e-02,  7.4680e-02,  4.2360e-02],
                [ 3.6859e-02,  1.3341e-03, -1.7436e-02,  1.3251e-02, -4.4198e-02],
                [-1.3440e-02, -5.0658e-02, -2.3869e-02, -1.0197e-03,  1.3315e-01],
                [ 2.6638e-02, -1.4311e-01, -2.0728e-02, -3.4138e-02, -6.8038e-03],
                [-2.3564e-02, -1.3118e-02,  6.4697e-02,  1.9203e-02, -5.7187e-02]],

               [[ 7.9665e-02, -1.1841e-02,  8.7050e-03,  3.5707e-02, -7.4566e-02],
                [ 1.0561e-01,  2.4171e-02, -5.8087e-02, -3.1631e-02,  1.4183e-01],
                [-2.9622e-02,  6.5245e-02, -8.4141e-02, -9.6315e-03,  7.5537e-02],
                [-1.1621e-02, -2.8678e-02, -1.8603e-02, -3.6197e-02, -6.8541e-02],
                [-5.3699e-02,  8.2185e-05,  1.3339e-02, -5.6374e-02,  9.1662e-02]],

               [[ 5.0422e-02, -3.0555e-02, -1.1994e-02, -2.8864e-02, -4.7944e-02],
                [-9.2082e-02, -2.0114e-02,  1.1080e-03, -8.7728e-02,  8.1596e-03],
                [-5.4058e-03,  4.7326e-02, -1.7694e-02,  1.1631e-02, -1.4565e-01],
                [-6.5642e-02, -5.3809e-03,  4.6345e-02, -1.3860e-02,  5.7886e-02],
                [-1.1082e-02, -3.7631e-02,  9.4582e-02, -9.6432e-02, -3.4886e-02]],

               [[-8.1977e-03,  4.6430e-02, -7.8213e-04,  3.4898e-02,  2.3575e-03],
                [ 1.0584e-01,  9.4443e-02, -1.8790e-02, -1.0291e-01, -1.2920e-01],
                [ 2.9621e-02,  3.0077e-02, -9.9012e-02, -6.2667e-02, -1.1825e-01],
                [-1.2209e-03,  1.8486e-02, -1.2279e-01, -8.0225e-02, -4.3227e-02],
                [-4.5238e-02,  4.2571e-02,  3.5998e-02,  4.2050e-02, -1.2023e-01]],

               [[ 4.7820e-02,  1.6323e-01, -9.7011e-02,  1.2791e-02, -7.0172e-02],
                [ 4.1511e-02,  8.5733e-02, -5.9874e-03, -1.0885e-02,  6.1121e-02],
                [-1.4934e-01,  7.1932e-02,  2.5868e-02, -1.3355e-02,  1.3271e-01],
                [ 2.7480e-03, -2.4276e-02,  4.9873e-02, -1.1099e-02, -1.0676e-01],
                [ 6.6207e-02,  1.6302e-02, -1.1455e-01,  1.0844e-02, -4.8347e-02]]],


              [[[-9.9760e-02,  1.9150e-02, -1.4017e-01, -1.1585e-01, -3.8237e-03],
                [-3.0495e-02, -1.0176e-01, -1.6690e-02, -2.9941e-02, -8.9302e-03],
                [ 8.2867e-02,  2.4127e-02,  9.1499e-02,  7.4052e-02,  1.6828e-03],
                [-1.1307e-01,  6.4954e-05,  1.1348e-01,  1.8891e-02,  2.4209e-03],
                [-3.4641e-03, -1.0385e-01, -1.3858e-01, -8.9685e-02,  7.7434e-02]],

               [[-4.5025e-02,  1.1394e-01, -6.5345e-02, -3.0052e-02, -4.3911e-02],
                [ 6.8716e-02, -9.5503e-02,  5.0614e-02,  7.8437e-02,  1.0309e-01],
                [-2.7321e-02, -8.1157e-02,  6.0459e-02, -8.1802e-02, -6.3444e-02],
                [ 9.0615e-02, -1.3963e-01,  9.8714e-02,  1.2384e-01, -1.0084e-01],
                [-1.5519e-02, -6.3194e-03, -3.4565e-02,  8.4971e-02,  3.9733e-02]],

               [[-2.8205e-02,  7.0934e-02,  6.6078e-02,  7.2542e-02, -9.7110e-02],
                [ 7.3296e-02,  3.7896e-03, -7.2323e-02,  1.6002e-02,  2.9101e-02],
                [-7.3723e-02,  1.5889e-01, -1.8016e-02,  6.1716e-02, -1.9445e-02],
                [-8.9607e-02, -7.3749e-02,  9.8301e-03, -3.4523e-02, -2.1987e-02],
                [-9.3157e-02,  7.1487e-02, -4.6032e-02,  5.4644e-02,  5.0541e-02]],

               [[ 6.0963e-02,  8.3188e-02,  4.2527e-02,  5.3783e-02, -3.3220e-02],
                [-7.2273e-02,  8.5987e-02,  1.5571e-02,  3.0730e-02,  2.9485e-02],
                [-8.2758e-03, -1.0446e-01, -5.6690e-02, -9.5882e-02, -1.0161e-01],
                [ 8.2187e-02, -8.8814e-02,  1.6184e-02,  4.1667e-02,  2.8868e-02],
                [ 2.5378e-02, -4.8991e-02,  2.9006e-02, -9.1864e-03, -9.6781e-02]],

               [[ 2.9548e-02, -1.1967e-02, -2.0336e-02,  4.9937e-02, -5.2592e-02],
                [-7.7833e-02, -6.7267e-03,  6.7678e-02,  1.2339e-02,  2.1119e-02],
                [-3.2693e-02, -2.8428e-02, -7.3010e-02,  7.4197e-02,  6.0370e-02],
                [ 1.1135e-01,  6.2439e-03,  6.6132e-02,  8.8876e-02,  6.7757e-02],
                [ 8.4736e-03,  6.1874e-02, -5.5572e-03, -3.2001e-03,  3.9575e-02]],

               [[ 9.4309e-02,  6.2423e-02,  1.6870e-01, -1.4243e-01,  2.1112e-02],
                [-2.2676e-02, -5.8523e-02,  4.6869e-03,  1.1016e-01,  2.4293e-02],
                [-1.0934e-01, -3.8908e-02, -3.0408e-02, -2.9326e-03,  1.8012e-02],
                [-5.2711e-02, -9.1650e-02,  1.5264e-01,  3.1413e-02,  1.0127e-01],
                [-5.4185e-02, -1.2086e-01, -5.9944e-02, -3.9325e-02, -5.0325e-02]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0382, -0.0121, -0.0262,  0.0070,  0.0597,  0.0315, -0.0379, -0.0624,
              -0.0017, -0.0016,  0.0567,  0.0240,  0.0155,  0.0450, -0.0316, -0.0051],
             requires_grad=True)
       tensor: tensor([ 0.0458, -0.0146,  0.0629,  0.0288,  0.0599,  0.0193, -0.0135, -0.0289,
               0.0030, -0.0048, -0.0006, -0.0199,  0.0055,  0.0746, -0.0526,  0.0112],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[-0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [0., -0., 0., 0., 0.]],

               [[-0., -0., 0., 0., 0.],
                [0., -0., 0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.]],

               [[0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., 0.],
                [0., -0., -0., 0., -0.]],

               [[0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., 0., 0., -0., 0.]],

               [[0., -0., -0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.]]],


              [[[0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., 0.],
                [0., -0., -0., 0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., -0.],
                [-0., 0., 0., -0., -0.]],

               [[0., 0., -0., -0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., 0., -0., -0.]],

               [[-0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., -0., -0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.]]],


              [[[0., 0., 0., -0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., -0., 0.]],

               [[-0., 0., 0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., -0., 0., 0., 0.]],

               [[-0., -0., 0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.]]],


              ...,


              [[[0., -0., 0., 0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.]],

               [[0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.]],

               [[0., -0., -0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., -0., -0., -0.]],

               [[-0., -0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., -0., -0.]],

               [[0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[0., -0., 0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.]]],


              [[[0., -0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., 0., -0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., -0., -0., 0.]],

               [[0., 0., -0., -0., -0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., -0., 0., -0., -0.]],

               [[-0., 0., -0., 0., 0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.]],

               [[-0., 0., -0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., -0., 0., -0.]]],


              [[[-0., 0., -0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., 0., -0., -0., -0.],
                [0., -0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., 0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., 0., 0., -0., 0.]],

               [[0., 0., 0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [0., -0., 0., 0., -0.],
                [0., -0., -0., -0., 0.]],

               [[0., 0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., -0., -0.]],

               [[0., 0., 0., -0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., -0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 1.7231e-02,  4.3731e-02,  5.4681e-02,  3.4570e-02,  4.3578e-02],
                [-2.7786e-02, -4.7320e-02,  3.1995e-02,  4.2537e-02,  7.5819e-02],
                [ 8.9881e-03,  2.6359e-02, -8.3574e-03, -5.6418e-02, -3.9465e-02],
                [-1.9657e-02,  5.7231e-03,  5.5066e-03, -1.4926e-02, -7.4856e-02],
                [-6.2679e-02,  3.7226e-02,  8.0120e-02, -2.3663e-02, -5.3511e-02]],

               [[-7.5590e-02,  6.4561e-02, -4.3921e-02, -7.7799e-03, -5.2721e-02],
                [-1.0355e-02,  5.9574e-02,  1.1166e-02, -4.8245e-02,  4.1576e-02],
                [ 5.6966e-02, -7.2947e-03, -1.9590e-03,  6.9857e-02, -2.1827e-02],
                [ 3.2483e-02, -6.1154e-02,  7.6961e-02,  4.5617e-02, -5.6834e-02],
                [ 3.4771e-02, -3.0784e-02,  4.8407e-02,  2.7358e-02,  6.5453e-02]],

               [[-7.3062e-02, -2.7048e-02,  1.0655e-02,  7.5364e-02,  6.3020e-02],
                [ 1.2004e-02, -3.4794e-02,  4.0299e-02,  6.4137e-02, -2.4964e-02],
                [ 7.4683e-02, -2.4405e-02, -6.7041e-02,  2.7558e-02,  2.7585e-02],
                [-2.3105e-02,  3.5924e-02, -1.3142e-02,  3.6770e-02, -5.1734e-02],
                [-3.0248e-03, -7.3805e-02,  1.5975e-02, -3.8458e-02, -3.4536e-02]],

               [[ 1.2704e-02, -6.5938e-02,  5.8697e-04, -4.0066e-03,  6.0744e-02],
                [-6.6682e-02,  7.7161e-02,  8.1204e-02,  1.0834e-02, -9.7154e-03],
                [-4.5248e-02,  4.1191e-02, -3.1746e-02,  1.4713e-02,  1.5446e-02],
                [-6.1854e-02, -5.4387e-02,  3.8440e-02, -3.7393e-02,  7.6462e-02],
                [ 6.8184e-02, -3.0699e-02, -6.0217e-02,  3.2648e-02, -2.2715e-02]],

               [[ 7.6031e-02,  6.2038e-02,  2.1564e-02, -6.5555e-02, -6.4686e-02],
                [-4.2324e-02,  3.6401e-02, -6.7974e-02, -3.9416e-02, -4.4166e-02],
                [ 6.7690e-02, -6.9973e-02, -4.7069e-02, -7.6832e-03, -4.0938e-03],
                [-2.6045e-02, -4.6521e-02,  3.4383e-03,  6.4282e-02, -2.5852e-02],
                [ 3.2686e-02,  3.1602e-02,  3.5400e-02, -7.0873e-02,  7.1903e-02]],

               [[ 6.0529e-02, -5.2981e-02, -3.9675e-02, -6.5714e-02,  2.8020e-02],
                [ 3.5057e-02,  1.3588e-02, -3.5319e-02, -7.1891e-02, -6.0487e-02],
                [-6.3268e-02,  1.4282e-02, -3.6912e-02,  3.1696e-02,  5.1382e-02],
                [-7.2075e-02, -5.2067e-02,  5.7057e-02, -6.0579e-02, -5.5177e-02],
                [-2.2326e-03,  6.6584e-02, -6.6819e-02, -5.9396e-02,  4.9889e-02]]],


              [[[ 6.1514e-02,  3.9812e-02, -7.7776e-02, -4.9428e-02, -2.5424e-02],
                [-1.3566e-02, -3.6820e-02, -1.2480e-02,  4.0800e-02, -7.4332e-02],
                [ 2.8230e-02, -2.4562e-02, -2.9086e-02,  6.3475e-02, -1.5058e-02],
                [-1.5993e-02,  2.6303e-02,  8.7087e-03, -5.3252e-02,  2.7997e-02],
                [ 2.8034e-02, -3.0461e-02, -1.0604e-02,  4.4506e-02, -1.2910e-02]],

               [[-7.8455e-02, -7.8485e-02, -5.5244e-02,  4.7309e-03, -3.7835e-02],
                [-5.3291e-02,  3.5150e-02, -6.4878e-02, -7.4903e-02, -5.5693e-02],
                [ 3.2479e-02,  6.5471e-02,  4.4178e-02, -5.4057e-02, -7.9336e-02],
                [-6.8247e-02,  6.5587e-02, -8.7773e-04, -3.5329e-02, -3.8033e-02],
                [-3.8157e-02,  4.2570e-02,  7.6967e-02, -4.9785e-02, -2.8498e-02]],

               [[ 4.3747e-02,  5.4973e-02, -2.6460e-02, -1.5235e-02,  7.8260e-02],
                [ 1.8068e-02,  1.9185e-02, -6.4915e-02,  7.4767e-03,  6.4246e-02],
                [-6.3240e-02, -7.0417e-03, -6.4609e-02, -3.5764e-03,  1.0317e-02],
                [-4.5486e-02,  7.3897e-02, -2.8131e-02, -8.0268e-02,  1.2427e-02],
                [ 2.6636e-02,  5.3320e-02,  8.0113e-02, -2.4988e-02, -3.3729e-02]],

               [[-3.0252e-02,  2.0032e-02,  2.2845e-02, -3.2871e-02, -7.2805e-02],
                [-1.2501e-02,  1.2056e-02,  1.7524e-02,  3.1636e-02, -3.6845e-02],
                [ 2.2584e-02,  1.5433e-02, -8.1113e-02,  4.8879e-03, -4.5628e-02],
                [-6.5354e-03, -1.2652e-02, -6.9698e-02,  6.7797e-02,  6.2843e-02],
                [ 1.2404e-02,  1.8414e-02, -4.5958e-02, -6.2785e-03,  4.0931e-02]],

               [[-1.8430e-02, -2.1989e-02, -8.0820e-03,  7.2810e-02, -9.3746e-03],
                [-7.0612e-02,  5.1172e-02, -7.0489e-03, -4.5064e-02,  3.5534e-02],
                [-3.2953e-02,  3.0727e-02,  4.8882e-02, -3.9965e-02, -2.3211e-02],
                [ 6.9945e-02,  2.6839e-02,  3.3160e-03, -2.7622e-02,  7.6837e-02],
                [-1.7656e-02, -4.1508e-03,  3.8821e-02,  3.6178e-02,  3.2764e-02]],

               [[-3.1492e-02, -7.9136e-02, -4.2417e-02,  3.8367e-02, -7.3788e-02],
                [-4.9922e-02,  2.4759e-02,  1.3668e-03,  7.1963e-02, -6.7763e-02],
                [-4.6869e-02,  6.6676e-02,  3.5155e-02, -1.8898e-02, -3.7238e-02],
                [-1.2625e-02, -5.9552e-02, -6.8551e-02,  5.6212e-02, -3.5397e-02],
                [-3.0192e-02, -5.3948e-02,  4.1616e-03, -1.9126e-02, -2.7772e-02]]],


              [[[ 7.5304e-02,  4.8642e-02,  6.6816e-03, -4.8956e-02,  7.0525e-03],
                [ 5.3506e-02, -6.3276e-02,  3.2309e-02, -6.3837e-02, -7.6400e-03],
                [ 8.0993e-02,  5.0589e-02, -4.4620e-02, -1.7407e-02, -4.0847e-02],
                [-7.9109e-02, -5.0230e-02, -6.1707e-02,  3.5480e-02,  9.4730e-03],
                [-4.0758e-02,  3.8149e-02, -2.6608e-02,  6.5586e-02,  6.7719e-02]],

               [[ 3.7999e-02,  5.1426e-02,  4.6990e-02, -8.4217e-04, -7.8228e-02],
                [ 7.4991e-02, -1.3921e-03, -5.2225e-02, -3.7669e-02, -1.3416e-02],
                [-6.8178e-02, -3.8191e-02, -4.8837e-02,  5.6142e-02,  8.0981e-02],
                [-4.9641e-02,  2.5925e-02,  3.2090e-02,  1.2274e-02, -7.0899e-02],
                [ 2.0209e-02,  7.2872e-02, -2.2630e-02, -6.1033e-02,  4.0857e-02]],

               [[-6.5934e-02,  6.1233e-02,  3.0293e-02, -7.4631e-02,  7.8141e-02],
                [-1.4982e-02,  2.9501e-02, -6.2282e-02,  2.6212e-02, -2.5934e-02],
                [ 3.0596e-02,  2.5406e-02,  4.6264e-02, -2.3452e-02, -1.9204e-02],
                [ 8.5599e-03,  7.3909e-02, -6.4926e-02, -8.6878e-03, -3.7090e-02],
                [ 3.4250e-02, -5.0255e-02, -7.8042e-02, -5.1085e-02, -8.0715e-02]],

               [[ 2.6317e-02, -3.6153e-02, -5.3455e-02,  2.1931e-03,  6.2524e-04],
                [-4.6965e-02,  3.9974e-02, -1.2460e-02, -4.2091e-02, -4.6497e-02],
                [-3.4564e-02, -1.1512e-02,  2.1072e-02, -6.0303e-02,  1.7830e-02],
                [-6.8750e-02,  2.9228e-02,  2.5533e-02,  5.8319e-02, -2.6735e-02],
                [ 3.2792e-02,  3.2718e-02,  5.9794e-02, -7.6940e-02, -4.0923e-02]],

               [[-7.0395e-02, -7.5975e-02, -2.7344e-02, -2.3934e-02, -2.9200e-02],
                [ 5.4427e-02, -6.5287e-02,  3.4746e-02,  1.0117e-02, -6.4013e-02],
                [-7.9322e-02,  7.8159e-02, -7.0473e-02,  4.4684e-02,  2.1939e-02],
                [ 5.5900e-02, -6.5708e-02,  3.3804e-02, -1.8570e-02, -7.0815e-02],
                [ 9.0783e-03, -7.2442e-02,  7.7275e-02,  5.2036e-03,  2.7754e-02]],

               [[-3.5030e-02, -1.2288e-02,  6.1587e-02,  1.0093e-02, -3.9849e-02],
                [ 3.1310e-02,  3.2095e-02, -7.3972e-02,  6.0673e-02,  6.6248e-02],
                [ 7.5880e-02,  7.9613e-02, -1.4684e-02,  6.6668e-02,  1.1897e-02],
                [-6.6942e-02,  6.9498e-02,  7.3219e-02, -5.7326e-02,  7.9767e-02],
                [-2.2455e-02,  2.7784e-02, -5.6006e-02, -2.8644e-02, -4.7902e-02]]],


              ...,


              [[[ 5.4641e-02, -2.8195e-02,  1.2863e-02,  6.9362e-02, -6.5228e-02],
                [-1.4326e-02, -2.4860e-02,  6.7933e-02,  6.0904e-03, -5.2127e-02],
                [-3.5126e-02,  1.9526e-02, -4.2415e-02,  7.7512e-03,  6.7621e-02],
                [-7.7149e-02, -3.9294e-02,  2.8953e-02,  2.8484e-02, -5.9953e-02],
                [ 3.0683e-02, -7.1082e-02, -1.7986e-02,  1.3298e-02, -6.4964e-02]],

               [[ 3.5600e-02, -6.4155e-02,  4.7039e-02, -2.0131e-02, -2.7915e-02],
                [ 2.9918e-02,  6.5944e-03, -6.6870e-02, -6.3787e-02, -4.9677e-02],
                [-3.6079e-02, -2.8304e-02,  2.9721e-02,  2.8190e-02, -5.0218e-02],
                [ 6.4923e-02, -4.9635e-02, -3.6667e-03,  7.9379e-02,  2.5979e-02],
                [-4.8337e-02,  7.7505e-02, -7.3112e-02,  2.4510e-02,  2.5683e-02]],

               [[ 2.8887e-03, -2.1671e-02, -5.2347e-02, -6.3329e-02, -4.0586e-02],
                [-1.2757e-02,  3.3395e-02, -7.8268e-02,  7.3369e-02,  5.0369e-04],
                [ 2.6221e-02, -2.9271e-02, -6.5565e-02, -1.6796e-02, -4.9055e-02],
                [-5.8221e-02, -4.2509e-02,  4.6818e-06,  2.6047e-05,  4.1964e-02],
                [ 1.0361e-02, -1.0747e-02, -5.5872e-02, -4.5506e-02, -2.9223e-02]],

               [[-1.9352e-02, -8.0087e-02, -3.3809e-03,  3.9983e-02,  6.5648e-02],
                [-2.5674e-02,  5.8915e-02,  1.8416e-02,  5.8460e-02,  3.2707e-02],
                [-6.6357e-02,  6.9795e-02,  8.6752e-03,  5.9294e-02,  1.7985e-02],
                [-6.5379e-02,  2.3563e-02,  5.0532e-02,  3.5488e-03, -4.1146e-02],
                [ 8.1383e-02,  5.7224e-02, -7.2400e-02, -4.0180e-02, -6.1370e-02]],

               [[ 5.9150e-02,  5.6013e-02, -4.5474e-02, -4.0693e-02,  4.2932e-02],
                [-1.3553e-02,  4.4707e-02, -2.7249e-02,  2.3061e-02,  1.9638e-02],
                [ 2.1247e-02, -6.3221e-02,  5.1882e-02, -1.6282e-02,  6.9770e-02],
                [-7.0485e-02,  7.4524e-02,  4.4509e-02, -6.5970e-02,  2.9617e-02],
                [-8.1458e-02, -4.9716e-02,  1.2315e-02, -2.0425e-02,  3.0172e-02]],

               [[ 5.2212e-02, -3.5905e-02,  2.5783e-02, -6.0258e-02,  3.5215e-02],
                [-7.5870e-02, -1.5704e-02,  3.3627e-02, -4.0729e-02,  5.7335e-02],
                [-3.5374e-02, -7.5164e-02, -7.3468e-02, -1.1014e-02,  1.6214e-02],
                [-3.1993e-02, -2.4012e-02, -1.6525e-03, -6.9434e-02,  2.8824e-02],
                [-2.4923e-02,  7.8550e-02,  4.5400e-02,  2.7779e-02, -6.5854e-02]]],


              [[[ 2.3195e-02, -6.8559e-02, -1.9293e-02, -3.7088e-02, -7.3186e-02],
                [ 7.8055e-02,  5.0381e-03,  3.0678e-02, -5.8232e-02,  2.8428e-02],
                [-2.2133e-02, -1.6136e-03,  6.5804e-02,  3.9714e-03, -3.9261e-02],
                [ 2.5493e-02, -1.4515e-02,  3.1299e-02, -1.6629e-02, -2.5878e-02],
                [-2.4748e-02, -6.8695e-02,  4.8038e-02,  1.7510e-02, -2.4795e-02]],

               [[ 2.7344e-02,  5.4179e-02,  1.8617e-02,  6.7468e-02,  4.8763e-02],
                [ 3.7600e-02,  3.9927e-02,  5.1062e-02,  2.1710e-02, -3.2169e-02],
                [-3.8513e-02, -4.6700e-02, -3.3343e-02,  5.7257e-02,  7.1398e-02],
                [ 3.1596e-02, -6.1682e-02, -1.1294e-02, -4.6606e-02, -1.9235e-02],
                [-5.1762e-02,  4.1756e-03,  5.5901e-02,  5.0582e-02, -3.5234e-02]],

               [[ 6.8751e-02,  1.9294e-02,  9.9260e-04,  4.8577e-02,  1.1296e-02],
                [-2.5931e-03,  5.6043e-02, -3.9379e-02, -1.5890e-02,  2.7560e-02],
                [-6.1309e-02,  4.4243e-02, -6.8550e-02, -7.5816e-02,  6.4328e-02],
                [-3.5933e-02,  1.5707e-02, -4.1360e-03,  2.3218e-02, -5.3996e-02],
                [-5.8497e-02,  2.2945e-02, -1.6730e-02, -4.9801e-02,  4.5134e-02]],

               [[ 4.3283e-02,  1.4086e-02, -7.7765e-03, -2.4735e-02, -6.1307e-02],
                [-3.7167e-02,  1.0755e-05, -3.5806e-02, -2.9059e-02,  7.9799e-02],
                [-2.3940e-02,  3.7716e-02,  2.2327e-02,  4.3423e-02, -4.9689e-02],
                [-3.8642e-02, -4.3529e-02, -2.7830e-02, -4.9579e-02,  6.6404e-02],
                [ 1.3762e-02, -2.3933e-02,  7.1659e-02, -1.3726e-02, -8.0242e-02]],

               [[-1.5402e-02,  1.8298e-02, -5.1471e-02,  2.4580e-02,  9.6023e-03],
                [ 6.1876e-02,  7.8261e-02,  2.6394e-02, -7.8227e-02, -7.6062e-02],
                [ 6.7169e-02, -7.8952e-03, -7.6834e-02, -7.2395e-02, -8.1512e-02],
                [ 2.4895e-02,  7.4719e-02, -6.9676e-02, -2.0183e-02,  4.7940e-02],
                [-1.6931e-02,  6.4322e-02,  4.9096e-02,  6.7067e-02, -5.1128e-02]],

               [[-4.2134e-03,  7.9587e-02, -7.7337e-02, -1.6919e-02, -4.4513e-02],
                [ 4.5003e-02,  2.9848e-02, -3.2239e-02, -5.3997e-02,  3.4833e-02],
                [-4.6084e-02,  7.7325e-02, -8.2341e-04, -2.9711e-02,  4.7059e-02],
                [ 7.1990e-02, -1.8925e-02,  6.9833e-02, -3.8232e-02, -5.3586e-02],
                [ 5.6777e-02,  5.4212e-02, -7.0351e-02,  7.8116e-02, -5.8073e-03]]],


              [[[-3.9872e-02,  2.2878e-02, -5.4838e-02, -7.8741e-02, -2.4075e-02],
                [ 5.5670e-02, -7.5194e-02, -2.3993e-02, -1.3565e-02, -7.6118e-02],
                [ 5.9835e-02,  7.7078e-02, -1.9101e-02,  3.7423e-02,  5.8969e-02],
                [-7.6931e-02,  5.4068e-02,  7.6462e-02,  6.0935e-02,  6.1393e-02],
                [-3.0153e-02, -6.9821e-02, -7.9367e-02, -6.8787e-02,  4.8573e-02]],

               [[-7.7210e-04,  1.2697e-02, -7.1333e-02, -4.3644e-02, -3.0627e-02],
                [ 6.4280e-02, -6.3660e-02,  5.7517e-02,  3.5869e-02, -2.4693e-02],
                [ 4.2786e-03, -8.1200e-02,  6.9931e-02, -2.5703e-04,  1.7692e-02],
                [ 8.5987e-03, -1.2595e-02,  7.9205e-02,  3.0073e-02, -3.2985e-02],
                [-6.3697e-02,  3.2692e-02, -1.9431e-02,  5.3542e-02, -2.0049e-02]],

               [[ 4.7488e-02,  3.1486e-02,  4.3938e-02,  3.8207e-02, -6.3004e-02],
                [ 7.6382e-02,  1.8666e-02,  1.0028e-02,  6.2085e-02,  5.3552e-02],
                [-3.0010e-02,  2.7386e-02, -2.2148e-02, -5.4034e-02, -2.1415e-02],
                [-3.2287e-02, -4.1362e-02,  1.2052e-02, -6.5838e-02, -4.6819e-02],
                [-6.8102e-02,  5.9098e-02,  2.8529e-02, -5.3848e-02,  2.3559e-02]],

               [[ 2.5513e-02,  3.7517e-02,  5.5636e-02,  4.3730e-02, -3.5048e-02],
                [-5.4454e-02,  7.0706e-02, -5.7952e-02,  2.3890e-02,  3.0251e-02],
                [ 2.0294e-02, -6.2255e-02, -7.7577e-02, -6.8416e-02, -4.8070e-02],
                [ 5.3928e-02, -6.0171e-02,  4.9991e-02,  4.6665e-02, -1.5579e-02],
                [ 1.9901e-02, -6.1094e-02, -1.4091e-02, -6.6292e-02,  1.2545e-02]],

               [[ 7.3009e-02,  7.1030e-02, -3.5882e-02, -5.9879e-02,  2.1529e-02],
                [-2.7738e-02, -4.8476e-02, -3.5715e-03, -5.2242e-03,  4.8341e-03],
                [-4.1100e-02,  2.7022e-02, -5.5728e-02,  5.0925e-02,  2.2531e-02],
                [ 6.5409e-02,  2.5243e-02,  5.5194e-02,  5.0815e-02,  3.7556e-02],
                [ 4.0211e-02,  3.1016e-02, -2.9596e-02, -4.3925e-03, -4.0317e-02]],

               [[ 4.9369e-02,  4.1262e-02,  7.0892e-02, -7.3260e-02,  4.2668e-02],
                [ 3.7235e-02,  3.5402e-02,  3.3255e-02,  5.4474e-02,  5.3561e-02],
                [-1.7112e-02, -1.1525e-02,  1.9306e-02,  9.0656e-04,  2.4812e-02],
                [-1.0841e-02, -3.1940e-02,  6.6983e-02, -1.3595e-02,  7.1947e-02],
                [-1.4649e-02,  3.0190e-02, -4.8740e-02, -8.1647e-02, -4.3005e-02]]]])
      (bias): Normal:
       loc: tensor([0., -0., -0., 0., 0., 0., -0., -0., -0., -0., 0., 0., 0., 0., -0., -0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0382, -0.0121, -0.0262,  0.0070,  0.0597,  0.0315, -0.0379, -0.0624,
              -0.0017, -0.0016,  0.0567,  0.0240,  0.0155,  0.0450, -0.0316, -0.0051])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[-3.3777e-02,  4.7079e-03, -2.3445e-02,  ..., -6.8087e-03,
               -4.6413e-02, -4.1360e-02],
              [-1.6733e-02,  5.1208e-03,  3.9803e-02,  ..., -4.5116e-02,
                1.8346e-03, -1.1031e-02],
              [ 2.3320e-02,  4.0388e-03, -4.6767e-02,  ..., -3.7066e-02,
                3.7666e-02, -1.2776e-02],
              ...,
              [-1.3560e-05,  3.7272e-02, -1.6224e-02,  ..., -3.3796e-03,
                3.3060e-02,  4.5754e-02],
              [-9.2607e-03, -4.9655e-02, -3.0438e-02,  ...,  1.7757e-02,
               -4.1499e-02, -1.2796e-02],
              [-3.5203e-02, -3.5148e-03,  4.2838e-03,  ..., -2.5652e-02,
               -7.0994e-03, -2.2834e-02]], requires_grad=True)
       tensor: tensor([[-0.0664,  0.0043,  0.0788,  ...,  0.0156,  0.0026,  0.0044],
              [ 0.0383, -0.0106,  0.0170,  ..., -0.0362, -0.0802, -0.0022],
              [ 0.0117,  0.0738, -0.0664,  ..., -0.1182, -0.0065, -0.0954],
              ...,
              [ 0.1008, -0.0091, -0.0501,  ..., -0.0091, -0.0095,  0.0133],
              [-0.0029,  0.0151,  0.0108,  ...,  0.0052, -0.0712, -0.0245],
              [-0.0082,  0.0598,  0.0241,  ..., -0.1353,  0.0168, -0.0179]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0091, -0.0222, -0.0435, -0.0427,  0.0332, -0.0145,  0.0444, -0.0225,
              -0.0252, -0.0384, -0.0325, -0.0306,  0.0136, -0.0440,  0.0226,  0.0197,
              -0.0346,  0.0281, -0.0020, -0.0201, -0.0482, -0.0426,  0.0018,  0.0065,
               0.0293,  0.0076,  0.0418, -0.0487,  0.0359, -0.0118,  0.0179, -0.0176,
               0.0471, -0.0438,  0.0071,  0.0235,  0.0245, -0.0395,  0.0206,  0.0396,
              -0.0399, -0.0174,  0.0419,  0.0254, -0.0254, -0.0495, -0.0108,  0.0297,
               0.0295,  0.0209, -0.0487,  0.0192,  0.0433, -0.0472, -0.0043, -0.0372,
               0.0491, -0.0487,  0.0278,  0.0408,  0.0268, -0.0455,  0.0474, -0.0153,
              -0.0464, -0.0248,  0.0130,  0.0203, -0.0359, -0.0385, -0.0100,  0.0170,
               0.0464,  0.0401, -0.0339,  0.0245,  0.0023,  0.0165,  0.0467,  0.0407,
              -0.0495,  0.0334,  0.0381,  0.0333, -0.0411, -0.0329,  0.0327, -0.0023,
               0.0222,  0.0137, -0.0227, -0.0370,  0.0094,  0.0082, -0.0387, -0.0141,
              -0.0045, -0.0353, -0.0282,  0.0103, -0.0167, -0.0474, -0.0296, -0.0452,
              -0.0296, -0.0141, -0.0436, -0.0171, -0.0035,  0.0408,  0.0441, -0.0025,
               0.0083,  0.0330, -0.0323,  0.0105,  0.0155,  0.0186, -0.0094, -0.0043],
             requires_grad=True)
       tensor: tensor([-0.1194,  0.0018, -0.0382, -0.0899,  0.0034,  0.0102,  0.0553, -0.0548,
              -0.0950,  0.0837, -0.0716,  0.0613,  0.0356, -0.1366,  0.0281,  0.0833,
              -0.0585,  0.0916, -0.0311, -0.0853, -0.0854, -0.0659,  0.0470, -0.0908,
               0.0431,  0.1137,  0.0463, -0.1019, -0.0060, -0.0466, -0.0754,  0.0031,
               0.0433, -0.0394,  0.0111, -0.0469,  0.0161, -0.0257, -0.0163,  0.0087,
              -0.0052, -0.0057,  0.0504, -0.0218, -0.0147, -0.0882,  0.0371,  0.0567,
               0.0220,  0.0803, -0.0448, -0.0313,  0.0571, -0.1098, -0.0201, -0.0401,
               0.0568, -0.0305, -0.0368,  0.0482,  0.0158, -0.0946,  0.0302, -0.0686,
              -0.0144,  0.0005,  0.1154,  0.0250,  0.0130, -0.0541, -0.0267,  0.0535,
               0.0697,  0.0690, -0.0767,  0.0776,  0.0745,  0.0335,  0.0614,  0.0944,
              -0.0984,  0.0523,  0.0240,  0.0436, -0.0207,  0.0420,  0.1366, -0.0816,
               0.0789, -0.0659,  0.0373, -0.0096,  0.0804,  0.0902, -0.1037, -0.0539,
               0.0304, -0.1223,  0.0599, -0.0474, -0.0649, -0.1068,  0.0388, -0.0095,
              -0.0560, -0.0583, -0.0698, -0.0469,  0.0231,  0.0534,  0.0424,  0.0062,
              -0.0409,  0.0445, -0.0633, -0.0341,  0.0322,  0.0177,  0.0837, -0.0139],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., 0., -0.,  ..., -0., -0., -0.],
              [-0., 0., 0.,  ..., -0., 0., -0.],
              [0., 0., -0.,  ..., -0., 0., -0.],
              ...,
              [-0., 0., -0.,  ..., -0., 0., 0.],
              [-0., -0., -0.,  ..., 0., -0., -0.],
              [-0., -0., 0.,  ..., -0., -0., -0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-3.3777e-02,  4.7079e-03, -2.3445e-02,  ..., -6.8087e-03,
               -4.6413e-02, -4.1360e-02],
              [-1.6733e-02,  5.1208e-03,  3.9803e-02,  ..., -4.5116e-02,
                1.8346e-03, -1.1031e-02],
              [ 2.3320e-02,  4.0388e-03, -4.6767e-02,  ..., -3.7066e-02,
                3.7666e-02, -1.2776e-02],
              ...,
              [-1.3560e-05,  3.7272e-02, -1.6224e-02,  ..., -3.3796e-03,
                3.3060e-02,  4.5754e-02],
              [-9.2607e-03, -4.9655e-02, -3.0438e-02,  ...,  1.7757e-02,
               -4.1499e-02, -1.2796e-02],
              [-3.5203e-02, -3.5148e-03,  4.2838e-03,  ..., -2.5652e-02,
               -7.0994e-03, -2.2834e-02]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., -0., 0., -0., 0., -0., -0., -0., -0., -0., 0., -0., 0., 0., -0., 0., -0., -0., -0., -0., 0., 0.,
              0., 0., 0., -0., 0., -0., 0., -0., 0., -0., 0., 0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0., -0., 0.,
              0., 0., -0., 0., 0., -0., -0., -0., 0., -0., 0., 0., 0., -0., 0., -0., -0., -0., 0., 0., -0., -0., -0., 0.,
              0., 0., -0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., -0., 0., -0., 0., 0., -0., -0., 0., 0., -0., -0.,
              -0., -0., -0., 0., -0., -0., -0., -0., -0., -0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., 0., 0., -0., -0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0091, -0.0222, -0.0435, -0.0427,  0.0332, -0.0145,  0.0444, -0.0225,
              -0.0252, -0.0384, -0.0325, -0.0306,  0.0136, -0.0440,  0.0226,  0.0197,
              -0.0346,  0.0281, -0.0020, -0.0201, -0.0482, -0.0426,  0.0018,  0.0065,
               0.0293,  0.0076,  0.0418, -0.0487,  0.0359, -0.0118,  0.0179, -0.0176,
               0.0471, -0.0438,  0.0071,  0.0235,  0.0245, -0.0395,  0.0206,  0.0396,
              -0.0399, -0.0174,  0.0419,  0.0254, -0.0254, -0.0495, -0.0108,  0.0297,
               0.0295,  0.0209, -0.0487,  0.0192,  0.0433, -0.0472, -0.0043, -0.0372,
               0.0491, -0.0487,  0.0278,  0.0408,  0.0268, -0.0455,  0.0474, -0.0153,
              -0.0464, -0.0248,  0.0130,  0.0203, -0.0359, -0.0385, -0.0100,  0.0170,
               0.0464,  0.0401, -0.0339,  0.0245,  0.0023,  0.0165,  0.0467,  0.0407,
              -0.0495,  0.0334,  0.0381,  0.0333, -0.0411, -0.0329,  0.0327, -0.0023,
               0.0222,  0.0137, -0.0227, -0.0370,  0.0094,  0.0082, -0.0387, -0.0141,
              -0.0045, -0.0353, -0.0282,  0.0103, -0.0167, -0.0474, -0.0296, -0.0452,
              -0.0296, -0.0141, -0.0436, -0.0171, -0.0035,  0.0408,  0.0441, -0.0025,
               0.0083,  0.0330, -0.0323,  0.0105,  0.0155,  0.0186, -0.0094, -0.0043])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=10, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0330, -0.0638,  0.0145,  ..., -0.0799,  0.0038, -0.0028],
              [ 0.0647, -0.0898,  0.0264,  ...,  0.0438,  0.0560,  0.0141],
              [-0.0616, -0.0866, -0.0050,  ...,  0.0206, -0.0182, -0.0549],
              ...,
              [-0.0021,  0.0008,  0.0321,  ..., -0.0234,  0.0801,  0.0765],
              [-0.0517, -0.0638,  0.0722,  ..., -0.0724, -0.0263, -0.0710],
              [-0.0566, -0.0889, -0.0738,  ...,  0.0263, -0.0151, -0.0537]],
             requires_grad=True)
       tensor: tensor([[ 0.0217, -0.1135, -0.0030,  ..., -0.0561,  0.0181,  0.0389],
              [ 0.0759, -0.0429,  0.0090,  ...,  0.0921,  0.1001,  0.0168],
              [-0.0243, -0.0910,  0.0879,  ...,  0.0186,  0.0450, -0.0486],
              ...,
              [ 0.0081,  0.0176, -0.0045,  ..., -0.0133,  0.1110,  0.0398],
              [-0.0471, -0.0219,  0.0584,  ..., -0.0652,  0.0176, -0.0921],
              [-0.0964, -0.1910, -0.1096,  ..., -0.0227, -0.0748, -0.1086]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0342, -0.0681,  0.0492, -0.0765, -0.0070, -0.0712, -0.0436,  0.0281,
               0.0831,  0.0615], requires_grad=True)
       tensor: tensor([ 0.0767, -0.0408,  0.1565, -0.0743, -0.0248, -0.0158, -0.0733, -0.0389,
              -0.0016,  0.0762], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0.,  ..., -0., 0., -0.],
              [0., -0., 0.,  ..., 0., 0., 0.],
              [-0., -0., -0.,  ..., 0., -0., -0.],
              ...,
              [-0., 0., 0.,  ..., -0., 0., 0.],
              [-0., -0., 0.,  ..., -0., -0., -0.],
              [-0., -0., -0.,  ..., 0., -0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              ...,
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0330, -0.0638,  0.0145,  ..., -0.0799,  0.0038, -0.0028],
              [ 0.0647, -0.0898,  0.0264,  ...,  0.0438,  0.0560,  0.0141],
              [-0.0616, -0.0866, -0.0050,  ...,  0.0206, -0.0182, -0.0549],
              ...,
              [-0.0021,  0.0008,  0.0321,  ..., -0.0234,  0.0801,  0.0765],
              [-0.0517, -0.0638,  0.0722,  ..., -0.0724, -0.0263, -0.0710],
              [-0.0566, -0.0889, -0.0738,  ...,  0.0263, -0.0151, -0.0537]])
      (bias): Normal:
       loc: tensor([0., -0., 0., -0., -0., -0., -0., 0., 0., 0.])
       scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
              0.3162])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0342, -0.0681,  0.0492, -0.0765, -0.0070, -0.0712, -0.0436,  0.0281,
               0.0831,  0.0615])
    )
    (observed): Observed()
  )
)

You just have to define the forward function, and the backward function (where gradients are computed) is automatically defined for you using autograd. You can use any of the Tensor operations in the forward function.

The learnable parameters of a model are returned by net.parameters()

params = list(net.parameters())
print(len(params))
print(params[0].size())

Out:

16
torch.Size([6, 1, 5, 5])

Let try a random 32x32 input

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

Out:

tensor([[-0.3077, -0.5381,  0.6875,  0.3936, -1.1296,  0.1405, -0.0033, -1.0102,
          0.4813,  0.6858]], grad_fn=<AddmmBackward0>)

Zero the gradient buffers of all parameters and backprops with random gradients:

net.zero_grad()
out.backward(torch.randn(1, 10))

Note

borch.nn only supports mini-batches. The entire borch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.

For example, nn.Conv2d will take in a 4D Tensor of nSamples x nChannels x Height x Width.

If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

Before proceeding further, let’s recap all the classes you’ve seen so far.

Recap:
  • torch.Tensor - A multi-dimensional array with support for autograd operations like backward(). Also holds the gradient w.r.t. the tensor.

  • nn.Module - Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.

  • nn.Parameter - A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to a Module.

  • autograd.Function - Implements forward and backward definitions of an autograd operation. Every Tensor operation, creates at least a single Function node, that connects to functions that created a Tensor and encodes its history.

At this point, we covered:
  • Defining a neural network

  • Processing inputs and calling backward

Still Left:
  • Computing the loss

  • Updating the weights of the network

Loss Function

A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.

There are several different loss functions under the nn package . A simple loss is: nn.MSELoss which computes the mean-squared error between the input and the target. They are how ever only equivalent to an maximum likelihood approach in deep learning.

In order to infer the posterior of the weights and thus capture the uncertainty of the weights as well, we have to use the infer package. In this example we will use infer.vi_loss function that automatically creates the best loss function for variational inference given the latent variables in your model.

Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:

target = torch.randint(10, (1,))  # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)

Out:

tensor(9293.3262, grad_fn=<AddBackward0>)

Now, if you would follow loss in the backward direction you will see a graph of computations that looks like this:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
      -> view -> linear -> relu -> linear ->
      -> loss

So, when we call loss.backward(), the whole graph is differentiated w.r.t. the loss, and all Tensors in the graph that has requires_grad=True will have their .grad Tensor accumulated with the gradient.

Backprop

To backpropagate the error all we have to do is to loss.backward(). You need to clear the existing gradients though, else gradients will be accumulated to existing gradients.

Now we shall call loss.backward(), and have a look at conv1’s bias gradients before and after the backward.

net.zero_grad()  # zeroes the gradient buffers of all parameters

The value for the loc paramater of the approximating distribution of conv1.bias zeroing the gradients is

print(net.conv1.posterior.bias.loc.grad)

loss.backward()

Out:

tensor([0., 0., 0., 0., 0., 0.])

after calling backward the value is

print(net.conv1.posterior.bias.loc.grad)

Out:

tensor([-0.6064, -0.4782,  0.9458,  0.5497,  1.2112, -0.2230])

The only thing left to learn is:

  • Updating the weights of the network

Update the weights

The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):

weight = weight - learning_rate * gradient

We can implement this using simple python code:

learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

However, as you use neural networks, you want to use various different update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc. To enable this, torch built a small package: torch.optim that implements all these methods. Using it is very simple:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
n_batch_epoch = 10  # number of batches per epoch usually len(dataloader)
optimizer.zero_grad()  # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step()  # Does the update

Exercises

  1. The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.

  2. Try designing yor own feed forward networks with two different types of non lineareties ex. relu

Total running time of the script: ( 0 minutes 0.135 seconds)

Gallery generated by Sphinx-Gallery