Neural Networks

Neural networks can be constructed using the borch.nn package.

Now that you’ve had a glimpse of autograd, nn depends on autograd to define models and differentiate them. An nn.Module contains layers, and a method forward(input) that returns the output.

For example, look at this network that classifies digit images:

convnet

convnet

It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.

A typical training procedure for a neural network is as follows:

  • Define a network that has some learnable parameters and/or randomVariables

  • For each batch in a dataset, do:

    • Process the input data through the network

    • Compute the loss (how far is the output from being correct?)

    • Propagate gradients back into the network’s parameters

    • Update the weights of the network, typically using a simple update rule: weight = weight - learning_rate * gradient

Define the network

Let’s define this network:

import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)

        # 6 input channels, 16 output channels, 5x5 convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)

        # An affine operation: y = Wx + b
        # NB after two convolutional operations with 5x5 kernels and no padding,
        # the spatial dimension of an image with intial dimension 32x32 is
        # 5x5 (with 16 channels)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[ 0.1680,  0.1718,  0.0098, -0.0428,  0.0831],
                [ 0.0618,  0.1618, -0.0325, -0.1448, -0.1646],
                [-0.0362, -0.0103, -0.1824,  0.1346,  0.0135],
                [-0.0389,  0.0558, -0.0680, -0.1515,  0.1680],
                [-0.1748,  0.1496, -0.1459,  0.0181, -0.1115]]],


              [[[ 0.1081, -0.0935, -0.1567,  0.1787,  0.0230],
                [-0.1890, -0.0129,  0.0133,  0.0472,  0.0424],
                [-0.0680, -0.0616, -0.0199,  0.0837,  0.1960],
                [ 0.0917,  0.0248,  0.0921, -0.0139,  0.0422],
                [-0.0997, -0.0637,  0.1865, -0.1982,  0.1430]]],


              [[[-0.1463, -0.0137,  0.1728, -0.0899, -0.0427],
                [-0.0654, -0.0375, -0.1964, -0.1289, -0.0483],
                [ 0.0495,  0.0236,  0.1467,  0.1145, -0.0054],
                [-0.0240, -0.0058, -0.0281,  0.1004, -0.0038],
                [ 0.0126,  0.0019, -0.1089, -0.1572, -0.1472]]],


              [[[-0.0682, -0.0487,  0.0365, -0.1739, -0.0511],
                [-0.0746,  0.1603,  0.0149,  0.1684,  0.0557],
                [-0.1502, -0.1176, -0.0681,  0.1512, -0.0974],
                [-0.0197, -0.1146,  0.0302, -0.0089, -0.0336],
                [ 0.1591,  0.1478, -0.1726, -0.0903,  0.0923]]],


              [[[-0.1594,  0.1311,  0.0003, -0.0176, -0.1664],
                [ 0.0809,  0.1795,  0.1855,  0.1246, -0.0944],
                [-0.1488, -0.0044,  0.0512, -0.1854,  0.0012],
                [-0.0820,  0.1209, -0.0562,  0.0859, -0.1175],
                [ 0.0652,  0.1324, -0.1412,  0.1655,  0.1654]]],


              [[[-0.0019, -0.0864,  0.1593,  0.1857, -0.1915],
                [ 0.1527, -0.0412, -0.0173,  0.0072, -0.0933],
                [-0.1378, -0.1514,  0.1343, -0.0660, -0.0785],
                [ 0.1723,  0.0924,  0.1815, -0.0553,  0.1033],
                [ 0.0797,  0.0611,  0.0834, -0.1489,  0.1146]]]], requires_grad=True)
       tensor: tensor([[[[ 0.1549,  0.1702,  0.0258, -0.0510,  0.1557],
                [ 0.0103,  0.1476,  0.0522, -0.2040, -0.1602],
                [-0.0236,  0.0437, -0.2519,  0.0928, -0.0489],
                [-0.0862,  0.1147, -0.0815, -0.1318,  0.2126],
                [-0.2679,  0.1714, -0.1526,  0.0422, -0.1586]]],


              [[[ 0.1015, -0.1919, -0.1668,  0.2028, -0.0101],
                [-0.1396, -0.0280,  0.0852,  0.0446,  0.0493],
                [-0.0921, -0.0273, -0.0788,  0.0889,  0.2471],
                [ 0.0824,  0.0427,  0.1677, -0.0809, -0.0029],
                [-0.0677, -0.1122,  0.1575, -0.2243,  0.1292]]],


              [[[-0.1460, -0.0464,  0.1325, -0.0497, -0.0142],
                [-0.0287,  0.0196, -0.2510, -0.1181,  0.0341],
                [ 0.0990,  0.0839,  0.1291,  0.1391, -0.0971],
                [-0.0203,  0.0117,  0.0436,  0.1213, -0.0132],
                [ 0.0082, -0.0122, -0.1402, -0.2093, -0.1193]]],


              [[[-0.0340, -0.0688,  0.0707, -0.1087, -0.0648],
                [-0.0878,  0.2453,  0.0498,  0.1437,  0.0903],
                [-0.1033, -0.1412, -0.0248,  0.1735, -0.0668],
                [ 0.0400, -0.0714,  0.0677,  0.0780, -0.0020],
                [ 0.1382,  0.1072, -0.1638, -0.0979,  0.0906]]],


              [[[-0.2221,  0.1444, -0.0130, -0.0024, -0.1279],
                [ 0.0786,  0.2368,  0.2252,  0.1285, -0.0604],
                [-0.0719, -0.0685,  0.0014, -0.2441, -0.0323],
                [-0.0953,  0.1052,  0.0052,  0.0528, -0.1838],
                [ 0.1155,  0.1910, -0.2823,  0.1724,  0.1009]]],


              [[[-0.0248, -0.0756,  0.0753,  0.1658, -0.1588],
                [ 0.0979, -0.0104, -0.0249,  0.0274, -0.1429],
                [-0.1627, -0.1360,  0.1719, -0.1276, -0.0804],
                [ 0.1876,  0.0270,  0.1819, -0.0160,  0.1198],
                [ 0.1410,  0.0349,  0.1013, -0.1741,  0.2346]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0894,  0.1877, -0.0535,  0.0308,  0.1593,  0.0810],
             requires_grad=True)
       tensor: tensor([ 0.1004,  0.2790, -0.0837, -0.0029,  0.1159,  0.0025],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., 0., 0., -0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.]]],


              [[[0., -0., -0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., -0., 0.]]],


              [[[-0., -0., 0., -0., -0.],
                [-0., -0., -0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., -0., -0., -0.]]],


              [[[-0., -0., 0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., 0.]]],


              [[[-0., 0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]]],


              [[[-0., -0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., 0., -0., 0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 0.1680,  0.1718,  0.0098, -0.0428,  0.0831],
                [ 0.0618,  0.1618, -0.0325, -0.1448, -0.1646],
                [-0.0362, -0.0103, -0.1824,  0.1346,  0.0135],
                [-0.0389,  0.0558, -0.0680, -0.1515,  0.1680],
                [-0.1748,  0.1496, -0.1459,  0.0181, -0.1115]]],


              [[[ 0.1081, -0.0935, -0.1567,  0.1787,  0.0230],
                [-0.1890, -0.0129,  0.0133,  0.0472,  0.0424],
                [-0.0680, -0.0616, -0.0199,  0.0837,  0.1960],
                [ 0.0917,  0.0248,  0.0921, -0.0139,  0.0422],
                [-0.0997, -0.0637,  0.1865, -0.1982,  0.1430]]],


              [[[-0.1463, -0.0137,  0.1728, -0.0899, -0.0427],
                [-0.0654, -0.0375, -0.1964, -0.1289, -0.0483],
                [ 0.0495,  0.0236,  0.1467,  0.1145, -0.0054],
                [-0.0240, -0.0058, -0.0281,  0.1004, -0.0038],
                [ 0.0126,  0.0019, -0.1089, -0.1572, -0.1472]]],


              [[[-0.0682, -0.0487,  0.0365, -0.1739, -0.0511],
                [-0.0746,  0.1603,  0.0149,  0.1684,  0.0557],
                [-0.1502, -0.1176, -0.0681,  0.1512, -0.0974],
                [-0.0197, -0.1146,  0.0302, -0.0089, -0.0336],
                [ 0.1591,  0.1478, -0.1726, -0.0903,  0.0923]]],


              [[[-0.1594,  0.1311,  0.0003, -0.0176, -0.1664],
                [ 0.0809,  0.1795,  0.1855,  0.1246, -0.0944],
                [-0.1488, -0.0044,  0.0512, -0.1854,  0.0012],
                [-0.0820,  0.1209, -0.0562,  0.0859, -0.1175],
                [ 0.0652,  0.1324, -0.1412,  0.1655,  0.1654]]],


              [[[-0.0019, -0.0864,  0.1593,  0.1857, -0.1915],
                [ 0.1527, -0.0412, -0.0173,  0.0072, -0.0933],
                [-0.1378, -0.1514,  0.1343, -0.0660, -0.0785],
                [ 0.1723,  0.0924,  0.1815, -0.0553,  0.1033],
                [ 0.0797,  0.0611,  0.0834, -0.1489,  0.1146]]]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0894,  0.1877, -0.0535,  0.0308,  0.1593,  0.0810])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[ 4.2191e-02,  7.6840e-02, -7.4705e-02, -2.1865e-02,  6.1931e-02],
                [ 4.5672e-02, -4.5500e-02,  6.8328e-02, -1.9356e-02,  6.2420e-03],
                [-9.1323e-03,  6.4114e-02, -6.2778e-02, -2.9379e-02,  9.7347e-03],
                [ 1.5921e-02, -2.0886e-02, -6.3837e-02, -4.3951e-02,  5.5913e-04],
                [-1.2263e-02, -4.8414e-02, -4.4003e-02, -2.4253e-02,  3.2497e-02]],

               [[ 2.3382e-02,  3.9259e-02, -7.7953e-02, -3.2520e-02,  1.8270e-02],
                [-4.3315e-02,  7.9060e-02, -5.8311e-02, -2.8759e-02, -8.7089e-03],
                [ 5.5321e-02,  4.8438e-02,  6.5735e-02, -4.9760e-02, -3.5508e-02],
                [ 4.1005e-02,  4.6544e-02,  1.5205e-02, -3.6013e-02,  6.9640e-03],
                [-6.8387e-02, -4.9177e-02,  4.0450e-02,  7.0279e-02,  7.5872e-02]],

               [[-4.6557e-03,  1.8276e-02, -5.8220e-02, -2.1548e-02,  5.6864e-02],
                [-1.3140e-02,  3.1058e-02,  1.7065e-02, -6.0197e-02, -9.2233e-03],
                [-5.6885e-02, -6.7734e-02,  2.2703e-02,  7.8996e-02, -2.3009e-02],
                [-3.2973e-02,  3.9733e-02, -4.7892e-02,  1.3657e-02, -6.1810e-02],
                [ 6.4573e-02, -2.9753e-02, -5.5327e-02, -9.4562e-03,  2.1779e-02]],

               [[ 7.1124e-02,  4.0486e-02,  6.9724e-02,  7.6075e-02,  4.1246e-03],
                [ 8.3202e-03,  2.3893e-02, -6.6808e-03, -4.9518e-02,  3.8584e-02],
                [-5.2714e-02,  5.0159e-02, -6.6929e-02, -7.0058e-02, -4.7074e-02],
                [ 7.5636e-02,  2.7516e-02, -5.3383e-02, -7.9865e-02,  7.8899e-02],
                [ 6.7598e-02,  2.2288e-02, -4.3938e-02, -2.2255e-02,  3.8803e-02]],

               [[-1.8988e-02,  6.3276e-02, -6.4471e-02, -3.8395e-02, -5.1800e-02],
                [ 2.6698e-02,  5.7260e-03, -2.2055e-02,  2.9274e-02,  1.7332e-02],
                [-1.0218e-02, -2.7324e-02, -2.9080e-02,  1.3222e-02,  7.9084e-02],
                [ 4.3772e-02,  7.7684e-02,  4.8257e-02, -6.8435e-02, -2.0343e-02],
                [ 3.0093e-02,  3.1836e-03,  2.6288e-03, -8.2295e-03, -2.5020e-02]],

               [[-3.8859e-02,  5.2245e-02, -7.8403e-02, -4.2871e-02,  6.1650e-02],
                [-4.8324e-02,  5.1745e-02,  5.0752e-02,  7.3504e-02, -2.3351e-02],
                [ 7.5441e-02, -5.1618e-02,  8.0429e-02,  7.3450e-02,  7.8664e-02],
                [ 4.9288e-02, -1.9266e-02,  2.9147e-02, -8.0626e-02, -3.1476e-02],
                [ 6.6063e-02,  1.7808e-02,  2.6498e-02,  4.6047e-03,  7.2748e-02]]],


              [[[ 2.7147e-02, -6.4806e-02, -1.7023e-02,  4.9091e-03, -6.6140e-03],
                [-7.7151e-02, -1.2995e-02, -4.5008e-02, -1.8290e-02,  5.0563e-02],
                [-7.6304e-02, -8.0019e-02, -1.1984e-02, -3.8425e-02, -4.5299e-03],
                [-1.4411e-02,  5.2870e-02,  7.4521e-02,  4.9850e-02, -2.3332e-02],
                [ 3.2249e-03,  4.7371e-02, -5.4040e-02,  7.8740e-02,  6.2129e-02]],

               [[-5.5597e-02, -7.9082e-03,  5.7858e-03, -1.7938e-02,  3.2981e-02],
                [-5.7644e-02, -7.8341e-02, -5.2250e-02, -7.0176e-02,  9.1169e-03],
                [ 4.8751e-03,  2.7214e-02, -7.2253e-02, -7.9531e-02,  2.2115e-02],
                [-8.0066e-02,  6.1143e-03,  9.5127e-03,  7.6738e-02, -5.4762e-02],
                [-4.0489e-02, -4.4740e-02,  7.4822e-02,  5.8303e-02,  3.4314e-02]],

               [[ 3.2335e-02, -7.6859e-03,  7.2277e-02,  6.4942e-02, -2.1889e-02],
                [ 6.7059e-02, -2.5517e-02, -7.2487e-02,  6.3138e-02,  1.8800e-02],
                [-4.1668e-02,  3.3260e-02,  2.0425e-03, -9.8570e-03, -5.2738e-02],
                [ 5.0590e-02, -1.0069e-02, -1.0559e-02,  1.4182e-02,  5.3266e-02],
                [-1.0744e-02, -2.3853e-02,  1.9477e-02,  1.2789e-02, -1.9533e-03]],

               [[-7.1060e-03, -4.1691e-02,  3.4899e-02, -7.2925e-02, -5.2534e-02],
                [-4.6110e-02,  7.5498e-02,  6.1524e-02,  1.1099e-02,  7.3727e-02],
                [ 5.4388e-02,  3.9116e-02, -6.6412e-02, -2.6689e-02, -2.0881e-02],
                [ 2.6913e-02,  3.5247e-02, -2.6048e-02,  1.4267e-02,  2.6323e-02],
                [ 6.5432e-03, -7.3995e-03, -4.3528e-02,  5.0431e-02,  6.9511e-02]],

               [[-2.8527e-02,  6.1429e-02,  6.0427e-02, -4.1141e-02, -7.3619e-02],
                [ 2.1947e-02, -3.8028e-02, -6.7364e-02,  1.5408e-02,  5.2976e-02],
                [ 5.8108e-02, -6.1758e-02,  1.4653e-02, -7.6558e-02, -3.1448e-02],
                [ 6.2695e-02,  6.4849e-02,  3.1678e-02,  8.7257e-03,  8.9366e-03],
                [ 4.3787e-02, -2.6563e-02,  7.6010e-02, -4.9455e-02,  4.6068e-02]],

               [[-4.6879e-02, -1.0211e-02, -5.7937e-02, -6.0173e-02, -5.6010e-03],
                [ 6.7886e-02,  4.8736e-02,  7.5209e-02, -2.4597e-02,  1.2666e-02],
                [ 2.7294e-02,  5.9385e-02,  8.1056e-02,  7.2644e-02, -7.7046e-02],
                [ 1.7380e-02, -3.2567e-02, -2.1987e-02,  7.8235e-02, -5.2908e-02],
                [-2.0811e-02, -2.9994e-02, -4.6838e-02, -4.2671e-02,  3.8727e-02]]],


              [[[ 3.5134e-02,  5.2745e-02, -6.1149e-02, -7.6414e-02, -2.6176e-02],
                [-6.1641e-02, -8.1041e-02, -2.5100e-02, -7.2410e-02,  6.1282e-02],
                [-3.0138e-03,  5.0077e-02, -8.1627e-02, -4.3763e-02, -3.0137e-03],
                [-4.0017e-02, -4.3606e-02, -5.0771e-02, -1.0728e-02,  6.0225e-02],
                [-6.3614e-02,  5.9101e-02, -1.1370e-02,  6.4711e-02, -1.0095e-02]],

               [[-5.0975e-02,  6.0192e-02,  4.8900e-02,  8.0230e-02, -1.5046e-02],
                [ 7.2609e-02,  3.1814e-02,  6.7249e-02,  5.9835e-02,  7.7418e-02],
                [ 2.9254e-02,  3.7293e-02,  1.1221e-02,  3.3947e-03,  4.2694e-02],
                [-6.1252e-02,  3.9850e-02,  8.3872e-03,  6.4936e-02,  5.5738e-02],
                [-3.8509e-02, -5.6070e-03,  6.8840e-02,  6.8610e-02, -4.3579e-02]],

               [[ 1.8686e-02,  2.1468e-02,  5.9046e-02,  5.2732e-02,  5.1538e-02],
                [ 7.0772e-02,  4.9593e-02,  4.1163e-02,  6.8360e-02, -3.9729e-02],
                [-4.0441e-02, -1.5649e-02,  7.2554e-02, -2.2384e-02,  3.3869e-02],
                [-1.1148e-02,  7.9215e-02,  1.0318e-02, -4.9182e-02,  9.6277e-03],
                [ 4.1858e-02,  3.0790e-02, -3.0381e-02,  6.6102e-02, -5.1207e-02]],

               [[-6.6738e-03,  6.4776e-02, -3.3882e-02,  7.1115e-02,  3.1006e-02],
                [-2.2906e-02, -2.9046e-02,  3.5319e-02,  5.3543e-02,  2.1489e-02],
                [ 3.3100e-02, -3.5833e-02, -1.8264e-02,  6.3019e-03,  3.8628e-02],
                [-3.4829e-02, -2.0159e-02, -4.5294e-02,  1.4057e-02, -7.9188e-02],
                [-3.1353e-02, -1.8218e-02, -8.0638e-02, -5.0035e-02, -3.4570e-02]],

               [[ 6.1931e-03, -3.6366e-03,  2.6574e-02, -1.3003e-02,  6.1312e-02],
                [ 4.6450e-02, -4.6889e-02, -2.5151e-02, -1.6860e-02,  6.3430e-02],
                [-6.3583e-02,  3.9752e-02, -4.2166e-02, -2.7013e-02, -3.3751e-03],
                [ 5.6220e-02,  8.1038e-02,  7.8622e-02,  5.7729e-02,  3.2215e-02],
                [-3.7003e-02, -6.6767e-02,  6.3296e-02, -7.1348e-02, -3.1742e-03]],

               [[ 7.3035e-02,  4.1458e-02, -4.6261e-02,  3.0164e-02, -1.8932e-03],
                [ 3.8696e-02,  5.4963e-02,  4.5004e-02, -5.6026e-02,  1.2815e-02],
                [-8.0065e-02,  1.8459e-02,  7.6491e-02,  7.3781e-02,  5.5816e-02],
                [ 4.5870e-02, -8.1061e-02, -4.7139e-02,  5.6221e-02,  6.0546e-02],
                [-6.1282e-02, -2.2380e-02, -6.3779e-02, -6.8776e-02, -4.3537e-02]]],


              ...,


              [[[ 5.2454e-02, -4.7911e-02, -7.2297e-02, -2.0057e-02, -5.5949e-02],
                [-1.7783e-02, -8.5585e-03,  2.1769e-02,  7.7195e-02, -2.6911e-02],
                [-5.5650e-02, -7.1240e-02, -4.3191e-02,  7.2941e-02,  7.0596e-04],
                [-1.1619e-02,  9.3917e-03, -2.8626e-02, -2.5418e-02, -1.0257e-02],
                [ 5.1407e-02,  4.8896e-02,  6.0359e-02,  3.2419e-02, -1.9229e-02]],

               [[ 1.5427e-02,  4.6055e-02,  3.8077e-02,  1.9711e-02, -4.7648e-02],
                [-2.7328e-02, -7.1221e-02,  5.6410e-02, -2.5372e-02, -4.0767e-02],
                [-3.1127e-02,  6.7051e-02, -7.9245e-02,  6.6682e-02, -2.6589e-03],
                [ 6.9368e-02, -4.1969e-02,  5.7844e-02,  4.0387e-02, -5.6537e-02],
                [ 1.3153e-02,  4.0302e-02,  7.1897e-02,  2.1402e-02,  5.5573e-02]],

               [[-4.6272e-02,  2.1553e-02,  1.4332e-03, -4.3438e-02, -2.7539e-03],
                [-3.4410e-02, -7.9327e-02, -3.9320e-02,  7.7251e-02,  2.2694e-02],
                [-7.6279e-02, -3.1001e-02,  6.8287e-02, -7.7937e-02, -6.5913e-02],
                [-5.2818e-02, -3.7519e-02, -2.3019e-02, -3.6675e-02,  7.9680e-02],
                [-6.1808e-02,  5.1876e-02, -7.8567e-02,  8.1565e-02, -6.8310e-02]],

               [[ 5.5309e-02,  8.7365e-03, -5.6561e-03,  7.1133e-02, -4.6917e-02],
                [ 2.1503e-03,  8.4254e-03,  1.1892e-02,  7.9317e-02, -7.8111e-03],
                [-5.6072e-02, -6.5883e-02,  3.5538e-02, -7.8521e-02, -6.9720e-02],
                [-1.9393e-02, -7.3491e-02,  8.1700e-03, -1.9778e-03,  2.8227e-02],
                [-1.6722e-02,  5.5947e-02,  6.2404e-03, -5.1478e-03, -2.5255e-02]],

               [[-7.8407e-02,  7.2449e-02, -6.2823e-02, -6.6556e-03, -1.6768e-02],
                [-3.4428e-02, -1.3495e-02,  6.7206e-04,  2.5644e-02, -3.1432e-02],
                [-7.0893e-02,  2.9846e-02, -5.0699e-03,  2.9505e-02, -1.8095e-02],
                [ 6.7203e-02,  5.6675e-03, -3.5811e-02,  1.9892e-02,  5.2260e-02],
                [-6.8506e-02, -7.1065e-02,  3.3085e-02,  4.3750e-02,  1.6927e-02]],

               [[-6.5512e-02, -5.6648e-02,  4.6444e-02,  7.4754e-02,  1.1064e-02],
                [ 6.3824e-02,  6.5449e-02, -4.9581e-02, -1.4683e-02, -4.2090e-02],
                [-4.7219e-02, -3.5248e-02,  3.6657e-04, -2.6233e-02,  7.6344e-02],
                [-6.2184e-02, -6.0576e-02,  8.0155e-02, -4.5559e-02,  5.8270e-02],
                [ 3.0984e-02,  5.9146e-02,  5.2023e-02, -8.1066e-02,  1.9684e-02]]],


              [[[ 7.3765e-02,  3.5141e-03, -2.1752e-03,  7.4100e-02,  7.4641e-02],
                [ 4.8591e-02,  7.6598e-02,  2.9565e-02,  2.1044e-02, -5.0159e-02],
                [-6.1025e-02, -2.8593e-02,  2.4891e-02, -6.0556e-02,  6.2090e-02],
                [-1.5755e-02,  4.0377e-02,  3.5295e-03,  5.3803e-02, -6.6197e-02],
                [ 4.4328e-02, -1.8031e-02, -7.4217e-02, -9.6550e-03, -6.4711e-03]],

               [[-7.8584e-02, -4.6375e-02, -3.3496e-02, -6.2696e-02,  3.7588e-02],
                [-1.4592e-02, -2.4390e-02, -5.0144e-02, -5.0049e-02,  3.2448e-02],
                [ 7.5607e-02, -6.8160e-02, -1.8825e-02,  6.0839e-02,  2.1488e-02],
                [-6.6120e-02,  6.7261e-02, -5.8349e-02,  4.2199e-02, -5.2458e-02],
                [-4.8494e-02, -5.4509e-02,  9.2273e-05, -3.4371e-02,  5.9357e-04]],

               [[-7.8414e-02,  7.0503e-02, -3.8403e-02, -4.4648e-02,  7.5146e-02],
                [-5.6907e-02,  3.4213e-05, -6.8368e-02, -5.0295e-02, -4.3787e-02],
                [-5.8034e-02, -6.0247e-02, -5.8353e-02, -2.5280e-02,  6.9298e-02],
                [-1.4764e-03,  4.3781e-02, -2.5240e-02, -4.3813e-02,  2.8583e-02],
                [-4.1418e-02, -4.6101e-02,  2.7372e-03, -3.2267e-02,  5.5665e-02]],

               [[-7.9642e-02,  5.8555e-02,  5.8883e-02, -6.3063e-02, -3.0692e-02],
                [ 5.0561e-02, -6.8027e-02,  4.7893e-02, -6.2822e-02,  1.0364e-02],
                [-6.9147e-02, -7.2420e-02,  7.8318e-02,  6.8844e-02,  5.5427e-03],
                [ 6.0117e-02, -3.5221e-02,  7.4160e-02, -1.2965e-02,  6.2557e-02],
                [ 4.4865e-04, -6.0832e-02,  7.6485e-03,  6.0057e-02,  2.3943e-02]],

               [[-2.4230e-02,  6.7864e-02,  7.2121e-02,  3.4157e-02, -7.9538e-02],
                [-3.3560e-02,  1.1223e-02, -1.7517e-02, -5.4311e-02,  6.5203e-02],
                [-7.7116e-02, -6.3548e-02,  3.0147e-02, -5.6751e-02, -4.5155e-02],
                [ 7.9671e-02, -6.0086e-02, -6.9947e-02,  1.2981e-02, -2.5239e-02],
                [-2.4339e-03,  7.9699e-02,  2.4683e-02, -4.5459e-02,  9.2444e-03]],

               [[ 4.2156e-02, -5.2864e-02,  7.0858e-02,  7.3412e-02, -1.2266e-02],
                [ 4.9417e-02,  4.1968e-02, -6.1616e-02, -1.7818e-02, -2.4427e-02],
                [ 2.1473e-02, -1.1306e-02,  4.7061e-02, -2.3690e-02, -6.6106e-02],
                [ 8.0653e-02, -6.1394e-02, -9.6741e-03,  3.2104e-02, -3.5300e-02],
                [-3.7291e-02, -1.6968e-02,  1.3973e-02,  5.5290e-02,  4.0622e-03]]],


              [[[ 2.2983e-02,  1.0195e-02,  4.0193e-02, -7.5025e-02, -1.5421e-03],
                [ 6.0510e-02,  1.8363e-02, -3.8094e-02,  4.5445e-02, -2.9622e-03],
                [-5.3075e-02,  3.4295e-02,  5.0654e-02, -1.9342e-02,  3.2361e-02],
                [ 7.3380e-02, -1.0884e-02, -5.3324e-02, -5.0394e-02,  3.1872e-02],
                [-4.9773e-02, -2.8900e-02, -2.6879e-02, -2.8097e-02,  6.6398e-02]],

               [[-2.5804e-02, -4.6905e-02, -4.2523e-02, -5.0381e-02, -2.0208e-02],
                [-4.8815e-02,  2.2532e-02,  6.9881e-02,  2.1225e-02,  4.3858e-02],
                [ 3.8482e-02,  7.5890e-03,  1.4969e-02,  1.8850e-02, -2.9226e-02],
                [ 3.2008e-02,  2.3823e-02,  7.4640e-02, -3.2508e-02, -5.9983e-02],
                [ 2.0176e-02, -2.4570e-02, -4.5569e-02,  6.8979e-02, -4.4682e-02]],

               [[-4.6043e-02, -3.2025e-02,  2.8151e-02, -4.8214e-02,  7.1796e-02],
                [ 6.4546e-02, -2.4392e-03,  6.5753e-02,  1.0233e-02,  3.1131e-02],
                [-2.9636e-02,  7.3704e-02, -7.7781e-02,  8.0873e-02, -8.5714e-04],
                [ 3.0164e-02, -6.3718e-02,  8.1516e-02,  1.4934e-02,  6.1368e-02],
                [ 4.9876e-02,  9.5079e-03,  7.7987e-02, -3.8619e-02,  5.9578e-02]],

               [[-4.2667e-03, -4.3751e-02, -2.8233e-02, -4.1592e-02, -6.0769e-02],
                [ 4.8450e-02, -1.2896e-02, -4.6601e-03, -6.4206e-03, -1.1431e-02],
                [-7.7503e-02, -1.0773e-02, -8.0490e-03, -3.1181e-03,  1.3285e-02],
                [ 6.8160e-02,  6.7861e-02,  1.4774e-02, -1.3599e-02, -2.4826e-02],
                [ 4.7705e-02, -7.3378e-02, -1.3825e-02, -3.1077e-02, -3.7972e-02]],

               [[ 5.6202e-02,  4.1837e-02, -2.8773e-02,  7.2132e-02,  2.5515e-02],
                [-3.3901e-02,  7.0663e-03,  6.8551e-02, -9.9631e-03,  3.0268e-02],
                [-7.2736e-02,  2.3806e-02,  2.9459e-02,  6.0356e-02, -4.3899e-03],
                [-2.8373e-02, -1.1198e-02, -1.2081e-02,  1.7582e-02, -6.8985e-02],
                [-7.2981e-02, -5.7756e-02, -2.6205e-02,  5.1297e-02, -5.9565e-02]],

               [[ 6.5016e-03,  6.3833e-02,  1.0038e-02, -7.4870e-02, -8.0244e-02],
                [-1.2594e-02,  2.0456e-02, -4.8056e-02,  1.1776e-02, -7.6218e-02],
                [ 5.1599e-02, -9.6947e-03, -1.6420e-02, -5.6519e-02, -7.0490e-02],
                [-5.8097e-02,  7.7906e-03, -3.3099e-03,  1.3623e-02, -5.2314e-02],
                [-1.3419e-02,  6.9299e-02,  7.2203e-02,  4.8437e-02, -4.1129e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[-4.6757e-02,  1.3288e-01, -1.0601e-01,  4.9272e-03,  2.7013e-02],
                [ 2.6809e-02,  2.9295e-02,  3.6850e-02, -7.2724e-03,  4.7542e-02],
                [-9.0670e-02,  9.5702e-02, -3.9494e-03, -1.5364e-02,  6.0347e-03],
                [-3.4580e-02, -3.7414e-02, -1.1697e-01, -2.5932e-02,  1.2895e-01],
                [ 2.3890e-02, -9.9844e-02, -8.2125e-02, -2.4534e-02, -1.7775e-02]],

               [[-6.7097e-03,  1.0264e-01, -3.6296e-02, -7.4812e-03, -2.2599e-02],
                [-3.5118e-02,  6.8040e-02, -1.5194e-01, -6.6822e-02,  3.5745e-02],
                [ 1.6445e-01,  9.8256e-02,  1.2114e-01, -4.1316e-02, -1.6552e-02],
                [ 2.3123e-02,  3.2617e-02, -6.9864e-02,  7.1308e-04, -1.5938e-02],
                [-1.2007e-01, -4.5737e-02, -3.4723e-02,  1.0539e-01,  7.5069e-02]],

               [[ 5.1764e-03, -4.7447e-02, -9.1823e-02,  1.9314e-02, -1.2899e-02],
                [ 2.8299e-02,  9.8341e-02,  4.5279e-02, -7.5509e-02, -6.5473e-02],
                [-8.1838e-02, -1.3693e-01,  5.8116e-02,  7.9483e-02,  1.4864e-02],
                [ 8.3024e-03, -2.8969e-02,  1.6709e-02,  8.6332e-02, -1.2912e-01],
                [ 9.1989e-02, -1.0204e-01, -3.5730e-02,  2.1878e-02,  3.9299e-02]],

               [[-8.4570e-03,  7.0740e-02,  9.5664e-02,  1.1943e-01,  7.8352e-03],
                [-5.2236e-02,  3.5574e-02, -6.2162e-02, -2.6292e-02,  4.3690e-02],
                [ 4.4340e-03,  4.6677e-02, -7.7338e-02, -1.2434e-01,  1.8030e-02],
                [ 5.7909e-02,  9.0265e-02, -2.7924e-03, -1.1392e-01,  1.0558e-01],
                [ 5.5833e-02,  5.4252e-02, -9.2561e-02, -2.0060e-02,  4.0309e-02]],

               [[-4.0318e-02,  5.0439e-02, -1.5511e-01, -7.2652e-02, -7.1681e-02],
                [ 3.3902e-02, -1.9146e-02, -2.6171e-02,  2.6969e-02,  1.1980e-01],
                [-1.7600e-02, -3.5479e-02, -1.3453e-01,  3.0411e-02, -1.8185e-02],
                [ 3.5787e-02,  9.1613e-02,  3.4231e-02, -8.5526e-02,  8.9608e-02],
                [ 8.0859e-02, -5.1546e-02, -7.0114e-02,  1.7646e-02, -2.3049e-03]],

               [[-3.7738e-02,  5.5325e-02, -1.2705e-01, -1.4088e-01, -3.4257e-03],
                [-4.7597e-02, -6.1689e-02,  1.0258e-02,  8.6211e-02,  5.5421e-02],
                [ 1.0676e-01, -1.0346e-01,  6.0882e-02,  5.1857e-02,  1.1405e-02],
                [-4.7731e-02, -6.7521e-02,  7.2658e-02, -5.9263e-02, -9.3708e-02],
                [ 6.4228e-02,  7.7473e-02,  3.8897e-02, -2.7823e-02,  3.3962e-02]]],


              [[[ 3.2504e-03, -6.7740e-02, -3.0733e-02,  3.1548e-02,  8.1641e-02],
                [-2.0448e-01, -1.4400e-02, -4.9462e-02, -3.7033e-02,  7.9739e-02],
                [-8.0209e-02, -8.2590e-02, -1.0695e-02,  4.0322e-02, -7.3355e-02],
                [-2.8050e-02,  2.3734e-02,  9.8712e-02,  7.6730e-02, -1.4703e-01],
                [-5.6083e-02,  6.0224e-02, -8.6158e-02,  8.4946e-02,  2.6570e-02]],

               [[ 1.3426e-02, -4.9485e-03,  2.7120e-02, -4.2500e-02,  3.8941e-02],
                [-6.6332e-02, -9.4494e-02, -9.7996e-02, -8.4159e-02,  3.6172e-03],
                [ 9.4459e-02, -1.8795e-02, -1.0849e-01, -1.5571e-01, -4.1801e-02],
                [-7.7880e-02,  1.8528e-02, -2.2428e-03,  7.7880e-02, -7.2803e-02],
                [-7.4748e-02, -1.1186e-01,  4.1013e-02,  3.6887e-02,  8.4037e-02]],

               [[ 6.5337e-02,  4.5208e-03,  1.5117e-01,  8.8387e-02, -2.8306e-02],
                [ 3.3486e-02, -5.9216e-02, -6.5075e-02,  6.9757e-02, -1.8149e-02],
                [-8.1113e-02,  1.6794e-01,  1.2086e-02,  1.0645e-02,  3.0739e-03],
                [ 1.4324e-02, -7.6908e-03, -6.0393e-02,  8.6236e-03,  7.2454e-03],
                [ 2.7266e-02, -4.2846e-04,  1.1333e-02,  1.0053e-01,  2.4300e-02]],

               [[ 3.1538e-02, -6.9510e-02,  1.1773e-01,  5.9435e-02, -1.8069e-02],
                [-1.0187e-01,  5.6848e-02,  1.0995e-01, -4.2845e-02, -4.2273e-02],
                [ 1.0056e-01,  1.0755e-01, -4.1334e-02, -6.9923e-02, -7.8262e-02],
                [ 3.8990e-02, -1.0840e-02,  2.0285e-02, -7.4898e-02,  1.0637e-01],
                [-2.6133e-02,  2.1303e-03, -5.4463e-02,  5.1494e-02,  2.1713e-02]],

               [[-5.6810e-02,  6.5028e-02,  8.5436e-03, -5.1192e-03, -3.0413e-02],
                [-2.1736e-02, -4.8212e-02, -6.9891e-02,  1.9763e-02, -6.8881e-02],
                [ 1.3820e-01, -4.0725e-02,  6.3567e-03,  1.8803e-02, -6.3437e-03],
                [ 1.3227e-01, -4.8645e-02,  1.0116e-01,  3.3921e-02,  3.4794e-03],
                [ 1.1039e-01,  3.6667e-02,  2.6818e-02, -1.5001e-01,  1.5748e-01]],

               [[-1.5231e-01, -8.2144e-02, -4.9698e-02, -1.3508e-01, -4.8011e-03],
                [ 5.2455e-02, -5.3886e-05,  8.3577e-02,  1.7951e-02, -1.8239e-02],
                [ 7.6571e-02,  5.1307e-02,  1.5132e-01,  1.0276e-01,  2.4194e-03],
                [ 7.5639e-02, -1.4870e-02,  5.7076e-02,  3.3924e-02, -3.2090e-02],
                [-1.3698e-01,  1.3499e-02, -9.5915e-02, -1.4111e-01,  5.8600e-02]]],


              [[[ 2.0662e-02,  6.4280e-02,  1.9997e-02, -4.7385e-02, -1.3667e-02],
                [-8.8869e-02, -7.1822e-02, -1.2350e-01, -7.7392e-02,  1.3802e-01],
                [-9.8538e-02,  4.7742e-02, -8.0551e-02,  1.5743e-02, -2.6337e-02],
                [-1.6897e-02, -9.6953e-02, -1.0730e-01,  8.2337e-02,  5.0256e-02],
                [-3.2799e-02,  5.4430e-02, -4.8982e-02,  2.5748e-02,  1.6828e-02]],

               [[-1.7873e-01, -1.4042e-02,  9.7526e-02,  6.6798e-02, -9.6164e-02],
                [ 9.1119e-02,  3.2205e-02,  1.5595e-01,  5.7516e-02,  1.5295e-01],
                [-2.6976e-02, -4.3605e-02, -3.6130e-02,  2.6217e-03,  5.0976e-02],
                [-5.5043e-02,  1.9190e-01, -5.0940e-03, -3.5246e-02,  1.2800e-01],
                [-4.0770e-02, -5.8901e-02,  1.1812e-01,  4.8643e-02, -7.3327e-02]],

               [[ 1.2682e-01,  1.2839e-01,  6.2409e-02,  1.0035e-01,  1.0129e-01],
                [ 7.2187e-02,  6.0281e-02,  1.3148e-01,  7.6282e-02, -1.5700e-02],
                [ 6.6575e-04, -1.3292e-01,  1.0657e-02, -3.7122e-02,  8.5044e-02],
                [ 3.2755e-03,  9.3671e-02,  3.9127e-03,  9.0736e-02,  9.7792e-02],
                [ 1.3001e-01,  3.8121e-02,  5.0893e-02,  1.9933e-02, -7.5944e-02]],

               [[-9.5800e-02,  1.5913e-02, -4.8295e-02,  3.4044e-02,  5.9260e-02],
                [-5.6601e-02, -4.2292e-04, -4.7735e-02, -7.0158e-03,  4.3759e-02],
                [ 3.3761e-02, -1.2779e-01, -3.9022e-02,  6.8614e-02, -4.8467e-03],
                [ 3.8670e-02, -2.9852e-02, -6.6050e-02, -7.3340e-02, -7.6644e-02],
                [ 1.1798e-02, -3.0740e-02, -5.0641e-02, -1.2520e-03, -7.6266e-02]],

               [[ 5.0593e-02,  2.0795e-02,  6.8589e-02,  2.9682e-02,  2.4102e-02],
                [ 9.1989e-02, -4.7258e-02,  2.0969e-02,  1.2252e-02,  8.5911e-02],
                [-4.1193e-02, -2.9648e-02, -7.0664e-02, -6.0782e-03,  8.7604e-02],
                [ 9.0460e-02,  4.0291e-02,  2.0270e-01,  2.7138e-02,  6.1062e-02],
                [-5.7255e-02, -1.2864e-01,  5.3703e-02,  2.1063e-02,  8.3378e-03]],

               [[ 4.5253e-02, -1.7370e-02, -8.8588e-02,  3.3718e-02,  1.2469e-02],
                [-2.0238e-02,  8.9196e-02,  4.8577e-02,  2.2773e-02,  4.3565e-02],
                [-2.2077e-02, -7.2846e-02,  2.2011e-02,  2.7994e-02,  7.1498e-02],
                [ 9.7228e-02, -8.7686e-02, -1.9475e-02, -3.4963e-02,  1.0454e-02],
                [-9.2156e-02, -2.6965e-02, -1.1747e-01, -4.5681e-02, -7.0946e-02]]],


              ...,


              [[[-2.4093e-02, -4.6103e-02, -6.8426e-02,  7.7020e-02, -7.0480e-02],
                [-7.8094e-02, -1.7899e-03,  1.5691e-02,  7.5024e-02, -6.8291e-02],
                [-4.2985e-02, -6.2913e-02, -8.8126e-02, -5.1507e-03, -5.0252e-02],
                [ 1.4346e-02, -5.1735e-02,  1.4216e-01, -1.6072e-02, -1.9065e-02],
                [ 4.0965e-02,  2.8348e-02,  8.5792e-02,  4.9581e-02,  4.2866e-03]],

               [[ 7.8131e-02,  6.6378e-03,  5.6797e-02, -4.0113e-02,  4.6595e-02],
                [-4.8508e-02, -1.2878e-02,  3.0165e-02, -1.2445e-01, -8.1256e-02],
                [-2.0725e-02,  1.1674e-01, -1.5972e-01,  6.4756e-02, -2.6744e-03],
                [ 1.1154e-01, -9.4811e-02,  9.5269e-02,  4.3903e-02, -1.0204e-01],
                [ 9.8562e-02,  1.2404e-03,  4.1380e-02,  3.1553e-02,  5.3186e-02]],

               [[ 2.6055e-02,  2.7388e-02, -6.0488e-02, -9.5511e-02, -1.2093e-01],
                [-2.4814e-02, -1.2639e-01, -1.5237e-02,  5.5561e-02,  3.0209e-02],
                [-8.2053e-02, -4.5935e-02,  6.1119e-02, -1.3363e-01, -7.1992e-02],
                [-1.5287e-01, -6.1410e-02,  6.0375e-03, -8.7979e-02,  9.1407e-02],
                [-1.0431e-02,  6.0618e-02, -1.5103e-01,  1.1385e-01, -4.2957e-02]],

               [[ 5.2269e-02, -8.1624e-05, -6.7309e-02,  2.2620e-02, -8.8212e-02],
                [-4.9433e-02, -1.1957e-02, -5.3011e-02,  5.0615e-02,  3.2292e-02],
                [-8.3719e-02, -4.0361e-02,  1.0135e-01, -4.5393e-02, -1.9368e-01],
                [-7.6934e-02, -9.0911e-02, -3.0937e-02,  8.6783e-02,  1.7257e-02],
                [-2.0637e-03,  1.0265e-01, -6.3748e-02,  3.5307e-02, -4.9159e-02]],

               [[-5.7169e-02,  1.0791e-01,  6.7873e-02, -5.5937e-02,  5.8612e-03],
                [-4.9908e-02, -1.9537e-02,  3.3519e-02,  4.3909e-02, -5.1788e-02],
                [-4.6378e-02, -2.4492e-02,  1.7248e-02,  2.1012e-02,  8.3702e-03],
                [ 4.1057e-02, -6.2135e-02, -2.2248e-02, -1.6534e-02,  1.2725e-02],
                [-9.7706e-02, -4.1293e-02,  2.0015e-02,  6.2570e-02,  1.2564e-01]],

               [[ 3.4413e-02, -1.4170e-01, -2.1922e-02,  1.0131e-01, -1.7902e-03],
                [ 9.0206e-02,  9.9272e-02, -1.1902e-01,  1.9933e-03, -3.6138e-02],
                [-6.4406e-02, -1.4185e-02, -1.7267e-02, -4.1037e-02,  2.5467e-02],
                [-1.9643e-02, -1.5773e-01,  1.4423e-01, -5.9313e-02,  9.5439e-02],
                [ 4.9707e-02,  5.7449e-02,  3.8146e-02, -7.0734e-02, -2.7587e-03]]],


              [[[ 7.7226e-02,  5.9220e-02, -4.4861e-02,  9.5649e-02,  4.2409e-02],
                [ 7.6165e-02,  1.4427e-01, -1.8052e-02, -4.2129e-03, -1.1270e-01],
                [ 7.9877e-03, -1.1160e-01, -8.1328e-03, -6.2670e-02,  5.2545e-02],
                [ 3.0828e-02,  3.9485e-02,  2.2733e-02,  3.3326e-02, -6.9556e-02],
                [-1.6844e-02,  1.1067e-02, -1.1267e-01,  3.9306e-03, -2.8748e-02]],

               [[-1.2918e-01, -9.7690e-03,  3.9843e-02, -3.0822e-02,  8.3621e-02],
                [-3.9798e-02,  1.2642e-03,  5.3833e-02, -7.6645e-02,  1.5902e-02],
                [ 7.7658e-02, -6.3720e-02, -5.3355e-02,  1.1662e-01,  1.0570e-01],
                [-8.8436e-02,  2.8177e-02, -2.8857e-02,  5.6628e-02, -3.3440e-02],
                [-4.2123e-02, -8.4260e-02,  5.4268e-02, -3.7032e-02,  6.7248e-03]],

               [[-3.7819e-02,  5.5044e-02, -7.9974e-02, -8.9877e-02,  2.6892e-02],
                [ 6.9266e-02, -2.9285e-02, -3.6309e-02, -1.1215e-01, -6.7951e-02],
                [-3.6242e-03, -7.9844e-02, -4.0420e-02,  4.0198e-03, -3.4907e-03],
                [-1.0090e-03, -2.3102e-02, -8.5352e-02, -4.3441e-02,  1.3924e-02],
                [-8.5443e-02, -3.8716e-02,  3.6130e-02, -4.8093e-02,  1.1988e-01]],

               [[-4.3373e-02,  9.9957e-02,  1.4781e-01, -1.3750e-01,  2.6857e-02],
                [ 4.4002e-02, -7.9690e-02, -3.7093e-02,  5.4835e-03, -8.9189e-03],
                [-8.6401e-02, -1.4609e-01,  1.1692e-01, -8.3849e-03,  6.4011e-02],
                [ 7.7343e-02, -4.4404e-02,  1.3719e-01, -7.4191e-02,  7.1682e-02],
                [ 4.9977e-02, -7.5457e-02,  2.9231e-02,  1.8797e-01,  1.4974e-01]],

               [[ 2.4268e-02,  6.1654e-02,  6.3272e-02, -1.0978e-02, -1.6786e-01],
                [ 3.4399e-02,  3.4448e-02, -4.8970e-02, -6.7921e-03,  6.8282e-02],
                [-1.0370e-01, -8.1308e-02,  9.1539e-03, -2.5933e-02, -3.0929e-02],
                [ 7.3665e-02, -1.6387e-01, -9.2283e-02, -8.2550e-02, -4.3681e-02],
                [-6.1899e-02,  1.7591e-01,  2.2806e-03, -1.0136e-01, -3.4592e-02]],

               [[ 1.2506e-02, -1.0966e-01, -1.2902e-02,  4.1343e-02, -1.6163e-02],
                [ 6.2248e-02,  4.3803e-02, -1.4526e-01, -1.5219e-01,  1.8110e-03],
                [ 9.2129e-02, -7.6849e-02,  1.3811e-02,  3.2184e-02, -1.6774e-02],
                [ 1.0622e-01, -1.3869e-02, -1.4605e-02, -2.6013e-02, -6.4145e-02],
                [-2.6536e-02, -4.1028e-02, -2.1912e-03,  8.6890e-02,  4.4823e-02]]],


              [[[ 3.6491e-02, -6.4714e-02, -6.0897e-03, -4.6173e-02,  2.3836e-02],
                [-1.0014e-02,  7.0377e-03, -2.4878e-02,  6.7624e-02, -6.6327e-04],
                [-9.7066e-02,  4.3811e-03,  1.2715e-01, -8.9454e-02, -7.0118e-02],
                [ 6.5933e-02, -4.7004e-03, -1.1758e-01, -1.6999e-02,  5.3587e-02],
                [ 2.3134e-02, -8.7674e-02,  6.0083e-03, -1.8805e-02,  1.6294e-01]],

               [[-5.5460e-02, -1.1708e-02, -4.1244e-02, -3.7981e-02, -2.7246e-03],
                [-6.9340e-02,  1.9413e-02,  1.3383e-01,  4.9228e-02,  8.6531e-02],
                [ 6.4961e-02, -1.3002e-03,  2.9986e-02,  5.1293e-02, -1.2678e-02],
                [ 7.0438e-02,  1.1205e-01,  1.7375e-01,  3.3223e-02, -7.5450e-02],
                [ 4.1298e-02, -1.6560e-02, -8.0311e-02,  1.0686e-02,  3.0853e-02]],

               [[-2.6702e-02, -6.8132e-02, -1.7867e-02, -5.7511e-02,  5.7227e-02],
                [-3.6242e-02, -4.1175e-02,  9.6871e-02, -6.2403e-02,  3.0870e-02],
                [ 3.5389e-02,  1.6592e-02, -1.9059e-01,  8.0625e-02,  3.5002e-02],
                [ 7.4076e-03, -8.5312e-02,  2.7993e-02,  1.2907e-01,  5.8467e-02],
                [ 6.3744e-02,  2.3969e-02,  5.5898e-02, -4.3199e-02,  1.6609e-02]],

               [[-5.8074e-03, -5.6821e-02, -6.2274e-02, -7.0582e-02,  2.7614e-02],
                [ 2.2282e-02, -4.4088e-02, -8.4117e-03, -3.5089e-02, -1.1400e-01],
                [-3.2177e-02, -2.5463e-02,  1.3630e-02, -7.4737e-02,  5.4067e-02],
                [ 1.3170e-02,  4.7218e-02, -3.8321e-03, -3.4776e-02, -7.6640e-02],
                [ 1.1996e-02, -6.3880e-02, -4.6459e-02, -3.3479e-02, -5.4155e-03]],

               [[ 1.4449e-02,  4.5700e-02, -9.0163e-02,  5.7394e-03,  4.8105e-02],
                [-5.1704e-02, -9.8552e-02,  1.1284e-01,  5.1767e-02,  7.1735e-02],
                [-1.8952e-02,  1.6876e-01,  6.1535e-02,  9.8161e-02,  1.1692e-02],
                [-5.5695e-02, -9.7229e-02,  3.5393e-02,  3.6695e-02, -5.0597e-02],
                [-8.2985e-02, -2.6743e-03, -8.6268e-02,  2.9535e-02, -6.2120e-02]],

               [[ 9.9467e-02,  2.2627e-01, -5.1615e-02, -7.7695e-03, -1.3050e-01],
                [-1.4781e-02,  1.0199e-01, -1.3160e-01,  5.6579e-03, -5.0954e-02],
                [ 1.3510e-02,  1.2833e-01, -5.0731e-02, -5.0934e-02, -7.4834e-02],
                [-1.9434e-03,  7.5700e-02,  4.0693e-02,  6.2565e-02, -1.0121e-01],
                [-1.0599e-01, -3.2444e-03,  2.9434e-02, -2.1970e-02, -1.1739e-02]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0365,  0.0072,  0.0093,  0.0025,  0.0342,  0.0735,  0.0308,  0.0723,
              -0.0758,  0.0133,  0.0423,  0.0314, -0.0622,  0.0366,  0.0547, -0.0157],
             requires_grad=True)
       tensor: tensor([ 0.0190, -0.0501,  0.0052, -0.1015, -0.0062,  0.0528,  0.0634,  0.0589,
              -0.1459, -0.0041,  0.0811,  0.0519, -0.0044, -0.0696,  0.0491, -0.0273],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., 0., -0., -0., 0.],
                [0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.]],

               [[0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., -0., -0., 0.]],

               [[0., 0., 0., 0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [0., 0., -0., -0., 0.]],

               [[-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.]]],


              [[[0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., -0., 0., 0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.]],

               [[-0., -0., 0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [0., -0., -0., 0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [0., -0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., 0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.]]],


              [[[0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[-0., 0., 0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., 0.],
                [-0., -0., 0., 0., -0.]],

               [[0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., -0., 0., -0.]],

               [[-0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.]],

               [[0., -0., 0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., 0., -0., -0.]],

               [[0., 0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., -0., 0., 0.],
                [-0., -0., -0., -0., -0.]]],


              ...,


              [[[0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., 0., -0.]],

               [[0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., -0.],
                [0., 0., 0., 0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., 0., -0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., -0., -0.]],

               [[-0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., 0.]],

               [[-0., -0., 0., 0., 0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [0., 0., 0., -0., 0.]]],


              [[[0., 0., -0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., -0.]],

               [[-0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.]],

               [[-0., 0., 0., 0., -0.],
                [-0., 0., -0., -0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [0., -0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[0., 0., 0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., 0., 0., 0., -0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., 0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [0., 0., 0., -0., 0.]],

               [[-0., -0., -0., -0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., -0., -0., -0., -0.]],

               [[0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.]],

               [[0., 0., 0., -0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., 0., 0., -0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 4.2191e-02,  7.6840e-02, -7.4705e-02, -2.1865e-02,  6.1931e-02],
                [ 4.5672e-02, -4.5500e-02,  6.8328e-02, -1.9356e-02,  6.2420e-03],
                [-9.1323e-03,  6.4114e-02, -6.2778e-02, -2.9379e-02,  9.7347e-03],
                [ 1.5921e-02, -2.0886e-02, -6.3837e-02, -4.3951e-02,  5.5913e-04],
                [-1.2263e-02, -4.8414e-02, -4.4003e-02, -2.4253e-02,  3.2497e-02]],

               [[ 2.3382e-02,  3.9259e-02, -7.7953e-02, -3.2520e-02,  1.8270e-02],
                [-4.3315e-02,  7.9060e-02, -5.8311e-02, -2.8759e-02, -8.7089e-03],
                [ 5.5321e-02,  4.8438e-02,  6.5735e-02, -4.9760e-02, -3.5508e-02],
                [ 4.1005e-02,  4.6544e-02,  1.5205e-02, -3.6013e-02,  6.9640e-03],
                [-6.8387e-02, -4.9177e-02,  4.0450e-02,  7.0279e-02,  7.5872e-02]],

               [[-4.6557e-03,  1.8276e-02, -5.8220e-02, -2.1548e-02,  5.6864e-02],
                [-1.3140e-02,  3.1058e-02,  1.7065e-02, -6.0197e-02, -9.2233e-03],
                [-5.6885e-02, -6.7734e-02,  2.2703e-02,  7.8996e-02, -2.3009e-02],
                [-3.2973e-02,  3.9733e-02, -4.7892e-02,  1.3657e-02, -6.1810e-02],
                [ 6.4573e-02, -2.9753e-02, -5.5327e-02, -9.4562e-03,  2.1779e-02]],

               [[ 7.1124e-02,  4.0486e-02,  6.9724e-02,  7.6075e-02,  4.1246e-03],
                [ 8.3202e-03,  2.3893e-02, -6.6808e-03, -4.9518e-02,  3.8584e-02],
                [-5.2714e-02,  5.0159e-02, -6.6929e-02, -7.0058e-02, -4.7074e-02],
                [ 7.5636e-02,  2.7516e-02, -5.3383e-02, -7.9865e-02,  7.8899e-02],
                [ 6.7598e-02,  2.2288e-02, -4.3938e-02, -2.2255e-02,  3.8803e-02]],

               [[-1.8988e-02,  6.3276e-02, -6.4471e-02, -3.8395e-02, -5.1800e-02],
                [ 2.6698e-02,  5.7260e-03, -2.2055e-02,  2.9274e-02,  1.7332e-02],
                [-1.0218e-02, -2.7324e-02, -2.9080e-02,  1.3222e-02,  7.9084e-02],
                [ 4.3772e-02,  7.7684e-02,  4.8257e-02, -6.8435e-02, -2.0343e-02],
                [ 3.0093e-02,  3.1836e-03,  2.6288e-03, -8.2295e-03, -2.5020e-02]],

               [[-3.8859e-02,  5.2245e-02, -7.8403e-02, -4.2871e-02,  6.1650e-02],
                [-4.8324e-02,  5.1745e-02,  5.0752e-02,  7.3504e-02, -2.3351e-02],
                [ 7.5441e-02, -5.1618e-02,  8.0429e-02,  7.3450e-02,  7.8664e-02],
                [ 4.9288e-02, -1.9266e-02,  2.9147e-02, -8.0626e-02, -3.1476e-02],
                [ 6.6063e-02,  1.7808e-02,  2.6498e-02,  4.6047e-03,  7.2748e-02]]],


              [[[ 2.7147e-02, -6.4806e-02, -1.7023e-02,  4.9091e-03, -6.6140e-03],
                [-7.7151e-02, -1.2995e-02, -4.5008e-02, -1.8290e-02,  5.0563e-02],
                [-7.6304e-02, -8.0019e-02, -1.1984e-02, -3.8425e-02, -4.5299e-03],
                [-1.4411e-02,  5.2870e-02,  7.4521e-02,  4.9850e-02, -2.3332e-02],
                [ 3.2249e-03,  4.7371e-02, -5.4040e-02,  7.8740e-02,  6.2129e-02]],

               [[-5.5597e-02, -7.9082e-03,  5.7858e-03, -1.7938e-02,  3.2981e-02],
                [-5.7644e-02, -7.8341e-02, -5.2250e-02, -7.0176e-02,  9.1169e-03],
                [ 4.8751e-03,  2.7214e-02, -7.2253e-02, -7.9531e-02,  2.2115e-02],
                [-8.0066e-02,  6.1143e-03,  9.5127e-03,  7.6738e-02, -5.4762e-02],
                [-4.0489e-02, -4.4740e-02,  7.4822e-02,  5.8303e-02,  3.4314e-02]],

               [[ 3.2335e-02, -7.6859e-03,  7.2277e-02,  6.4942e-02, -2.1889e-02],
                [ 6.7059e-02, -2.5517e-02, -7.2487e-02,  6.3138e-02,  1.8800e-02],
                [-4.1668e-02,  3.3260e-02,  2.0425e-03, -9.8570e-03, -5.2738e-02],
                [ 5.0590e-02, -1.0069e-02, -1.0559e-02,  1.4182e-02,  5.3266e-02],
                [-1.0744e-02, -2.3853e-02,  1.9477e-02,  1.2789e-02, -1.9533e-03]],

               [[-7.1060e-03, -4.1691e-02,  3.4899e-02, -7.2925e-02, -5.2534e-02],
                [-4.6110e-02,  7.5498e-02,  6.1524e-02,  1.1099e-02,  7.3727e-02],
                [ 5.4388e-02,  3.9116e-02, -6.6412e-02, -2.6689e-02, -2.0881e-02],
                [ 2.6913e-02,  3.5247e-02, -2.6048e-02,  1.4267e-02,  2.6323e-02],
                [ 6.5432e-03, -7.3995e-03, -4.3528e-02,  5.0431e-02,  6.9511e-02]],

               [[-2.8527e-02,  6.1429e-02,  6.0427e-02, -4.1141e-02, -7.3619e-02],
                [ 2.1947e-02, -3.8028e-02, -6.7364e-02,  1.5408e-02,  5.2976e-02],
                [ 5.8108e-02, -6.1758e-02,  1.4653e-02, -7.6558e-02, -3.1448e-02],
                [ 6.2695e-02,  6.4849e-02,  3.1678e-02,  8.7257e-03,  8.9366e-03],
                [ 4.3787e-02, -2.6563e-02,  7.6010e-02, -4.9455e-02,  4.6068e-02]],

               [[-4.6879e-02, -1.0211e-02, -5.7937e-02, -6.0173e-02, -5.6010e-03],
                [ 6.7886e-02,  4.8736e-02,  7.5209e-02, -2.4597e-02,  1.2666e-02],
                [ 2.7294e-02,  5.9385e-02,  8.1056e-02,  7.2644e-02, -7.7046e-02],
                [ 1.7380e-02, -3.2567e-02, -2.1987e-02,  7.8235e-02, -5.2908e-02],
                [-2.0811e-02, -2.9994e-02, -4.6838e-02, -4.2671e-02,  3.8727e-02]]],


              [[[ 3.5134e-02,  5.2745e-02, -6.1149e-02, -7.6414e-02, -2.6176e-02],
                [-6.1641e-02, -8.1041e-02, -2.5100e-02, -7.2410e-02,  6.1282e-02],
                [-3.0138e-03,  5.0077e-02, -8.1627e-02, -4.3763e-02, -3.0137e-03],
                [-4.0017e-02, -4.3606e-02, -5.0771e-02, -1.0728e-02,  6.0225e-02],
                [-6.3614e-02,  5.9101e-02, -1.1370e-02,  6.4711e-02, -1.0095e-02]],

               [[-5.0975e-02,  6.0192e-02,  4.8900e-02,  8.0230e-02, -1.5046e-02],
                [ 7.2609e-02,  3.1814e-02,  6.7249e-02,  5.9835e-02,  7.7418e-02],
                [ 2.9254e-02,  3.7293e-02,  1.1221e-02,  3.3947e-03,  4.2694e-02],
                [-6.1252e-02,  3.9850e-02,  8.3872e-03,  6.4936e-02,  5.5738e-02],
                [-3.8509e-02, -5.6070e-03,  6.8840e-02,  6.8610e-02, -4.3579e-02]],

               [[ 1.8686e-02,  2.1468e-02,  5.9046e-02,  5.2732e-02,  5.1538e-02],
                [ 7.0772e-02,  4.9593e-02,  4.1163e-02,  6.8360e-02, -3.9729e-02],
                [-4.0441e-02, -1.5649e-02,  7.2554e-02, -2.2384e-02,  3.3869e-02],
                [-1.1148e-02,  7.9215e-02,  1.0318e-02, -4.9182e-02,  9.6277e-03],
                [ 4.1858e-02,  3.0790e-02, -3.0381e-02,  6.6102e-02, -5.1207e-02]],

               [[-6.6738e-03,  6.4776e-02, -3.3882e-02,  7.1115e-02,  3.1006e-02],
                [-2.2906e-02, -2.9046e-02,  3.5319e-02,  5.3543e-02,  2.1489e-02],
                [ 3.3100e-02, -3.5833e-02, -1.8264e-02,  6.3019e-03,  3.8628e-02],
                [-3.4829e-02, -2.0159e-02, -4.5294e-02,  1.4057e-02, -7.9188e-02],
                [-3.1353e-02, -1.8218e-02, -8.0638e-02, -5.0035e-02, -3.4570e-02]],

               [[ 6.1931e-03, -3.6366e-03,  2.6574e-02, -1.3003e-02,  6.1312e-02],
                [ 4.6450e-02, -4.6889e-02, -2.5151e-02, -1.6860e-02,  6.3430e-02],
                [-6.3583e-02,  3.9752e-02, -4.2166e-02, -2.7013e-02, -3.3751e-03],
                [ 5.6220e-02,  8.1038e-02,  7.8622e-02,  5.7729e-02,  3.2215e-02],
                [-3.7003e-02, -6.6767e-02,  6.3296e-02, -7.1348e-02, -3.1742e-03]],

               [[ 7.3035e-02,  4.1458e-02, -4.6261e-02,  3.0164e-02, -1.8932e-03],
                [ 3.8696e-02,  5.4963e-02,  4.5004e-02, -5.6026e-02,  1.2815e-02],
                [-8.0065e-02,  1.8459e-02,  7.6491e-02,  7.3781e-02,  5.5816e-02],
                [ 4.5870e-02, -8.1061e-02, -4.7139e-02,  5.6221e-02,  6.0546e-02],
                [-6.1282e-02, -2.2380e-02, -6.3779e-02, -6.8776e-02, -4.3537e-02]]],


              ...,


              [[[ 5.2454e-02, -4.7911e-02, -7.2297e-02, -2.0057e-02, -5.5949e-02],
                [-1.7783e-02, -8.5585e-03,  2.1769e-02,  7.7195e-02, -2.6911e-02],
                [-5.5650e-02, -7.1240e-02, -4.3191e-02,  7.2941e-02,  7.0596e-04],
                [-1.1619e-02,  9.3917e-03, -2.8626e-02, -2.5418e-02, -1.0257e-02],
                [ 5.1407e-02,  4.8896e-02,  6.0359e-02,  3.2419e-02, -1.9229e-02]],

               [[ 1.5427e-02,  4.6055e-02,  3.8077e-02,  1.9711e-02, -4.7648e-02],
                [-2.7328e-02, -7.1221e-02,  5.6410e-02, -2.5372e-02, -4.0767e-02],
                [-3.1127e-02,  6.7051e-02, -7.9245e-02,  6.6682e-02, -2.6589e-03],
                [ 6.9368e-02, -4.1969e-02,  5.7844e-02,  4.0387e-02, -5.6537e-02],
                [ 1.3153e-02,  4.0302e-02,  7.1897e-02,  2.1402e-02,  5.5573e-02]],

               [[-4.6272e-02,  2.1553e-02,  1.4332e-03, -4.3438e-02, -2.7539e-03],
                [-3.4410e-02, -7.9327e-02, -3.9320e-02,  7.7251e-02,  2.2694e-02],
                [-7.6279e-02, -3.1001e-02,  6.8287e-02, -7.7937e-02, -6.5913e-02],
                [-5.2818e-02, -3.7519e-02, -2.3019e-02, -3.6675e-02,  7.9680e-02],
                [-6.1808e-02,  5.1876e-02, -7.8567e-02,  8.1565e-02, -6.8310e-02]],

               [[ 5.5309e-02,  8.7365e-03, -5.6561e-03,  7.1133e-02, -4.6917e-02],
                [ 2.1503e-03,  8.4254e-03,  1.1892e-02,  7.9317e-02, -7.8111e-03],
                [-5.6072e-02, -6.5883e-02,  3.5538e-02, -7.8521e-02, -6.9720e-02],
                [-1.9393e-02, -7.3491e-02,  8.1700e-03, -1.9778e-03,  2.8227e-02],
                [-1.6722e-02,  5.5947e-02,  6.2404e-03, -5.1478e-03, -2.5255e-02]],

               [[-7.8407e-02,  7.2449e-02, -6.2823e-02, -6.6556e-03, -1.6768e-02],
                [-3.4428e-02, -1.3495e-02,  6.7206e-04,  2.5644e-02, -3.1432e-02],
                [-7.0893e-02,  2.9846e-02, -5.0699e-03,  2.9505e-02, -1.8095e-02],
                [ 6.7203e-02,  5.6675e-03, -3.5811e-02,  1.9892e-02,  5.2260e-02],
                [-6.8506e-02, -7.1065e-02,  3.3085e-02,  4.3750e-02,  1.6927e-02]],

               [[-6.5512e-02, -5.6648e-02,  4.6444e-02,  7.4754e-02,  1.1064e-02],
                [ 6.3824e-02,  6.5449e-02, -4.9581e-02, -1.4683e-02, -4.2090e-02],
                [-4.7219e-02, -3.5248e-02,  3.6657e-04, -2.6233e-02,  7.6344e-02],
                [-6.2184e-02, -6.0576e-02,  8.0155e-02, -4.5559e-02,  5.8270e-02],
                [ 3.0984e-02,  5.9146e-02,  5.2023e-02, -8.1066e-02,  1.9684e-02]]],


              [[[ 7.3765e-02,  3.5141e-03, -2.1752e-03,  7.4100e-02,  7.4641e-02],
                [ 4.8591e-02,  7.6598e-02,  2.9565e-02,  2.1044e-02, -5.0159e-02],
                [-6.1025e-02, -2.8593e-02,  2.4891e-02, -6.0556e-02,  6.2090e-02],
                [-1.5755e-02,  4.0377e-02,  3.5295e-03,  5.3803e-02, -6.6197e-02],
                [ 4.4328e-02, -1.8031e-02, -7.4217e-02, -9.6550e-03, -6.4711e-03]],

               [[-7.8584e-02, -4.6375e-02, -3.3496e-02, -6.2696e-02,  3.7588e-02],
                [-1.4592e-02, -2.4390e-02, -5.0144e-02, -5.0049e-02,  3.2448e-02],
                [ 7.5607e-02, -6.8160e-02, -1.8825e-02,  6.0839e-02,  2.1488e-02],
                [-6.6120e-02,  6.7261e-02, -5.8349e-02,  4.2199e-02, -5.2458e-02],
                [-4.8494e-02, -5.4509e-02,  9.2273e-05, -3.4371e-02,  5.9357e-04]],

               [[-7.8414e-02,  7.0503e-02, -3.8403e-02, -4.4648e-02,  7.5146e-02],
                [-5.6907e-02,  3.4213e-05, -6.8368e-02, -5.0295e-02, -4.3787e-02],
                [-5.8034e-02, -6.0247e-02, -5.8353e-02, -2.5280e-02,  6.9298e-02],
                [-1.4764e-03,  4.3781e-02, -2.5240e-02, -4.3813e-02,  2.8583e-02],
                [-4.1418e-02, -4.6101e-02,  2.7372e-03, -3.2267e-02,  5.5665e-02]],

               [[-7.9642e-02,  5.8555e-02,  5.8883e-02, -6.3063e-02, -3.0692e-02],
                [ 5.0561e-02, -6.8027e-02,  4.7893e-02, -6.2822e-02,  1.0364e-02],
                [-6.9147e-02, -7.2420e-02,  7.8318e-02,  6.8844e-02,  5.5427e-03],
                [ 6.0117e-02, -3.5221e-02,  7.4160e-02, -1.2965e-02,  6.2557e-02],
                [ 4.4865e-04, -6.0832e-02,  7.6485e-03,  6.0057e-02,  2.3943e-02]],

               [[-2.4230e-02,  6.7864e-02,  7.2121e-02,  3.4157e-02, -7.9538e-02],
                [-3.3560e-02,  1.1223e-02, -1.7517e-02, -5.4311e-02,  6.5203e-02],
                [-7.7116e-02, -6.3548e-02,  3.0147e-02, -5.6751e-02, -4.5155e-02],
                [ 7.9671e-02, -6.0086e-02, -6.9947e-02,  1.2981e-02, -2.5239e-02],
                [-2.4339e-03,  7.9699e-02,  2.4683e-02, -4.5459e-02,  9.2444e-03]],

               [[ 4.2156e-02, -5.2864e-02,  7.0858e-02,  7.3412e-02, -1.2266e-02],
                [ 4.9417e-02,  4.1968e-02, -6.1616e-02, -1.7818e-02, -2.4427e-02],
                [ 2.1473e-02, -1.1306e-02,  4.7061e-02, -2.3690e-02, -6.6106e-02],
                [ 8.0653e-02, -6.1394e-02, -9.6741e-03,  3.2104e-02, -3.5300e-02],
                [-3.7291e-02, -1.6968e-02,  1.3973e-02,  5.5290e-02,  4.0622e-03]]],


              [[[ 2.2983e-02,  1.0195e-02,  4.0193e-02, -7.5025e-02, -1.5421e-03],
                [ 6.0510e-02,  1.8363e-02, -3.8094e-02,  4.5445e-02, -2.9622e-03],
                [-5.3075e-02,  3.4295e-02,  5.0654e-02, -1.9342e-02,  3.2361e-02],
                [ 7.3380e-02, -1.0884e-02, -5.3324e-02, -5.0394e-02,  3.1872e-02],
                [-4.9773e-02, -2.8900e-02, -2.6879e-02, -2.8097e-02,  6.6398e-02]],

               [[-2.5804e-02, -4.6905e-02, -4.2523e-02, -5.0381e-02, -2.0208e-02],
                [-4.8815e-02,  2.2532e-02,  6.9881e-02,  2.1225e-02,  4.3858e-02],
                [ 3.8482e-02,  7.5890e-03,  1.4969e-02,  1.8850e-02, -2.9226e-02],
                [ 3.2008e-02,  2.3823e-02,  7.4640e-02, -3.2508e-02, -5.9983e-02],
                [ 2.0176e-02, -2.4570e-02, -4.5569e-02,  6.8979e-02, -4.4682e-02]],

               [[-4.6043e-02, -3.2025e-02,  2.8151e-02, -4.8214e-02,  7.1796e-02],
                [ 6.4546e-02, -2.4392e-03,  6.5753e-02,  1.0233e-02,  3.1131e-02],
                [-2.9636e-02,  7.3704e-02, -7.7781e-02,  8.0873e-02, -8.5714e-04],
                [ 3.0164e-02, -6.3718e-02,  8.1516e-02,  1.4934e-02,  6.1368e-02],
                [ 4.9876e-02,  9.5079e-03,  7.7987e-02, -3.8619e-02,  5.9578e-02]],

               [[-4.2667e-03, -4.3751e-02, -2.8233e-02, -4.1592e-02, -6.0769e-02],
                [ 4.8450e-02, -1.2896e-02, -4.6601e-03, -6.4206e-03, -1.1431e-02],
                [-7.7503e-02, -1.0773e-02, -8.0490e-03, -3.1181e-03,  1.3285e-02],
                [ 6.8160e-02,  6.7861e-02,  1.4774e-02, -1.3599e-02, -2.4826e-02],
                [ 4.7705e-02, -7.3378e-02, -1.3825e-02, -3.1077e-02, -3.7972e-02]],

               [[ 5.6202e-02,  4.1837e-02, -2.8773e-02,  7.2132e-02,  2.5515e-02],
                [-3.3901e-02,  7.0663e-03,  6.8551e-02, -9.9631e-03,  3.0268e-02],
                [-7.2736e-02,  2.3806e-02,  2.9459e-02,  6.0356e-02, -4.3899e-03],
                [-2.8373e-02, -1.1198e-02, -1.2081e-02,  1.7582e-02, -6.8985e-02],
                [-7.2981e-02, -5.7756e-02, -2.6205e-02,  5.1297e-02, -5.9565e-02]],

               [[ 6.5016e-03,  6.3833e-02,  1.0038e-02, -7.4870e-02, -8.0244e-02],
                [-1.2594e-02,  2.0456e-02, -4.8056e-02,  1.1776e-02, -7.6218e-02],
                [ 5.1599e-02, -9.6947e-03, -1.6420e-02, -5.6519e-02, -7.0490e-02],
                [-5.8097e-02,  7.7906e-03, -3.3099e-03,  1.3623e-02, -5.2314e-02],
                [-1.3419e-02,  6.9299e-02,  7.2203e-02,  4.8437e-02, -4.1129e-02]]]])
      (bias): Normal:
       loc: tensor([-0., 0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., 0., 0., -0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0365,  0.0072,  0.0093,  0.0025,  0.0342,  0.0735,  0.0308,  0.0723,
              -0.0758,  0.0133,  0.0423,  0.0314, -0.0622,  0.0366,  0.0547, -0.0157])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0290,  0.0450,  0.0432,  ..., -0.0188,  0.0460,  0.0460],
              [-0.0440,  0.0107, -0.0302,  ..., -0.0457,  0.0447, -0.0215],
              [-0.0185,  0.0321, -0.0076,  ..., -0.0138,  0.0309,  0.0219],
              ...,
              [ 0.0011,  0.0479,  0.0438,  ...,  0.0297, -0.0310,  0.0051],
              [ 0.0461, -0.0489,  0.0107,  ..., -0.0152,  0.0484, -0.0418],
              [ 0.0300, -0.0080,  0.0244,  ..., -0.0223,  0.0077, -0.0147]],
             requires_grad=True)
       tensor: tensor([[-0.0771,  0.0682,  0.0952,  ..., -0.0025,  0.0747,  0.0905],
              [-0.0600, -0.0432, -0.0064,  ..., -0.0666,  0.0176,  0.0168],
              [ 0.0137,  0.0625,  0.0067,  ...,  0.0625,  0.0137,  0.1347],
              ...,
              [ 0.0379, -0.0104,  0.0276,  ...,  0.0230, -0.0285,  0.0612],
              [ 0.0619, -0.0173,  0.0673,  ...,  0.0148,  0.0073,  0.0058],
              [ 0.1118, -0.0446, -0.0601,  ...,  0.0168, -0.0052, -0.0763]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 3.4268e-02, -4.9700e-02,  1.3390e-02, -2.0985e-02, -1.6808e-02,
               2.4641e-02, -4.8857e-02,  3.6003e-02,  1.0424e-02, -2.8776e-02,
               4.2286e-02, -2.7017e-02,  4.5771e-03,  2.5405e-02,  1.3040e-02,
               4.0760e-02,  3.4614e-02,  9.6058e-03, -1.0389e-02,  1.7939e-02,
               2.7261e-02,  1.0442e-02, -4.0359e-02,  6.0608e-04, -2.8432e-02,
              -2.3294e-02,  2.5094e-02, -2.7444e-02, -3.9087e-02,  4.8333e-02,
              -3.8067e-02,  4.3731e-02, -1.1337e-02,  6.5628e-03, -3.9661e-02,
              -3.0479e-02,  4.2353e-02, -2.1659e-02,  1.7937e-02,  4.8377e-02,
              -7.3253e-03, -2.8690e-02, -2.0960e-02, -2.0081e-02, -4.1321e-02,
              -2.5133e-02, -3.6166e-02,  3.5816e-02,  3.0482e-03, -4.1686e-02,
               3.5028e-02,  3.1139e-02, -2.5572e-03, -8.7303e-03,  4.4109e-02,
               1.0638e-02,  3.5676e-02,  1.9173e-02,  2.6669e-02, -3.7786e-02,
              -1.8936e-02, -1.0854e-05,  3.6708e-02, -2.0134e-02, -4.2009e-02,
               3.9515e-02, -3.7057e-02,  2.5732e-02, -2.7906e-02,  3.4639e-02,
               4.3312e-03,  1.9998e-02, -2.2635e-02,  2.4380e-02, -1.2083e-02,
               3.3281e-02,  4.6888e-02,  3.9643e-02,  1.1001e-02,  4.8635e-02,
              -1.2259e-03,  3.2456e-02, -1.0144e-02, -6.7254e-03,  1.5385e-03,
               1.6456e-02, -3.1296e-02, -1.5280e-02,  3.8949e-02, -1.9992e-02,
              -2.8284e-02,  4.0517e-02, -9.7160e-03,  7.7505e-04,  1.6642e-02,
               4.0081e-02,  9.7422e-04,  3.7670e-02,  4.7919e-02,  4.4317e-02,
               1.0160e-02,  1.9730e-02, -8.4861e-03,  3.3960e-02, -2.3660e-02,
              -4.4850e-02, -3.0455e-02, -1.5874e-03,  1.4935e-02,  2.1578e-02,
               4.7099e-02, -3.9308e-02,  1.2687e-02, -3.1501e-02,  2.6750e-02,
               1.3101e-02, -4.7801e-02,  4.5735e-02,  2.1969e-02,  1.0239e-02],
             requires_grad=True)
       tensor: tensor([ 0.0461, -0.0686, -0.0018, -0.0174, -0.1039,  0.0644,  0.0129, -0.0319,
               0.0540, -0.0665,  0.0386, -0.0862, -0.0050, -0.0043,  0.0145,  0.0309,
               0.1006,  0.0050, -0.0034,  0.0625, -0.0306,  0.0265, -0.1264, -0.0380,
               0.0116, -0.0690,  0.0944, -0.0293, -0.0990,  0.1073,  0.0308,  0.0658,
              -0.0371,  0.0542, -0.0284, -0.0286, -0.0326, -0.0041,  0.0036,  0.0200,
              -0.0654, -0.0131, -0.0922, -0.0521, -0.0524,  0.0307,  0.0191,  0.0887,
               0.0523, -0.0082,  0.0263,  0.0368, -0.0741,  0.0028,  0.0159,  0.0315,
              -0.0049,  0.0991,  0.0387, -0.0024, -0.0229,  0.1443, -0.0206, -0.0440,
               0.0354,  0.0254, -0.0601,  0.0427, -0.0854,  0.0387,  0.0435,  0.0373,
               0.0068,  0.0281,  0.0257, -0.0699, -0.0028,  0.0847,  0.0240, -0.0674,
              -0.0045,  0.1192,  0.0438, -0.0321,  0.0262, -0.0189, -0.0056, -0.0746,
               0.0501,  0.0481, -0.1232, -0.0209,  0.0074, -0.0143, -0.0851,  0.1050,
               0.0118,  0.0893,  0.0022, -0.0285, -0.0378, -0.0167, -0.0383,  0.0616,
              -0.0673, -0.0795, -0.0137,  0.0920,  0.0490,  0.0294,  0.1110,  0.0153,
               0.0588, -0.0804,  0.0791,  0.1213,  0.0279,  0.0537,  0.0901, -0.0692],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., 0., 0.,  ..., -0., 0., 0.],
              [-0., 0., -0.,  ..., -0., 0., -0.],
              [-0., 0., -0.,  ..., -0., 0., 0.],
              ...,
              [0., 0., 0.,  ..., 0., -0., 0.],
              [0., -0., 0.,  ..., -0., 0., -0.],
              [0., -0., 0.,  ..., -0., 0., -0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0290,  0.0450,  0.0432,  ..., -0.0188,  0.0460,  0.0460],
              [-0.0440,  0.0107, -0.0302,  ..., -0.0457,  0.0447, -0.0215],
              [-0.0185,  0.0321, -0.0076,  ..., -0.0138,  0.0309,  0.0219],
              ...,
              [ 0.0011,  0.0479,  0.0438,  ...,  0.0297, -0.0310,  0.0051],
              [ 0.0461, -0.0489,  0.0107,  ..., -0.0152,  0.0484, -0.0418],
              [ 0.0300, -0.0080,  0.0244,  ..., -0.0223,  0.0077, -0.0147]])
      (bias): Normal:
       loc: tensor([0., -0., 0., -0., -0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., 0., 0., 0., -0., 0., 0., 0., -0., 0.,
              -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., -0., 0., 0., -0., -0., -0., -0., -0., -0., -0., 0.,
              0., -0., 0., 0., -0., -0., 0., 0., 0., 0., 0., -0., -0., -0., 0., -0., -0., 0., -0., 0., -0., 0., 0., 0.,
              -0., 0., -0., 0., 0., 0., 0., 0., -0., 0., -0., -0., 0., 0., -0., -0., 0., -0., -0., 0., -0., 0., 0., 0.,
              0., 0., 0., 0., 0., 0., -0., 0., -0., -0., -0., -0., 0., 0., 0., -0., 0., -0., 0., 0., -0., 0., 0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 3.4268e-02, -4.9700e-02,  1.3390e-02, -2.0985e-02, -1.6808e-02,
               2.4641e-02, -4.8857e-02,  3.6003e-02,  1.0424e-02, -2.8776e-02,
               4.2286e-02, -2.7017e-02,  4.5771e-03,  2.5405e-02,  1.3040e-02,
               4.0760e-02,  3.4614e-02,  9.6058e-03, -1.0389e-02,  1.7939e-02,
               2.7261e-02,  1.0442e-02, -4.0359e-02,  6.0608e-04, -2.8432e-02,
              -2.3294e-02,  2.5094e-02, -2.7444e-02, -3.9087e-02,  4.8333e-02,
              -3.8067e-02,  4.3731e-02, -1.1337e-02,  6.5628e-03, -3.9661e-02,
              -3.0479e-02,  4.2353e-02, -2.1659e-02,  1.7937e-02,  4.8377e-02,
              -7.3253e-03, -2.8690e-02, -2.0960e-02, -2.0081e-02, -4.1321e-02,
              -2.5133e-02, -3.6166e-02,  3.5816e-02,  3.0482e-03, -4.1686e-02,
               3.5028e-02,  3.1139e-02, -2.5572e-03, -8.7303e-03,  4.4109e-02,
               1.0638e-02,  3.5676e-02,  1.9173e-02,  2.6669e-02, -3.7786e-02,
              -1.8936e-02, -1.0854e-05,  3.6708e-02, -2.0134e-02, -4.2009e-02,
               3.9515e-02, -3.7057e-02,  2.5732e-02, -2.7906e-02,  3.4639e-02,
               4.3312e-03,  1.9998e-02, -2.2635e-02,  2.4380e-02, -1.2083e-02,
               3.3281e-02,  4.6888e-02,  3.9643e-02,  1.1001e-02,  4.8635e-02,
              -1.2259e-03,  3.2456e-02, -1.0144e-02, -6.7254e-03,  1.5385e-03,
               1.6456e-02, -3.1296e-02, -1.5280e-02,  3.8949e-02, -1.9992e-02,
              -2.8284e-02,  4.0517e-02, -9.7160e-03,  7.7505e-04,  1.6642e-02,
               4.0081e-02,  9.7422e-04,  3.7670e-02,  4.7919e-02,  4.4317e-02,
               1.0160e-02,  1.9730e-02, -8.4861e-03,  3.3960e-02, -2.3660e-02,
              -4.4850e-02, -3.0455e-02, -1.5874e-03,  1.4935e-02,  2.1578e-02,
               4.7099e-02, -3.9308e-02,  1.2687e-02, -3.1501e-02,  2.6750e-02,
               1.3101e-02, -4.7801e-02,  4.5735e-02,  2.1969e-02,  1.0239e-02])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=10, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0853, -0.0135,  0.0683,  ...,  0.0101, -0.0386,  0.0218],
              [ 0.0366, -0.0514,  0.0796,  ..., -0.0416, -0.0514,  0.0824],
              [-0.0907,  0.0680,  0.0275,  ..., -0.0242,  0.0592, -0.0864],
              ...,
              [-0.0862,  0.0863, -0.0316,  ...,  0.0718,  0.0438,  0.0558],
              [-0.0711, -0.0183, -0.0767,  ...,  0.0480,  0.0798, -0.0622],
              [ 0.0360, -0.0011, -0.0885,  ...,  0.0526,  0.0213, -0.0500]],
             requires_grad=True)
       tensor: tensor([[ 0.0446, -0.0119,  0.0613,  ...,  0.0736, -0.1029, -0.0459],
              [ 0.0366, -0.1579,  0.0943,  ...,  0.0078, -0.0376,  0.0493],
              [-0.1436,  0.0732,  0.0015,  ..., -0.0362, -0.0015, -0.0596],
              ...,
              [-0.0071,  0.0492,  0.0063,  ...,  0.0985,  0.0816,  0.0473],
              [-0.1276,  0.0231, -0.0409,  ...,  0.0153,  0.0586, -0.0618],
              [ 0.0311, -0.0013, -0.0697,  ...,  0.0642, -0.0082, -0.0596]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0895,  0.0306, -0.0073,  0.0383,  0.0428, -0.0883,  0.0808, -0.0065,
               0.0228,  0.0389], requires_grad=True)
       tensor: tensor([-0.0167,  0.0773,  0.0665,  0.0994,  0.0772, -0.0356, -0.0222, -0.0223,
              -0.0587,  0.0343], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., -0., 0.,  ..., 0., -0., 0.],
              [0., -0., 0.,  ..., -0., -0., 0.],
              [-0., 0., 0.,  ..., -0., 0., -0.],
              ...,
              [-0., 0., -0.,  ..., 0., 0., 0.],
              [-0., -0., -0.,  ..., 0., 0., -0.],
              [0., -0., -0.,  ..., 0., 0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              ...,
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0853, -0.0135,  0.0683,  ...,  0.0101, -0.0386,  0.0218],
              [ 0.0366, -0.0514,  0.0796,  ..., -0.0416, -0.0514,  0.0824],
              [-0.0907,  0.0680,  0.0275,  ..., -0.0242,  0.0592, -0.0864],
              ...,
              [-0.0862,  0.0863, -0.0316,  ...,  0.0718,  0.0438,  0.0558],
              [-0.0711, -0.0183, -0.0767,  ...,  0.0480,  0.0798, -0.0622],
              [ 0.0360, -0.0011, -0.0885,  ...,  0.0526,  0.0213, -0.0500]])
      (bias): Normal:
       loc: tensor([0., 0., -0., 0., 0., -0., 0., -0., 0., 0.])
       scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
              0.3162])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0895,  0.0306, -0.0073,  0.0383,  0.0428, -0.0883,  0.0808, -0.0065,
               0.0228,  0.0389])
    )
    (observed): Observed()
  )
)

You just have to define the forward function, and the backward function (where gradients are computed) is automatically defined for you using autograd. You can use any of the Tensor operations in the forward function.

The learnable parameters of a model are returned by net.parameters()

params = list(net.parameters())
print(len(params))
print(params[0].size())

Out:

16
torch.Size([6, 1, 5, 5])

Let try a random 32x32 input

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

Out:

tensor([[ 0.4980,  0.6893,  0.0096, -0.4008, -0.1188, -0.2803,  0.3551,  0.4623,
         -0.2278, -0.0520]], grad_fn=<AddmmBackward0>)

Zero the gradient buffers of all parameters and backprops with random gradients:

net.zero_grad()
out.backward(torch.randn(1, 10))

Note

borch.nn only supports mini-batches. The entire borch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.

For example, nn.Conv2d will take in a 4D Tensor of nSamples x nChannels x Height x Width.

If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

Before proceeding further, let’s recap all the classes you’ve seen so far.

Recap:
  • torch.Tensor - A multi-dimensional array with support for autograd operations like backward(). Also holds the gradient w.r.t. the tensor.

  • nn.Module - Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.

  • nn.Parameter - A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to a Module.

  • autograd.Function - Implements forward and backward definitions of an autograd operation. Every Tensor operation, creates at least a single Function node, that connects to functions that created a Tensor and encodes its history.

At this point, we covered:
  • Defining a neural network

  • Processing inputs and calling backward

Still Left:
  • Computing the loss

  • Updating the weights of the network

Loss Function

A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.

There are several different loss functions under the nn package . A simple loss is: nn.MSELoss which computes the mean-squared error between the input and the target. They are how ever only equivalent to an maximum likelihood approach in deep learning.

In order to infer the posterior of the weights and thus capture the uncertainty of the weights as well, we have to use the infer package. In this example we will use infer.vi_loss function that automatically creates the best loss function for variational inference given the latent variables in your model.

Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:

target = torch.randint(10, (1,))  # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)

Out:

tensor(9549.6875, grad_fn=<AddBackward0>)

Now, if you would follow loss in the backward direction you will see a graph of computations that looks like this:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
      -> view -> linear -> relu -> linear ->
      -> loss

So, when we call loss.backward(), the whole graph is differentiated w.r.t. the loss, and all Tensors in the graph that has requires_grad=True will have their .grad Tensor accumulated with the gradient.

Backprop

To backpropagate the error all we have to do is to loss.backward(). You need to clear the existing gradients though, else gradients will be accumulated to existing gradients.

Now we shall call loss.backward(), and have a look at conv1’s bias gradients before and after the backward.

net.zero_grad()  # zeroes the gradient buffers of all parameters

The value for the loc paramater of the approximating distribution of conv1.bias zeroing the gradients is

print(net.conv1.posterior.bias.loc.grad)

loss.backward()

Out:

tensor([0., 0., 0., 0., 0., 0.])

after calling backward the value is

print(net.conv1.posterior.bias.loc.grad)

Out:

tensor([ 0.6376,  1.3013, -0.0847,  0.4150,  1.0615,  0.2077])

The only thing left to learn is:

  • Updating the weights of the network

Update the weights

The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):

weight = weight - learning_rate * gradient

We can implement this using simple python code:

learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

However, as you use neural networks, you want to use various different update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc. To enable this, torch built a small package: torch.optim that implements all these methods. Using it is very simple:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
n_batch_epoch = 10  # number of batches per epoch usually len(dataloader)
optimizer.zero_grad()  # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step()  # Does the update

Exercises

  1. The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.

  2. Try designing yor own feed forward networks with two different types of non lineareties ex. relu

Total running time of the script: ( 0 minutes 0.126 seconds)

Gallery generated by Sphinx-Gallery