Neural Networks

Neural networks can be constructed using the borch.nn package.

Now that you’ve had a glimpse of autograd, nn depends on autograd to define models and differentiate them. An nn.Module contains layers, and a method forward(input) that returns the output.

For example, look at this network that classifies digit images:

convnet

convnet

It is a simple feed-forward network. It takes an input, feeds it through several layers one after the other, and then finally gives the output.

A typical training procedure for a neural network is as follows:

  • Define a network that has some learnable parameters and/or randomVariables

  • For each batch in a dataset, do:

    • Process the input data through the network

    • Compute the loss (how far is the output from being correct?)

    • Propagate gradients back into the network’s parameters

    • Update the weights of the network, typically using a simple update rule: weight = weight - learning_rate * gradient

Define the network

Let’s define this network:

import torch
import torch.nn.functional as F
import borch
from borch import distributions, posterior, nn, infer


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__(posterior=posterior.Automatic())
        # 1 input image channel, 6 output channels, 5x5 convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)

        # 6 input channels, 16 output channels, 5x5 convolution kernel
        self.conv2 = nn.Conv2d(6, 16, 5)

        # An affine operation: y = Wx + b
        # NB after two convolutional operations with 5x5 kernels and no padding,
        # the spatial dimension of an image with intial dimension 32x32 is
        # 5x5 (with 16 channels)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        # Specifying the likelihood function
        self.classification = distributions.Categorical(logits=x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


net = Net()
print(net)

Out:

Net(
  (posterior): Automatic()
  (prior): Module()
  (observed): Observed()
  (conv1): Conv2d(
    1, 6, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[-0.0991, -0.0592, -0.1623,  0.0523,  0.1645],
                [-0.1893, -0.0215,  0.1007, -0.1699, -0.1765],
                [ 0.1281, -0.0728, -0.1345, -0.0970,  0.0558],
                [-0.1093,  0.1952, -0.0354,  0.0050,  0.0558],
                [-0.1692,  0.1254, -0.0014,  0.1250,  0.1932]]],


              [[[ 0.0767,  0.0089, -0.1279, -0.1135,  0.1844],
                [-0.1334,  0.0288,  0.1844,  0.0812, -0.1021],
                [-0.1645,  0.1263, -0.1469, -0.0918, -0.1910],
                [ 0.0580,  0.0896, -0.1576,  0.0805, -0.1268],
                [ 0.1165,  0.0309, -0.0574,  0.0782,  0.0083]]],


              [[[-0.1977, -0.0520, -0.0045,  0.1165, -0.0447],
                [-0.0220,  0.1359,  0.0219, -0.1281, -0.0446],
                [-0.1952,  0.0164,  0.1802,  0.0125, -0.1155],
                [-0.1422, -0.1467,  0.0933, -0.1485, -0.1893],
                [-0.0724,  0.0799, -0.1816,  0.1123, -0.0929]]],


              [[[-0.0213, -0.1834, -0.0827,  0.0897,  0.1789],
                [ 0.1426, -0.0666,  0.0739, -0.0140, -0.0703],
                [ 0.1502, -0.0841, -0.0333,  0.0246, -0.0301],
                [-0.0028,  0.1863,  0.1322,  0.0346, -0.1045],
                [ 0.1198,  0.1824,  0.1617, -0.1592, -0.1136]]],


              [[[-0.0311,  0.1311, -0.1360,  0.1307,  0.0552],
                [ 0.1491,  0.1264,  0.0093,  0.0053,  0.1045],
                [-0.1270, -0.1262, -0.1517, -0.1297,  0.1289],
                [-0.0223,  0.0269, -0.1771,  0.0697, -0.0873],
                [ 0.1871,  0.1114, -0.1391,  0.0208, -0.0346]]],


              [[[-0.1517,  0.0542,  0.0801, -0.1400, -0.1205],
                [-0.0011, -0.1382, -0.0382, -0.0469,  0.0242],
                [ 0.0768,  0.0658,  0.0869, -0.1353,  0.0962],
                [ 0.1768,  0.1289,  0.0925,  0.0991, -0.1173],
                [ 0.0856, -0.0907,  0.1015, -0.1604, -0.0407]]]], requires_grad=True)
       tensor: tensor([[[[-0.0614,  0.0123, -0.1081,  0.1236,  0.1454],
                [-0.1888,  0.0317,  0.0477, -0.1962, -0.1923],
                [ 0.1091,  0.0490, -0.1451, -0.2202, -0.0304],
                [-0.1762,  0.1958,  0.0447, -0.0126, -0.0172],
                [-0.2533,  0.0799, -0.0985,  0.0271,  0.2088]]],


              [[[ 0.0831,  0.0297, -0.1840, -0.0885,  0.1750],
                [-0.1555,  0.0938,  0.1390,  0.2179, -0.1170],
                [-0.1460,  0.1144, -0.1382, -0.1290, -0.1353],
                [ 0.1784,  0.0510, -0.1081,  0.0149, -0.1268],
                [ 0.1006,  0.0944, -0.0335,  0.0982,  0.0166]]],


              [[[-0.1809,  0.0275,  0.0079,  0.1430, -0.0050],
                [-0.0176,  0.1644,  0.0566, -0.2271, -0.0801],
                [-0.2015,  0.0562,  0.1753,  0.0234, -0.1854],
                [-0.1811, -0.1378,  0.1392, -0.1488, -0.2823],
                [-0.2067, -0.0193, -0.1639,  0.1127, -0.1008]]],


              [[[ 0.0130, -0.2083, -0.1244,  0.0793,  0.2028],
                [ 0.2372, -0.0538,  0.1101,  0.0460, -0.0670],
                [ 0.1380, -0.0999,  0.0380,  0.1050, -0.0276],
                [ 0.0336,  0.1392,  0.0764,  0.0153, -0.1006],
                [ 0.1250,  0.1943,  0.2652, -0.0947, -0.1738]]],


              [[[-0.0794,  0.0738, -0.0990,  0.1658,  0.1148],
                [ 0.2231,  0.0902,  0.0110,  0.0146,  0.1010],
                [-0.1262, -0.2206, -0.1288, -0.1480,  0.1115],
                [-0.0217,  0.0519, -0.1776,  0.0722, -0.0266],
                [ 0.1953,  0.0690, -0.1989, -0.0013, -0.0624]]],


              [[[-0.1406,  0.0837,  0.1120, -0.1516, -0.0857],
                [ 0.0686, -0.1454, -0.0464, -0.0448,  0.0909],
                [ 0.1257,  0.0885,  0.0548, -0.1890,  0.0385],
                [ 0.2158,  0.0925,  0.1184,  0.1180, -0.0726],
                [ 0.0904, -0.0799,  0.1279, -0.1520,  0.0205]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0725, -0.0897,  0.0513, -0.0100, -0.1734,  0.0447],
             requires_grad=True)
       tensor: tensor([-0.1067, -0.1109,  0.0153, -0.0572, -0.1683, -0.0096],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[-0., -0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., -0., 0., 0.]]],


              [[[0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [0., 0., -0., 0., 0.]]],


              [[[-0., -0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [-0., 0., -0., 0., -0.]]],


              [[[-0., -0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., 0., -0., -0.]]],


              [[[-0., 0., -0., 0., 0.],
                [0., 0., 0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., -0., 0., -0.],
                [0., 0., -0., 0., -0.]]],


              [[[-0., 0., 0., -0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., 0., 0., -0., 0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., -0., -0.]]]])
       scale: tensor([[[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]],


              [[[0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000],
                [0.2000, 0.2000, 0.2000, 0.2000, 0.2000]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[-0.0991, -0.0592, -0.1623,  0.0523,  0.1645],
                [-0.1893, -0.0215,  0.1007, -0.1699, -0.1765],
                [ 0.1281, -0.0728, -0.1345, -0.0970,  0.0558],
                [-0.1093,  0.1952, -0.0354,  0.0050,  0.0558],
                [-0.1692,  0.1254, -0.0014,  0.1250,  0.1932]]],


              [[[ 0.0767,  0.0089, -0.1279, -0.1135,  0.1844],
                [-0.1334,  0.0288,  0.1844,  0.0812, -0.1021],
                [-0.1645,  0.1263, -0.1469, -0.0918, -0.1910],
                [ 0.0580,  0.0896, -0.1576,  0.0805, -0.1268],
                [ 0.1165,  0.0309, -0.0574,  0.0782,  0.0083]]],


              [[[-0.1977, -0.0520, -0.0045,  0.1165, -0.0447],
                [-0.0220,  0.1359,  0.0219, -0.1281, -0.0446],
                [-0.1952,  0.0164,  0.1802,  0.0125, -0.1155],
                [-0.1422, -0.1467,  0.0933, -0.1485, -0.1893],
                [-0.0724,  0.0799, -0.1816,  0.1123, -0.0929]]],


              [[[-0.0213, -0.1834, -0.0827,  0.0897,  0.1789],
                [ 0.1426, -0.0666,  0.0739, -0.0140, -0.0703],
                [ 0.1502, -0.0841, -0.0333,  0.0246, -0.0301],
                [-0.0028,  0.1863,  0.1322,  0.0346, -0.1045],
                [ 0.1198,  0.1824,  0.1617, -0.1592, -0.1136]]],


              [[[-0.0311,  0.1311, -0.1360,  0.1307,  0.0552],
                [ 0.1491,  0.1264,  0.0093,  0.0053,  0.1045],
                [-0.1270, -0.1262, -0.1517, -0.1297,  0.1289],
                [-0.0223,  0.0269, -0.1771,  0.0697, -0.0873],
                [ 0.1871,  0.1114, -0.1391,  0.0208, -0.0346]]],


              [[[-0.1517,  0.0542,  0.0801, -0.1400, -0.1205],
                [-0.0011, -0.1382, -0.0382, -0.0469,  0.0242],
                [ 0.0768,  0.0658,  0.0869, -0.1353,  0.0962],
                [ 0.1768,  0.1289,  0.0925,  0.0991, -0.1173],
                [ 0.0856, -0.0907,  0.1015, -0.1604, -0.0407]]]])
      (bias): Normal:
       loc: tensor([-0., -0., 0., -0., -0., 0.])
       scale: tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.4082, 0.4082])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0725, -0.0897,  0.0513, -0.0100, -0.1734,  0.0447])
    )
    (observed): Observed()
  )
  (conv2): Conv2d(
    6, 16, kernel_size=(5, 5), stride=(1, 1)
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              ...,


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]],


              [[[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]],

               [[0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
                [0.0498, 0.0498, 0.0498, 0.0498, 0.0498]]]], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[[[ 2.4966e-03, -6.7341e-02,  6.2799e-02, -3.4370e-03,  4.5322e-02],
                [ 1.4607e-02,  6.2478e-02,  6.8280e-02, -3.5654e-02, -1.8825e-02],
                [ 7.1212e-02,  1.5050e-02, -1.8025e-02,  2.5491e-02,  2.8724e-02],
                [-2.2547e-02,  6.0692e-02,  5.2667e-02, -5.7841e-03,  2.6376e-02],
                [-4.3343e-02, -5.9860e-02, -1.2973e-02,  6.2583e-03,  8.0137e-02]],

               [[-2.1569e-02, -7.1375e-02,  1.2631e-02, -1.8075e-02, -4.7912e-02],
                [-2.8027e-02, -2.6252e-02, -5.9709e-02,  6.7074e-02,  2.6864e-02],
                [-6.6509e-02, -8.1214e-02,  2.3467e-02, -6.6464e-02, -6.0669e-02],
                [ 5.0579e-03,  1.1671e-02, -6.6733e-02, -4.9352e-02, -2.0956e-02],
                [ 7.6966e-03,  2.2004e-02, -2.8989e-02,  5.1644e-02, -6.1621e-02]],

               [[-5.4252e-02, -8.1522e-02, -6.2672e-02, -6.3098e-02, -5.2104e-02],
                [ 2.5544e-02,  3.6878e-02, -4.4395e-02, -4.9102e-02,  4.8144e-02],
                [-3.2576e-02,  8.0204e-02, -9.2122e-03, -7.7311e-03,  6.7384e-02],
                [-1.5682e-02,  9.5757e-03, -6.5708e-02, -2.7689e-02, -6.5013e-02],
                [-6.4995e-02,  2.7979e-02, -1.6400e-03,  4.6764e-02, -4.3071e-02]],

               [[ 5.6996e-02,  5.9034e-03, -4.6686e-02, -6.9235e-02,  1.9127e-03],
                [ 1.8545e-02, -1.9516e-02, -5.0705e-02, -3.5648e-02,  6.0718e-02],
                [-6.7171e-02, -7.3959e-02,  5.4578e-02, -1.1379e-02,  4.7259e-02],
                [ 1.8278e-02,  1.0828e-02,  2.9502e-02, -4.8256e-02,  1.5783e-02],
                [ 3.0306e-02, -5.6080e-03,  1.1095e-02, -8.0521e-02,  6.7744e-02]],

               [[ 6.4667e-03, -3.1020e-02, -5.5962e-02, -4.0058e-02, -3.4994e-02],
                [-4.6501e-02, -5.0324e-02, -7.3864e-02,  1.8048e-02, -5.4226e-02],
                [-7.8412e-02, -4.8003e-02, -1.1779e-02, -7.8603e-02, -3.8339e-02],
                [-2.3014e-02,  5.4240e-02, -4.7970e-02,  2.5364e-02,  3.3201e-02],
                [ 6.0505e-02,  6.4193e-02,  7.4750e-04, -2.1983e-02,  7.6744e-02]],

               [[-6.7106e-02,  3.3650e-02, -7.8031e-02,  6.1907e-02, -7.8580e-02],
                [-3.2164e-02, -4.9644e-02,  5.5871e-02, -1.9880e-02, -5.9899e-02],
                [ 7.6375e-02,  1.6907e-02, -3.7618e-02, -4.6346e-02,  4.5166e-02],
                [-6.2424e-02,  1.6883e-02,  3.9223e-02, -1.4790e-02, -6.1880e-02],
                [ 3.8489e-02, -7.1589e-05, -1.7487e-02, -6.2649e-02,  6.8018e-02]]],


              [[[-2.5034e-02, -4.6058e-02, -2.3337e-02,  2.2215e-02, -1.4850e-02],
                [-4.3984e-02,  1.4747e-02, -7.9597e-03, -6.3069e-02, -2.2592e-02],
                [ 4.3075e-02,  1.9039e-02,  3.0107e-02,  2.3445e-03, -6.9232e-02],
                [ 6.4497e-03, -4.7578e-02,  7.6891e-02, -5.1616e-02,  1.2707e-02],
                [-3.6706e-02,  7.6937e-02, -5.9667e-02, -8.1446e-02,  1.7545e-02]],

               [[-4.4402e-02,  6.4599e-02,  4.1490e-02, -2.8644e-02, -4.8206e-02],
                [ 2.6031e-02,  5.7371e-02,  2.0308e-02, -7.4061e-02, -4.4877e-02],
                [-2.7835e-02,  5.2467e-02,  4.0908e-02, -8.1401e-02, -6.8110e-02],
                [ 7.6093e-02, -9.4644e-03, -5.2426e-03,  7.7898e-03, -1.6856e-02],
                [-7.1420e-02,  7.8023e-02,  8.0673e-02,  3.5987e-02,  7.1558e-02]],

               [[-2.2731e-02,  6.9536e-03, -3.5987e-02,  7.5870e-02, -7.4783e-02],
                [ 3.9431e-02, -3.5946e-02,  3.5435e-02,  3.6262e-02,  1.0886e-02],
                [-5.3538e-02,  7.1992e-02, -3.2372e-04, -1.5915e-02, -7.1567e-02],
                [ 2.4658e-02, -1.1047e-02, -5.5648e-02,  1.9212e-02, -1.2692e-02],
                [-8.4647e-03, -4.5983e-02, -6.7419e-04,  9.8715e-03, -4.6148e-02]],

               [[-2.0329e-02,  4.4603e-02,  6.4148e-02,  7.1504e-02, -4.5712e-02],
                [ 4.6664e-02,  7.8892e-02, -3.1212e-02, -8.2775e-03, -6.5268e-02],
                [-5.3630e-02, -3.4377e-02,  4.4915e-02, -3.5955e-02,  4.4886e-02],
                [-6.4155e-02, -1.4317e-02, -6.8112e-02, -2.0220e-02,  3.0197e-02],
                [-6.6971e-02,  7.1244e-02,  3.6331e-02,  4.9435e-02,  9.6099e-04]],

               [[ 7.0543e-02, -2.7817e-02,  6.0078e-02,  4.7984e-02, -4.3316e-02],
                [ 7.6346e-02,  1.2656e-02, -6.2175e-02,  7.6718e-02,  1.8006e-02],
                [ 7.4139e-02, -5.7349e-02, -1.0675e-02, -2.4720e-02, -4.4425e-02],
                [-7.0388e-02, -6.9265e-02,  2.1314e-02,  4.5109e-02, -7.1493e-02],
                [-3.2351e-02,  3.5873e-02, -4.8951e-03, -7.2958e-02, -4.4238e-02]],

               [[ 8.1172e-02, -5.8811e-02,  4.0991e-02,  6.9875e-02, -6.0488e-02],
                [-6.3614e-02,  2.7591e-02,  4.7980e-02,  8.1415e-02, -3.3863e-02],
                [ 2.1980e-02, -2.3815e-02,  1.8170e-02, -2.5447e-02,  2.7371e-02],
                [ 8.1534e-02, -3.9964e-02, -7.7124e-02,  5.2303e-02, -2.0802e-02],
                [ 6.0792e-02, -7.2300e-02, -2.8772e-02,  5.7216e-02, -3.4328e-02]]],


              [[[-2.7990e-03,  1.8808e-02,  6.7455e-04, -6.6091e-02,  6.6982e-02],
                [-6.9830e-02,  6.5788e-02, -4.8964e-02, -4.1140e-03,  3.1098e-03],
                [-5.4660e-02,  7.1862e-02,  7.1153e-03, -6.5570e-02, -5.1447e-02],
                [ 6.0571e-02,  2.0612e-03,  2.3722e-02,  1.5375e-02, -5.1639e-02],
                [ 4.7830e-02, -2.5379e-02,  6.9856e-02,  4.5925e-02,  9.6594e-03]],

               [[ 2.2661e-02,  6.6682e-03, -4.9346e-02,  2.0883e-03,  3.6717e-02],
                [-2.9341e-02, -7.5506e-02, -2.3876e-02, -4.3604e-02,  6.0555e-02],
                [-5.9189e-03, -9.2285e-03, -7.6425e-02,  5.5105e-02, -3.4522e-02],
                [ 4.5115e-02,  6.4247e-03,  4.4812e-02,  1.9783e-02,  7.9421e-02],
                [ 1.4393e-02,  6.9105e-02, -5.1842e-02, -5.6427e-04, -3.0378e-02]],

               [[-2.9724e-02,  2.7430e-02,  2.0606e-02, -6.8593e-02,  7.5719e-02],
                [-1.1574e-02, -6.1978e-02,  1.5036e-02,  6.5520e-02, -5.5500e-02],
                [-5.4038e-03,  1.6463e-03, -1.6045e-02,  5.0524e-02, -6.1660e-02],
                [ 7.9241e-02, -6.3365e-02, -7.6820e-02, -7.4036e-02, -3.8586e-02],
                [-2.4314e-02,  8.9899e-03, -7.1113e-03, -3.9968e-02, -5.5614e-02]],

               [[-2.1929e-02, -1.1871e-02, -4.8945e-02,  2.5946e-02, -6.4562e-03],
                [-3.3630e-02,  5.2410e-02,  3.1839e-02,  3.7700e-02,  4.5499e-02],
                [ 2.5421e-02, -6.7532e-02,  7.4772e-02, -4.7501e-03, -5.6375e-02],
                [ 3.3231e-03,  2.4121e-02,  8.1080e-02,  1.0045e-02,  7.1572e-02],
                [ 4.9069e-02, -3.3530e-02,  1.1358e-02, -5.2008e-04, -4.7222e-02]],

               [[ 4.6323e-02, -2.6136e-04, -5.2806e-02, -2.4812e-02,  4.8923e-02],
                [ 3.7666e-02, -5.7144e-02, -6.4298e-03, -1.7247e-02,  3.1968e-02],
                [-6.0489e-02,  2.8763e-02,  2.8486e-02,  3.5758e-02,  1.7142e-02],
                [ 8.0486e-02, -3.3536e-02,  3.6080e-02, -2.8437e-02, -1.9267e-03],
                [-6.2648e-02, -5.2696e-02, -6.9294e-02,  3.6589e-02,  1.8661e-02]],

               [[ 9.3734e-03,  1.9768e-02, -3.5296e-02, -2.7572e-02,  6.0065e-02],
                [ 4.1112e-02, -3.4800e-02, -7.9605e-02,  5.3310e-02, -1.9664e-02],
                [-1.3619e-02,  2.5598e-02, -2.6082e-02,  1.0369e-02, -6.4816e-02],
                [-5.6846e-03,  2.9120e-02,  3.2444e-02, -1.8672e-02, -6.1514e-03],
                [-4.0733e-02,  1.6849e-02,  3.8275e-02, -3.8825e-02, -5.8020e-02]]],


              ...,


              [[[ 5.9329e-02, -6.5870e-02,  1.5859e-02, -1.5105e-02, -8.8580e-03],
                [-5.0859e-02,  3.3974e-04,  3.0749e-02,  2.3093e-02,  4.2401e-02],
                [-5.4646e-02,  7.5729e-02, -5.2260e-02, -4.6122e-03,  5.1098e-02],
                [ 7.6842e-02,  1.2823e-02,  3.1960e-03, -3.9222e-02, -2.0589e-02],
                [-6.3892e-03,  4.3896e-02, -2.7891e-02, -1.8322e-02,  3.2296e-02]],

               [[-4.1220e-02,  8.0820e-02, -4.3842e-02,  7.9302e-02, -4.9983e-02],
                [-7.3771e-03, -2.5980e-03,  3.6210e-03, -2.1433e-02,  7.3244e-02],
                [-5.0760e-02, -2.5616e-03,  6.2678e-02,  7.3593e-02,  2.8301e-02],
                [ 6.7007e-02, -1.9064e-02,  8.3973e-03, -5.9798e-02, -2.6974e-02],
                [ 3.2876e-02,  3.8669e-02, -3.8719e-02, -1.0153e-02,  7.5301e-02]],

               [[-5.0442e-02,  7.7989e-03,  3.9936e-02,  3.3326e-02,  7.9093e-02],
                [-2.9757e-03,  5.9154e-02,  5.3348e-02,  7.3507e-02, -7.6732e-03],
                [ 4.6368e-02,  2.1609e-02,  6.4193e-02,  6.0536e-02, -6.3037e-02],
                [ 5.3429e-02, -2.5135e-02, -2.1193e-02, -5.1761e-04,  1.9996e-02],
                [ 4.3215e-02, -2.7851e-02, -6.4092e-02,  6.4469e-03,  4.8372e-02]],

               [[ 5.6178e-02,  3.5691e-02, -8.8206e-03, -3.0902e-02,  4.1369e-02],
                [ 3.2037e-03, -7.3717e-02, -7.2683e-02,  3.9951e-02,  2.7550e-02],
                [ 7.8738e-02, -7.8033e-02, -5.3960e-02, -3.4911e-02,  6.0757e-02],
                [-1.3279e-02,  8.0203e-02,  5.5343e-02, -2.4808e-02,  7.3956e-02],
                [ 3.9374e-02,  7.8239e-02, -7.7757e-02,  3.6571e-02, -2.9753e-02]],

               [[ 1.4708e-02, -7.2628e-02,  6.9122e-02, -4.9947e-02, -2.0280e-02],
                [-4.6593e-02, -7.6654e-02,  3.2037e-02, -1.1588e-02, -4.1092e-03],
                [ 6.4820e-02,  1.2277e-02, -5.5943e-02,  3.8722e-02, -2.4028e-02],
                [-5.1028e-02,  6.5714e-02, -1.7858e-02,  3.3266e-04,  7.5303e-02],
                [-2.6313e-02,  2.8460e-02,  7.5677e-02,  6.7317e-02,  2.4816e-02]],

               [[ 7.5268e-02, -1.0377e-02,  8.7081e-03,  2.6532e-02, -3.6530e-02],
                [ 3.2920e-02, -5.4807e-03,  5.5366e-02,  6.0630e-02,  7.4726e-02],
                [ 6.9652e-03, -4.8734e-02, -5.6073e-02, -6.8096e-02, -3.3923e-02],
                [ 3.3389e-03,  5.3901e-02,  7.1177e-02,  5.2428e-02, -6.1103e-03],
                [ 6.7079e-02, -2.1961e-02,  2.5835e-02, -7.2561e-02, -1.7097e-02]]],


              [[[-1.4495e-02, -5.7205e-02, -6.9848e-02,  2.6035e-02, -6.8173e-02],
                [ 7.7289e-03,  1.0050e-02,  2.6973e-02, -3.7391e-02,  4.0037e-02],
                [-8.8926e-03,  2.8563e-02, -5.2584e-02,  7.0450e-02,  3.2828e-03],
                [ 4.1858e-02, -5.4062e-03,  8.6848e-04, -2.5418e-02, -1.4656e-02],
                [ 3.8785e-03, -1.3989e-02,  6.0276e-02, -5.7459e-02, -8.8772e-03]],

               [[-2.5420e-02, -4.8067e-02,  4.2092e-02, -1.0464e-04,  2.9935e-03],
                [-6.1568e-02, -2.6599e-02,  1.3947e-02,  6.6099e-02,  2.1827e-02],
                [-3.0738e-02,  2.9112e-02, -3.2363e-02,  5.0961e-02, -5.7377e-02],
                [-7.7591e-02, -1.8587e-03,  4.1255e-02, -7.4086e-02,  2.4900e-02],
                [-5.8474e-02, -6.6314e-02,  3.9553e-02,  6.3980e-02,  3.6565e-02]],

               [[ 1.8101e-02,  4.8489e-03, -7.8886e-02, -5.7999e-02,  7.0357e-02],
                [-4.3515e-02, -4.7323e-02, -5.6187e-02,  2.3487e-02, -4.3308e-02],
                [-4.8880e-02, -5.8023e-02, -7.3924e-02, -9.6583e-03,  6.2028e-02],
                [ 5.2484e-02, -2.6062e-02, -4.3238e-03, -7.2771e-02,  3.2814e-02],
                [ 5.0127e-02,  3.3270e-02,  6.3871e-02, -3.8169e-02,  7.1588e-02]],

               [[-6.5380e-02, -7.1597e-02,  8.0692e-02, -5.3291e-02,  4.4471e-02],
                [ 7.1552e-02,  4.8723e-02, -6.1756e-02, -4.6115e-02,  7.7099e-02],
                [-1.9059e-02,  1.1879e-03,  1.7027e-02, -4.2909e-02,  5.6986e-02],
                [-8.7295e-03, -4.8808e-02, -5.7670e-02,  2.4997e-04, -1.1744e-02],
                [-1.3926e-02,  3.1644e-02,  1.1507e-02,  8.1024e-02,  7.2102e-03]],

               [[ 2.1168e-02, -2.9686e-02, -4.1679e-02, -4.7661e-02,  3.7415e-02],
                [ 2.2248e-02, -2.6836e-02,  3.3736e-02,  7.9188e-02, -5.4294e-02],
                [-6.3778e-02, -2.1948e-02, -3.8511e-02,  3.4796e-02,  6.7318e-02],
                [ 1.9141e-03, -1.8587e-02,  7.3516e-02, -5.1300e-02,  5.7174e-02],
                [-3.4286e-02, -8.8658e-03,  2.8017e-02, -6.1345e-02,  7.5492e-02]],

               [[-3.9274e-03,  5.5152e-02, -7.5673e-02,  2.3418e-02,  2.4453e-02],
                [-4.7555e-02,  3.8749e-02,  7.6313e-02,  1.7370e-02, -3.5157e-02],
                [ 2.5302e-03, -2.2002e-02, -3.2479e-02, -2.4231e-02, -3.4979e-02],
                [ 5.0821e-02,  4.9669e-02, -3.9780e-03, -5.6009e-02, -7.5135e-02],
                [-1.4819e-02, -3.7873e-02,  8.0048e-02,  6.6795e-02,  1.9579e-04]]],


              [[[-6.7337e-03, -3.1325e-02, -3.8130e-04,  2.8326e-02,  5.0600e-02],
                [ 2.1436e-02, -4.8592e-02,  6.1161e-02, -6.3545e-02, -6.6114e-02],
                [ 2.6338e-03,  5.4122e-02, -4.5045e-02,  1.2242e-02,  8.8401e-03],
                [-7.7559e-02, -1.9867e-02,  6.5748e-03,  4.9552e-02, -6.9732e-02],
                [-3.9603e-02, -2.5227e-02, -7.4894e-02, -4.6245e-02,  3.7202e-02]],

               [[ 7.4730e-02,  4.4270e-02, -1.6544e-02, -1.4239e-02,  5.9522e-03],
                [ 2.7072e-03, -5.8596e-02,  4.8411e-02,  4.3153e-02, -4.5938e-02],
                [ 7.8193e-02, -6.7088e-02, -6.9651e-02, -6.0037e-02,  2.6229e-02],
                [-7.0300e-03, -3.4991e-02, -7.6215e-02,  9.9203e-03,  6.9492e-02],
                [ 4.9921e-02, -2.0404e-02,  6.9527e-02,  3.0870e-02,  3.4420e-02]],

               [[-1.8141e-02,  4.8283e-02,  1.9330e-02, -4.4328e-02, -1.4779e-02],
                [-5.6117e-02,  7.9438e-02,  6.8914e-03,  2.9340e-02,  5.3000e-02],
                [ 5.4209e-02, -2.0673e-02,  4.3754e-02, -3.3216e-02, -1.8343e-02],
                [ 2.5629e-02,  3.6082e-02,  7.0708e-02,  7.3608e-02,  5.9628e-02],
                [ 6.0885e-02,  1.6643e-02,  1.6415e-02, -1.0011e-02,  7.8816e-02]],

               [[ 7.5320e-02, -5.7523e-02, -4.1853e-02,  1.0916e-02, -3.0991e-02],
                [-6.2630e-02,  1.4596e-02, -1.0427e-02,  3.6875e-02, -1.6136e-03],
                [-4.4583e-02, -3.0317e-02, -2.2016e-02, -6.6638e-02,  6.7848e-02],
                [-5.1430e-02, -3.2587e-02, -1.1150e-03,  3.4666e-03, -3.3550e-02],
                [ 3.4297e-02, -7.6208e-02,  7.9594e-02,  2.8101e-02,  5.8770e-02]],

               [[-7.3166e-02,  6.3813e-02,  5.1557e-02, -8.4318e-03,  6.9600e-02],
                [ 7.9715e-02,  1.2783e-03, -5.8935e-02, -4.6001e-02,  2.3448e-04],
                [-2.1650e-02,  7.1989e-02,  5.8779e-02,  3.0808e-02, -5.0626e-02],
                [-7.0207e-02,  6.4429e-02,  3.2764e-02,  4.4986e-02, -5.6941e-02],
                [-3.9577e-02,  2.8508e-03,  1.5647e-02, -5.7797e-02,  3.3754e-02]],

               [[-3.6933e-02,  6.9224e-02,  6.7209e-02,  1.0161e-02,  2.9785e-02],
                [ 1.1000e-02,  5.7507e-02,  7.7336e-02, -2.3910e-02,  1.2587e-02],
                [-3.8743e-02,  9.7108e-03,  4.9643e-02,  3.2226e-02, -6.9066e-02],
                [ 1.7411e-02, -5.1872e-02, -4.3662e-02, -2.2543e-02,  4.4947e-02],
                [ 5.7543e-02, -2.1366e-02, -2.5460e-02,  4.8669e-02,  4.8413e-02]]]],
             requires_grad=True)
       tensor: tensor([[[[ 0.0030,  0.0370,  0.0061,  0.0055,  0.1897],
                [ 0.0061,  0.0060, -0.0193,  0.0377, -0.0564],
                [-0.0500, -0.0086,  0.0329,  0.0670,  0.0101],
                [-0.0847,  0.0830,  0.0771, -0.0188,  0.0843],
                [-0.0097, -0.1136, -0.0237,  0.0625,  0.0676]],

               [[-0.0559, -0.0980, -0.0725, -0.0112, -0.0243],
                [ 0.0962, -0.0593, -0.1873,  0.0279,  0.0782],
                [-0.0591, -0.0855,  0.0373, -0.1575, -0.0922],
                [ 0.0329,  0.0644, -0.0940, -0.0186, -0.0622],
                [-0.0482,  0.0230,  0.0336, -0.0282, -0.1814]],

               [[-0.0127, -0.1343, -0.1221, -0.0167, -0.0065],
                [ 0.0476,  0.0722, -0.0504,  0.0020,  0.0555],
                [-0.0666,  0.1025, -0.0479,  0.0443,  0.0820],
                [ 0.0718,  0.0899, -0.0440, -0.0267, -0.0442],
                [-0.0810,  0.0078,  0.0804,  0.0138,  0.0176]],

               [[ 0.0109,  0.0394, -0.0121,  0.0404, -0.0129],
                [ 0.0802, -0.0914, -0.0956,  0.0313,  0.0754],
                [-0.0038, -0.0588, -0.0027, -0.0544, -0.0149],
                [-0.0293,  0.0581,  0.0832,  0.1311,  0.0171],
                [ 0.1106,  0.0351,  0.0081, -0.0740,  0.0295]],

               [[ 0.0067,  0.0062, -0.0142,  0.0262, -0.0578],
                [-0.0709,  0.0171, -0.0674, -0.0588,  0.0029],
                [-0.0295, -0.0527, -0.0683, -0.1257, -0.0551],
                [ 0.0356,  0.0503, -0.0011,  0.0599, -0.0024],
                [ 0.1453,  0.0082, -0.0518, -0.0353,  0.0338]],

               [[ 0.0246,  0.0324, -0.1101,  0.0527, -0.1270],
                [-0.1302, -0.0539,  0.0103,  0.0226,  0.0360],
                [ 0.0771,  0.0761, -0.0768, -0.0466,  0.0357],
                [-0.0788, -0.0133,  0.1923,  0.1031, -0.0516],
                [-0.0493,  0.0171, -0.0626, -0.0484,  0.1355]]],


              [[[ 0.0296, -0.0681, -0.0929,  0.0268,  0.0190],
                [-0.0455, -0.0379,  0.0346, -0.0537, -0.0268],
                [-0.0023, -0.0677, -0.0214, -0.0291, -0.1081],
                [-0.0020, -0.0793,  0.1388, -0.0793,  0.1056],
                [-0.0718,  0.1434, -0.1299, -0.0603, -0.0080]],

               [[-0.1244,  0.0348,  0.0574, -0.0776,  0.0472],
                [ 0.1165,  0.0356,  0.0074,  0.0251, -0.1098],
                [ 0.0321, -0.0125,  0.0018, -0.1240,  0.0353],
                [ 0.0994, -0.0156,  0.0453, -0.0251, -0.0393],
                [ 0.0301,  0.1138,  0.1745,  0.1075,  0.1367]],

               [[-0.0349, -0.0200, -0.0847,  0.0510,  0.0069],
                [ 0.0411,  0.0056,  0.0714,  0.1483,  0.0639],
                [-0.1383,  0.0720, -0.0054,  0.0541, -0.0916],
                [-0.0191, -0.0420, -0.0035,  0.0365, -0.0419],
                [-0.0620, -0.0209,  0.0207,  0.0278, -0.0520]],

               [[-0.0598, -0.0202,  0.0199,  0.0228,  0.0366],
                [ 0.0405,  0.0057,  0.0308,  0.0057, -0.0100],
                [-0.0716, -0.0653, -0.0008, -0.0686, -0.0948],
                [-0.0877, -0.0224, -0.0703, -0.0918,  0.0621],
                [-0.1118,  0.0707, -0.0052,  0.0841, -0.0137]],

               [[ 0.0930,  0.0124, -0.0211,  0.0541, -0.1070],
                [ 0.0211,  0.0281, -0.0532,  0.0517, -0.0016],
                [ 0.0772,  0.0063, -0.0234,  0.0156, -0.0652],
                [-0.0644, -0.0647, -0.0120,  0.1149, -0.0564],
                [ 0.0767, -0.0070, -0.0011, -0.1237, -0.0066]],

               [[ 0.1717, -0.0064,  0.0110,  0.0309,  0.0137],
                [-0.0550,  0.0663,  0.0573,  0.1322, -0.0082],
                [ 0.0342, -0.0386, -0.0328, -0.0154,  0.0448],
                [ 0.0547, -0.0263, -0.1192,  0.0387, -0.0164],
                [ 0.0607, -0.0326,  0.1569,  0.0261, -0.0246]]],


              [[[ 0.0702,  0.0203, -0.0141, -0.0446,  0.0751],
                [-0.0347,  0.1131, -0.1338, -0.0294,  0.0571],
                [-0.0677,  0.1090, -0.0009, -0.0625, -0.1291],
                [ 0.0354,  0.0200, -0.0549, -0.0156, -0.0812],
                [ 0.0049, -0.0936,  0.0989,  0.1016,  0.0297]],

               [[ 0.0672, -0.0527, -0.0339, -0.0094,  0.0326],
                [ 0.0930, -0.1736,  0.0808, -0.0811,  0.0135],
                [-0.0180, -0.0720, -0.1406, -0.0214, -0.0106],
                [-0.0289, -0.0049, -0.0305,  0.0357,  0.1551],
                [-0.0046,  0.0355, -0.0194,  0.0280, -0.0185]],

               [[-0.0964,  0.0292,  0.0972, -0.0208,  0.0972],
                [-0.1380, -0.0010,  0.0390,  0.0697, -0.1097],
                [-0.0065, -0.0035, -0.0658,  0.0803, -0.1039],
                [ 0.0459, -0.0645, -0.0060, -0.0435, -0.0126],
                [-0.0245,  0.0109, -0.0870, -0.0560, -0.0631]],

               [[-0.0394, -0.0649, -0.0273,  0.0062,  0.0212],
                [-0.0515,  0.0840, -0.0245,  0.0682,  0.0838],
                [ 0.0320, -0.0481,  0.0371, -0.0384, -0.1137],
                [ 0.0538, -0.0604,  0.0587, -0.0396,  0.0582],
                [ 0.0630, -0.0448,  0.0376,  0.0443, -0.0243]],

               [[ 0.0887, -0.0682, -0.0835,  0.0265,  0.0202],
                [-0.0194, -0.0398, -0.0036, -0.0129,  0.0483],
                [ 0.0099,  0.0179, -0.0808,  0.0102, -0.0288],
                [ 0.0867, -0.0921, -0.1018, -0.0525, -0.0209],
                [-0.0754,  0.0110, -0.0890,  0.1195, -0.0101]],

               [[ 0.0732,  0.0196, -0.0186, -0.0783,  0.0396],
                [ 0.0464, -0.0444, -0.1216,  0.0254, -0.0550],
                [ 0.0327,  0.0962, -0.0902,  0.0010, -0.0489],
                [-0.0611,  0.0270,  0.0678,  0.0009,  0.0287],
                [-0.1257, -0.0570,  0.0649, -0.0547, -0.0710]]],


              ...,


              [[[-0.0368, -0.1070, -0.0377, -0.0151, -0.0040],
                [-0.0080, -0.0498, -0.0739,  0.0521, -0.0047],
                [-0.1140,  0.1242, -0.0421,  0.0485,  0.1163],
                [ 0.0760,  0.1164,  0.0229, -0.0359,  0.0471],
                [-0.0256, -0.0217,  0.0007, -0.0934, -0.0431]],

               [[-0.0146,  0.1180, -0.0894,  0.0662, -0.1869],
                [ 0.0360, -0.0623, -0.0055,  0.0395,  0.0304],
                [-0.0796,  0.0298,  0.1031,  0.1473,  0.0148],
                [ 0.0487, -0.0528,  0.0374, -0.0222, -0.0536],
                [ 0.0577,  0.0646,  0.0502, -0.0029,  0.0426]],

               [[-0.0677, -0.0192,  0.0540,  0.0301,  0.0904],
                [-0.0196,  0.1227,  0.1267,  0.0565,  0.0055],
                [-0.0280, -0.0297,  0.0549, -0.0250, -0.1382],
                [-0.1008,  0.0365, -0.0657,  0.0342,  0.0065],
                [ 0.0315,  0.0172, -0.0969,  0.0013,  0.0104]],

               [[ 0.0558,  0.0658, -0.0555, -0.0110,  0.0538],
                [ 0.0588, -0.1003, -0.1523, -0.0112,  0.0897],
                [ 0.0105,  0.0035, -0.0389,  0.0006,  0.1152],
                [ 0.0710,  0.1802,  0.0374, -0.0865,  0.0942],
                [ 0.0541,  0.0844, -0.0715,  0.0436, -0.0691]],

               [[ 0.0120, -0.0954,  0.0155,  0.0152, -0.1001],
                [-0.0493, -0.0088,  0.0529, -0.0759, -0.0161],
                [ 0.0656,  0.0814,  0.0555,  0.0120, -0.0367],
                [-0.0500,  0.1128,  0.0131, -0.0284,  0.0928],
                [ 0.0287,  0.0125,  0.0348,  0.0482, -0.0143]],

               [[ 0.0911,  0.0056, -0.0845,  0.0777, -0.0683],
                [ 0.0142, -0.0229,  0.0458,  0.0650,  0.0809],
                [ 0.0135, -0.0528, -0.1453, -0.0548, -0.0564],
                [ 0.0639,  0.0091,  0.0875,  0.0836, -0.0286],
                [ 0.0162, -0.0257, -0.0204, -0.1649, -0.0647]]],


              [[[-0.0013, -0.0171, -0.1675, -0.0056, -0.0280],
                [ 0.0808,  0.0189,  0.1015, -0.0198,  0.0340],
                [-0.1070, -0.0596, -0.0489,  0.1452, -0.0884],
                [-0.0348, -0.0879, -0.0137,  0.0157,  0.0132],
                [ 0.0066,  0.0364,  0.1176, -0.0999,  0.0079]],

               [[ 0.0413, -0.0700,  0.0778,  0.0368,  0.0522],
                [-0.0291, -0.0256,  0.0869,  0.0585,  0.0262],
                [-0.1031, -0.0223, -0.0122,  0.0745,  0.0264],
                [-0.1007,  0.0369, -0.0398, -0.0670,  0.0970],
                [-0.1686, -0.0153, -0.0234,  0.0754,  0.0909]],

               [[ 0.0407,  0.0859, -0.1172, -0.0371,  0.0481],
                [-0.0890,  0.0338, -0.0192,  0.0476, -0.0217],
                [ 0.0182, -0.0838, -0.0416, -0.0070, -0.0653],
                [-0.0378, -0.0439,  0.0132, -0.0796, -0.0450],
                [ 0.0619, -0.0380,  0.0575, -0.0631,  0.1044]],

               [[ 0.0278, -0.1035,  0.0012, -0.0706,  0.0558],
                [ 0.0631,  0.1811, -0.0200,  0.0622,  0.0970],
                [ 0.0673, -0.0611,  0.0469,  0.0018,  0.0879],
                [-0.0202, -0.0397, -0.0025,  0.0565, -0.0940],
                [-0.0053,  0.0735,  0.0329,  0.0304,  0.0019]],

               [[-0.0521, -0.0473, -0.0060, -0.0507,  0.0482],
                [-0.0048, -0.0338, -0.0170,  0.1147, -0.0254],
                [-0.0622, -0.0311, -0.0690,  0.0485,  0.0562],
                [ 0.0360,  0.0941,  0.1680, -0.0755,  0.0932],
                [-0.0503, -0.0668, -0.0178, -0.1147,  0.0155]],

               [[-0.0955,  0.0673, -0.1261, -0.0133,  0.0173],
                [-0.0105,  0.0789,  0.0800,  0.0514, -0.0300],
                [ 0.0112,  0.0047,  0.0450, -0.0191, -0.1006],
                [ 0.1001,  0.0414,  0.0202,  0.0338, -0.0103],
                [-0.1475, -0.1024,  0.0735, -0.0058,  0.0517]]],


              [[[-0.0090, -0.0069, -0.0367, -0.0155,  0.0832],
                [-0.0323,  0.0123,  0.0614,  0.0114, -0.0722],
                [-0.0454, -0.0036, -0.0285,  0.1304,  0.0580],
                [-0.0847,  0.0040,  0.0986,  0.0794, -0.0332],
                [-0.0672, -0.0382, -0.0653, -0.0962,  0.1406]],

               [[ 0.1578,  0.0159, -0.0342,  0.0398, -0.0326],
                [ 0.0567, -0.1326,  0.0710,  0.0133, -0.0371],
                [ 0.0955, -0.1701, -0.1302, -0.0422, -0.0215],
                [-0.0294, -0.0221, -0.1544,  0.0118,  0.1777],
                [ 0.0897, -0.1050,  0.0651,  0.0025, -0.0192]],

               [[-0.0447,  0.0564,  0.0178, -0.0194, -0.0053],
                [-0.0803,  0.1073, -0.0077,  0.0310,  0.1199],
                [ 0.0350,  0.2188,  0.0568, -0.1098, -0.1173],
                [ 0.0527,  0.1389,  0.0367,  0.1086,  0.0371],
                [ 0.0656,  0.0234, -0.0382, -0.0206,  0.2017]],

               [[ 0.0579, -0.0759, -0.0199,  0.0364, -0.0006],
                [-0.0441, -0.0335,  0.0095,  0.0706,  0.0474],
                [-0.0435, -0.0494, -0.0207,  0.0210,  0.0661],
                [-0.1438, -0.0200,  0.0809, -0.0481, -0.0648],
                [ 0.0745, -0.0318,  0.0249,  0.0515,  0.0518]],

               [[-0.0142,  0.0744, -0.0306, -0.0203,  0.1105],
                [ 0.0364, -0.0508, -0.0803, -0.0077,  0.0171],
                [-0.0659,  0.1150,  0.0811, -0.0835, -0.0413],
                [-0.0712,  0.0401, -0.0108,  0.0200, -0.0529],
                [-0.0456, -0.0428,  0.0376, -0.0927, -0.0055]],

               [[ 0.1133,  0.0363,  0.0195,  0.0877,  0.0687],
                [-0.0174,  0.0007,  0.1192, -0.0410, -0.0328],
                [-0.1145, -0.0187,  0.0342,  0.0645, -0.0715],
                [-0.0164, -0.0221, -0.0441,  0.0314, -0.0280],
                [ 0.0618, -0.0397, -0.0189,  0.0394,  0.0238]]]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([ 0.0396, -0.0348,  0.0458,  0.0297, -0.0694,  0.0213, -0.0060, -0.0400,
              -0.0470,  0.0165,  0.0250,  0.0796,  0.0477,  0.0254,  0.0719,  0.0583],
             requires_grad=True)
       tensor: tensor([ 0.1063, -0.0750, -0.0130,  0.1064, -0.0199,  0.0071, -0.0617, -0.0020,
              -0.0503, -0.0594,  0.0255,  0.0815, -0.0090,  0.1249,  0.0318,  0.1622],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[[[0., -0., 0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., 0.]],

               [[-0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., 0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [0., 0., -0., 0., -0.]],

               [[-0., -0., -0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., -0., -0., -0.],
                [-0., 0., -0., 0., -0.]],

               [[0., 0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., 0., -0., 0.],
                [0., 0., 0., -0., 0.],
                [0., -0., 0., -0., 0.]],

               [[0., -0., -0., -0., -0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., -0.],
                [-0., 0., -0., 0., 0.],
                [0., 0., 0., -0., 0.]],

               [[-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., -0., -0., 0.]]],


              [[[-0., -0., -0., 0., -0.],
                [-0., 0., -0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [-0., 0., -0., -0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., 0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.]],

               [[-0., 0., -0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [-0., 0., -0., -0., -0.],
                [0., -0., -0., 0., -0.],
                [-0., -0., -0., 0., -0.]],

               [[-0., 0., 0., 0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., 0., -0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[0., -0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., 0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [0., -0., -0., 0., -0.]]],


              [[[-0., 0., 0., -0., 0.],
                [-0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., 0., 0.]],

               [[0., 0., -0., 0., 0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., -0., -0., -0.]],

               [[-0., 0., 0., -0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [-0., 0., -0., -0., -0.]],

               [[-0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.]],

               [[0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [-0., -0., -0., 0., 0.]],

               [[0., 0., -0., -0., 0.],
                [0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [-0., 0., 0., -0., -0.],
                [-0., 0., 0., -0., -0.]]],


              ...,


              [[[0., -0., 0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [-0., 0., -0., -0., 0.],
                [0., 0., 0., -0., -0.],
                [-0., 0., -0., -0., 0.]],

               [[-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., -0., -0., 0.]],

               [[-0., 0., 0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., 0., 0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., 0., 0.]],

               [[0., 0., -0., -0., 0.],
                [0., -0., -0., 0., 0.],
                [0., -0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [0., 0., -0., 0., -0.]],

               [[0., -0., 0., -0., -0.],
                [-0., -0., 0., -0., -0.],
                [0., 0., -0., 0., -0.],
                [-0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., 0.]],

               [[0., -0., 0., 0., -0.],
                [0., -0., 0., 0., 0.],
                [0., -0., -0., -0., -0.],
                [0., 0., 0., 0., -0.],
                [0., -0., 0., -0., -0.]]],


              [[[-0., -0., -0., 0., -0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., -0., 0., -0., -0.]],

               [[-0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., 0., -0., 0.],
                [-0., -0., 0., 0., 0.]],

               [[0., 0., -0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [0., -0., -0., -0., 0.],
                [0., 0., 0., -0., 0.]],

               [[-0., -0., 0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [-0., 0., 0., 0., 0.]],

               [[0., -0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [-0., -0., -0., 0., 0.],
                [0., -0., 0., -0., 0.],
                [-0., -0., 0., -0., 0.]],

               [[-0., 0., -0., 0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., -0.],
                [0., 0., -0., -0., -0.],
                [-0., -0., 0., 0., 0.]]],


              [[[-0., -0., -0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., -0., 0., 0.],
                [-0., -0., 0., 0., -0.],
                [-0., -0., -0., -0., 0.]],

               [[0., 0., -0., -0., 0.],
                [0., -0., 0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., 0.],
                [0., -0., 0., 0., 0.]],

               [[-0., 0., 0., -0., -0.],
                [-0., 0., 0., 0., 0.],
                [0., -0., 0., -0., -0.],
                [0., 0., 0., 0., 0.],
                [0., 0., 0., -0., 0.]],

               [[0., -0., -0., 0., -0.],
                [-0., 0., -0., 0., -0.],
                [-0., -0., -0., -0., 0.],
                [-0., -0., -0., 0., -0.],
                [0., -0., 0., 0., 0.]],

               [[-0., 0., 0., -0., 0.],
                [0., 0., -0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., 0., -0.],
                [-0., 0., 0., -0., 0.]],

               [[-0., 0., 0., 0., 0.],
                [0., 0., 0., -0., 0.],
                [-0., 0., 0., 0., -0.],
                [0., -0., -0., -0., 0.],
                [0., -0., -0., 0., 0.]]]])
       scale: tensor([[[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              ...,


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]],


              [[[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]],

               [[0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816],
                [0.0816, 0.0816, 0.0816, 0.0816, 0.0816]]]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[[[ 2.4966e-03, -6.7341e-02,  6.2799e-02, -3.4370e-03,  4.5322e-02],
                [ 1.4607e-02,  6.2478e-02,  6.8280e-02, -3.5654e-02, -1.8825e-02],
                [ 7.1212e-02,  1.5050e-02, -1.8025e-02,  2.5491e-02,  2.8724e-02],
                [-2.2547e-02,  6.0692e-02,  5.2667e-02, -5.7841e-03,  2.6376e-02],
                [-4.3343e-02, -5.9860e-02, -1.2973e-02,  6.2583e-03,  8.0137e-02]],

               [[-2.1569e-02, -7.1375e-02,  1.2631e-02, -1.8075e-02, -4.7912e-02],
                [-2.8027e-02, -2.6252e-02, -5.9709e-02,  6.7074e-02,  2.6864e-02],
                [-6.6509e-02, -8.1214e-02,  2.3467e-02, -6.6464e-02, -6.0669e-02],
                [ 5.0579e-03,  1.1671e-02, -6.6733e-02, -4.9352e-02, -2.0956e-02],
                [ 7.6966e-03,  2.2004e-02, -2.8989e-02,  5.1644e-02, -6.1621e-02]],

               [[-5.4252e-02, -8.1522e-02, -6.2672e-02, -6.3098e-02, -5.2104e-02],
                [ 2.5544e-02,  3.6878e-02, -4.4395e-02, -4.9102e-02,  4.8144e-02],
                [-3.2576e-02,  8.0204e-02, -9.2122e-03, -7.7311e-03,  6.7384e-02],
                [-1.5682e-02,  9.5757e-03, -6.5708e-02, -2.7689e-02, -6.5013e-02],
                [-6.4995e-02,  2.7979e-02, -1.6400e-03,  4.6764e-02, -4.3071e-02]],

               [[ 5.6996e-02,  5.9034e-03, -4.6686e-02, -6.9235e-02,  1.9127e-03],
                [ 1.8545e-02, -1.9516e-02, -5.0705e-02, -3.5648e-02,  6.0718e-02],
                [-6.7171e-02, -7.3959e-02,  5.4578e-02, -1.1379e-02,  4.7259e-02],
                [ 1.8278e-02,  1.0828e-02,  2.9502e-02, -4.8256e-02,  1.5783e-02],
                [ 3.0306e-02, -5.6080e-03,  1.1095e-02, -8.0521e-02,  6.7744e-02]],

               [[ 6.4667e-03, -3.1020e-02, -5.5962e-02, -4.0058e-02, -3.4994e-02],
                [-4.6501e-02, -5.0324e-02, -7.3864e-02,  1.8048e-02, -5.4226e-02],
                [-7.8412e-02, -4.8003e-02, -1.1779e-02, -7.8603e-02, -3.8339e-02],
                [-2.3014e-02,  5.4240e-02, -4.7970e-02,  2.5364e-02,  3.3201e-02],
                [ 6.0505e-02,  6.4193e-02,  7.4750e-04, -2.1983e-02,  7.6744e-02]],

               [[-6.7106e-02,  3.3650e-02, -7.8031e-02,  6.1907e-02, -7.8580e-02],
                [-3.2164e-02, -4.9644e-02,  5.5871e-02, -1.9880e-02, -5.9899e-02],
                [ 7.6375e-02,  1.6907e-02, -3.7618e-02, -4.6346e-02,  4.5166e-02],
                [-6.2424e-02,  1.6883e-02,  3.9223e-02, -1.4790e-02, -6.1880e-02],
                [ 3.8489e-02, -7.1589e-05, -1.7487e-02, -6.2649e-02,  6.8018e-02]]],


              [[[-2.5034e-02, -4.6058e-02, -2.3337e-02,  2.2215e-02, -1.4850e-02],
                [-4.3984e-02,  1.4747e-02, -7.9597e-03, -6.3069e-02, -2.2592e-02],
                [ 4.3075e-02,  1.9039e-02,  3.0107e-02,  2.3445e-03, -6.9232e-02],
                [ 6.4497e-03, -4.7578e-02,  7.6891e-02, -5.1616e-02,  1.2707e-02],
                [-3.6706e-02,  7.6937e-02, -5.9667e-02, -8.1446e-02,  1.7545e-02]],

               [[-4.4402e-02,  6.4599e-02,  4.1490e-02, -2.8644e-02, -4.8206e-02],
                [ 2.6031e-02,  5.7371e-02,  2.0308e-02, -7.4061e-02, -4.4877e-02],
                [-2.7835e-02,  5.2467e-02,  4.0908e-02, -8.1401e-02, -6.8110e-02],
                [ 7.6093e-02, -9.4644e-03, -5.2426e-03,  7.7898e-03, -1.6856e-02],
                [-7.1420e-02,  7.8023e-02,  8.0673e-02,  3.5987e-02,  7.1558e-02]],

               [[-2.2731e-02,  6.9536e-03, -3.5987e-02,  7.5870e-02, -7.4783e-02],
                [ 3.9431e-02, -3.5946e-02,  3.5435e-02,  3.6262e-02,  1.0886e-02],
                [-5.3538e-02,  7.1992e-02, -3.2372e-04, -1.5915e-02, -7.1567e-02],
                [ 2.4658e-02, -1.1047e-02, -5.5648e-02,  1.9212e-02, -1.2692e-02],
                [-8.4647e-03, -4.5983e-02, -6.7419e-04,  9.8715e-03, -4.6148e-02]],

               [[-2.0329e-02,  4.4603e-02,  6.4148e-02,  7.1504e-02, -4.5712e-02],
                [ 4.6664e-02,  7.8892e-02, -3.1212e-02, -8.2775e-03, -6.5268e-02],
                [-5.3630e-02, -3.4377e-02,  4.4915e-02, -3.5955e-02,  4.4886e-02],
                [-6.4155e-02, -1.4317e-02, -6.8112e-02, -2.0220e-02,  3.0197e-02],
                [-6.6971e-02,  7.1244e-02,  3.6331e-02,  4.9435e-02,  9.6099e-04]],

               [[ 7.0543e-02, -2.7817e-02,  6.0078e-02,  4.7984e-02, -4.3316e-02],
                [ 7.6346e-02,  1.2656e-02, -6.2175e-02,  7.6718e-02,  1.8006e-02],
                [ 7.4139e-02, -5.7349e-02, -1.0675e-02, -2.4720e-02, -4.4425e-02],
                [-7.0388e-02, -6.9265e-02,  2.1314e-02,  4.5109e-02, -7.1493e-02],
                [-3.2351e-02,  3.5873e-02, -4.8951e-03, -7.2958e-02, -4.4238e-02]],

               [[ 8.1172e-02, -5.8811e-02,  4.0991e-02,  6.9875e-02, -6.0488e-02],
                [-6.3614e-02,  2.7591e-02,  4.7980e-02,  8.1415e-02, -3.3863e-02],
                [ 2.1980e-02, -2.3815e-02,  1.8170e-02, -2.5447e-02,  2.7371e-02],
                [ 8.1534e-02, -3.9964e-02, -7.7124e-02,  5.2303e-02, -2.0802e-02],
                [ 6.0792e-02, -7.2300e-02, -2.8772e-02,  5.7216e-02, -3.4328e-02]]],


              [[[-2.7990e-03,  1.8808e-02,  6.7455e-04, -6.6091e-02,  6.6982e-02],
                [-6.9830e-02,  6.5788e-02, -4.8964e-02, -4.1140e-03,  3.1098e-03],
                [-5.4660e-02,  7.1862e-02,  7.1153e-03, -6.5570e-02, -5.1447e-02],
                [ 6.0571e-02,  2.0612e-03,  2.3722e-02,  1.5375e-02, -5.1639e-02],
                [ 4.7830e-02, -2.5379e-02,  6.9856e-02,  4.5925e-02,  9.6594e-03]],

               [[ 2.2661e-02,  6.6682e-03, -4.9346e-02,  2.0883e-03,  3.6717e-02],
                [-2.9341e-02, -7.5506e-02, -2.3876e-02, -4.3604e-02,  6.0555e-02],
                [-5.9189e-03, -9.2285e-03, -7.6425e-02,  5.5105e-02, -3.4522e-02],
                [ 4.5115e-02,  6.4247e-03,  4.4812e-02,  1.9783e-02,  7.9421e-02],
                [ 1.4393e-02,  6.9105e-02, -5.1842e-02, -5.6427e-04, -3.0378e-02]],

               [[-2.9724e-02,  2.7430e-02,  2.0606e-02, -6.8593e-02,  7.5719e-02],
                [-1.1574e-02, -6.1978e-02,  1.5036e-02,  6.5520e-02, -5.5500e-02],
                [-5.4038e-03,  1.6463e-03, -1.6045e-02,  5.0524e-02, -6.1660e-02],
                [ 7.9241e-02, -6.3365e-02, -7.6820e-02, -7.4036e-02, -3.8586e-02],
                [-2.4314e-02,  8.9899e-03, -7.1113e-03, -3.9968e-02, -5.5614e-02]],

               [[-2.1929e-02, -1.1871e-02, -4.8945e-02,  2.5946e-02, -6.4562e-03],
                [-3.3630e-02,  5.2410e-02,  3.1839e-02,  3.7700e-02,  4.5499e-02],
                [ 2.5421e-02, -6.7532e-02,  7.4772e-02, -4.7501e-03, -5.6375e-02],
                [ 3.3231e-03,  2.4121e-02,  8.1080e-02,  1.0045e-02,  7.1572e-02],
                [ 4.9069e-02, -3.3530e-02,  1.1358e-02, -5.2008e-04, -4.7222e-02]],

               [[ 4.6323e-02, -2.6136e-04, -5.2806e-02, -2.4812e-02,  4.8923e-02],
                [ 3.7666e-02, -5.7144e-02, -6.4298e-03, -1.7247e-02,  3.1968e-02],
                [-6.0489e-02,  2.8763e-02,  2.8486e-02,  3.5758e-02,  1.7142e-02],
                [ 8.0486e-02, -3.3536e-02,  3.6080e-02, -2.8437e-02, -1.9267e-03],
                [-6.2648e-02, -5.2696e-02, -6.9294e-02,  3.6589e-02,  1.8661e-02]],

               [[ 9.3734e-03,  1.9768e-02, -3.5296e-02, -2.7572e-02,  6.0065e-02],
                [ 4.1112e-02, -3.4800e-02, -7.9605e-02,  5.3310e-02, -1.9664e-02],
                [-1.3619e-02,  2.5598e-02, -2.6082e-02,  1.0369e-02, -6.4816e-02],
                [-5.6846e-03,  2.9120e-02,  3.2444e-02, -1.8672e-02, -6.1514e-03],
                [-4.0733e-02,  1.6849e-02,  3.8275e-02, -3.8825e-02, -5.8020e-02]]],


              ...,


              [[[ 5.9329e-02, -6.5870e-02,  1.5859e-02, -1.5105e-02, -8.8580e-03],
                [-5.0859e-02,  3.3974e-04,  3.0749e-02,  2.3093e-02,  4.2401e-02],
                [-5.4646e-02,  7.5729e-02, -5.2260e-02, -4.6122e-03,  5.1098e-02],
                [ 7.6842e-02,  1.2823e-02,  3.1960e-03, -3.9222e-02, -2.0589e-02],
                [-6.3892e-03,  4.3896e-02, -2.7891e-02, -1.8322e-02,  3.2296e-02]],

               [[-4.1220e-02,  8.0820e-02, -4.3842e-02,  7.9302e-02, -4.9983e-02],
                [-7.3771e-03, -2.5980e-03,  3.6210e-03, -2.1433e-02,  7.3244e-02],
                [-5.0760e-02, -2.5616e-03,  6.2678e-02,  7.3593e-02,  2.8301e-02],
                [ 6.7007e-02, -1.9064e-02,  8.3973e-03, -5.9798e-02, -2.6974e-02],
                [ 3.2876e-02,  3.8669e-02, -3.8719e-02, -1.0153e-02,  7.5301e-02]],

               [[-5.0442e-02,  7.7989e-03,  3.9936e-02,  3.3326e-02,  7.9093e-02],
                [-2.9757e-03,  5.9154e-02,  5.3348e-02,  7.3507e-02, -7.6732e-03],
                [ 4.6368e-02,  2.1609e-02,  6.4193e-02,  6.0536e-02, -6.3037e-02],
                [ 5.3429e-02, -2.5135e-02, -2.1193e-02, -5.1761e-04,  1.9996e-02],
                [ 4.3215e-02, -2.7851e-02, -6.4092e-02,  6.4469e-03,  4.8372e-02]],

               [[ 5.6178e-02,  3.5691e-02, -8.8206e-03, -3.0902e-02,  4.1369e-02],
                [ 3.2037e-03, -7.3717e-02, -7.2683e-02,  3.9951e-02,  2.7550e-02],
                [ 7.8738e-02, -7.8033e-02, -5.3960e-02, -3.4911e-02,  6.0757e-02],
                [-1.3279e-02,  8.0203e-02,  5.5343e-02, -2.4808e-02,  7.3956e-02],
                [ 3.9374e-02,  7.8239e-02, -7.7757e-02,  3.6571e-02, -2.9753e-02]],

               [[ 1.4708e-02, -7.2628e-02,  6.9122e-02, -4.9947e-02, -2.0280e-02],
                [-4.6593e-02, -7.6654e-02,  3.2037e-02, -1.1588e-02, -4.1092e-03],
                [ 6.4820e-02,  1.2277e-02, -5.5943e-02,  3.8722e-02, -2.4028e-02],
                [-5.1028e-02,  6.5714e-02, -1.7858e-02,  3.3266e-04,  7.5303e-02],
                [-2.6313e-02,  2.8460e-02,  7.5677e-02,  6.7317e-02,  2.4816e-02]],

               [[ 7.5268e-02, -1.0377e-02,  8.7081e-03,  2.6532e-02, -3.6530e-02],
                [ 3.2920e-02, -5.4807e-03,  5.5366e-02,  6.0630e-02,  7.4726e-02],
                [ 6.9652e-03, -4.8734e-02, -5.6073e-02, -6.8096e-02, -3.3923e-02],
                [ 3.3389e-03,  5.3901e-02,  7.1177e-02,  5.2428e-02, -6.1103e-03],
                [ 6.7079e-02, -2.1961e-02,  2.5835e-02, -7.2561e-02, -1.7097e-02]]],


              [[[-1.4495e-02, -5.7205e-02, -6.9848e-02,  2.6035e-02, -6.8173e-02],
                [ 7.7289e-03,  1.0050e-02,  2.6973e-02, -3.7391e-02,  4.0037e-02],
                [-8.8926e-03,  2.8563e-02, -5.2584e-02,  7.0450e-02,  3.2828e-03],
                [ 4.1858e-02, -5.4062e-03,  8.6848e-04, -2.5418e-02, -1.4656e-02],
                [ 3.8785e-03, -1.3989e-02,  6.0276e-02, -5.7459e-02, -8.8772e-03]],

               [[-2.5420e-02, -4.8067e-02,  4.2092e-02, -1.0464e-04,  2.9935e-03],
                [-6.1568e-02, -2.6599e-02,  1.3947e-02,  6.6099e-02,  2.1827e-02],
                [-3.0738e-02,  2.9112e-02, -3.2363e-02,  5.0961e-02, -5.7377e-02],
                [-7.7591e-02, -1.8587e-03,  4.1255e-02, -7.4086e-02,  2.4900e-02],
                [-5.8474e-02, -6.6314e-02,  3.9553e-02,  6.3980e-02,  3.6565e-02]],

               [[ 1.8101e-02,  4.8489e-03, -7.8886e-02, -5.7999e-02,  7.0357e-02],
                [-4.3515e-02, -4.7323e-02, -5.6187e-02,  2.3487e-02, -4.3308e-02],
                [-4.8880e-02, -5.8023e-02, -7.3924e-02, -9.6583e-03,  6.2028e-02],
                [ 5.2484e-02, -2.6062e-02, -4.3238e-03, -7.2771e-02,  3.2814e-02],
                [ 5.0127e-02,  3.3270e-02,  6.3871e-02, -3.8169e-02,  7.1588e-02]],

               [[-6.5380e-02, -7.1597e-02,  8.0692e-02, -5.3291e-02,  4.4471e-02],
                [ 7.1552e-02,  4.8723e-02, -6.1756e-02, -4.6115e-02,  7.7099e-02],
                [-1.9059e-02,  1.1879e-03,  1.7027e-02, -4.2909e-02,  5.6986e-02],
                [-8.7295e-03, -4.8808e-02, -5.7670e-02,  2.4997e-04, -1.1744e-02],
                [-1.3926e-02,  3.1644e-02,  1.1507e-02,  8.1024e-02,  7.2102e-03]],

               [[ 2.1168e-02, -2.9686e-02, -4.1679e-02, -4.7661e-02,  3.7415e-02],
                [ 2.2248e-02, -2.6836e-02,  3.3736e-02,  7.9188e-02, -5.4294e-02],
                [-6.3778e-02, -2.1948e-02, -3.8511e-02,  3.4796e-02,  6.7318e-02],
                [ 1.9141e-03, -1.8587e-02,  7.3516e-02, -5.1300e-02,  5.7174e-02],
                [-3.4286e-02, -8.8658e-03,  2.8017e-02, -6.1345e-02,  7.5492e-02]],

               [[-3.9274e-03,  5.5152e-02, -7.5673e-02,  2.3418e-02,  2.4453e-02],
                [-4.7555e-02,  3.8749e-02,  7.6313e-02,  1.7370e-02, -3.5157e-02],
                [ 2.5302e-03, -2.2002e-02, -3.2479e-02, -2.4231e-02, -3.4979e-02],
                [ 5.0821e-02,  4.9669e-02, -3.9780e-03, -5.6009e-02, -7.5135e-02],
                [-1.4819e-02, -3.7873e-02,  8.0048e-02,  6.6795e-02,  1.9579e-04]]],


              [[[-6.7337e-03, -3.1325e-02, -3.8130e-04,  2.8326e-02,  5.0600e-02],
                [ 2.1436e-02, -4.8592e-02,  6.1161e-02, -6.3545e-02, -6.6114e-02],
                [ 2.6338e-03,  5.4122e-02, -4.5045e-02,  1.2242e-02,  8.8401e-03],
                [-7.7559e-02, -1.9867e-02,  6.5748e-03,  4.9552e-02, -6.9732e-02],
                [-3.9603e-02, -2.5227e-02, -7.4894e-02, -4.6245e-02,  3.7202e-02]],

               [[ 7.4730e-02,  4.4270e-02, -1.6544e-02, -1.4239e-02,  5.9522e-03],
                [ 2.7072e-03, -5.8596e-02,  4.8411e-02,  4.3153e-02, -4.5938e-02],
                [ 7.8193e-02, -6.7088e-02, -6.9651e-02, -6.0037e-02,  2.6229e-02],
                [-7.0300e-03, -3.4991e-02, -7.6215e-02,  9.9203e-03,  6.9492e-02],
                [ 4.9921e-02, -2.0404e-02,  6.9527e-02,  3.0870e-02,  3.4420e-02]],

               [[-1.8141e-02,  4.8283e-02,  1.9330e-02, -4.4328e-02, -1.4779e-02],
                [-5.6117e-02,  7.9438e-02,  6.8914e-03,  2.9340e-02,  5.3000e-02],
                [ 5.4209e-02, -2.0673e-02,  4.3754e-02, -3.3216e-02, -1.8343e-02],
                [ 2.5629e-02,  3.6082e-02,  7.0708e-02,  7.3608e-02,  5.9628e-02],
                [ 6.0885e-02,  1.6643e-02,  1.6415e-02, -1.0011e-02,  7.8816e-02]],

               [[ 7.5320e-02, -5.7523e-02, -4.1853e-02,  1.0916e-02, -3.0991e-02],
                [-6.2630e-02,  1.4596e-02, -1.0427e-02,  3.6875e-02, -1.6136e-03],
                [-4.4583e-02, -3.0317e-02, -2.2016e-02, -6.6638e-02,  6.7848e-02],
                [-5.1430e-02, -3.2587e-02, -1.1150e-03,  3.4666e-03, -3.3550e-02],
                [ 3.4297e-02, -7.6208e-02,  7.9594e-02,  2.8101e-02,  5.8770e-02]],

               [[-7.3166e-02,  6.3813e-02,  5.1557e-02, -8.4318e-03,  6.9600e-02],
                [ 7.9715e-02,  1.2783e-03, -5.8935e-02, -4.6001e-02,  2.3448e-04],
                [-2.1650e-02,  7.1989e-02,  5.8779e-02,  3.0808e-02, -5.0626e-02],
                [-7.0207e-02,  6.4429e-02,  3.2764e-02,  4.4986e-02, -5.6941e-02],
                [-3.9577e-02,  2.8508e-03,  1.5647e-02, -5.7797e-02,  3.3754e-02]],

               [[-3.6933e-02,  6.9224e-02,  6.7209e-02,  1.0161e-02,  2.9785e-02],
                [ 1.1000e-02,  5.7507e-02,  7.7336e-02, -2.3910e-02,  1.2587e-02],
                [-3.8743e-02,  9.7108e-03,  4.9643e-02,  3.2226e-02, -6.9066e-02],
                [ 1.7411e-02, -5.1872e-02, -4.3662e-02, -2.2543e-02,  4.4947e-02],
                [ 5.7543e-02, -2.1366e-02, -2.5460e-02,  4.8669e-02,  4.8413e-02]]]])
      (bias): Normal:
       loc: tensor([0., -0., 0., 0., -0., 0., -0., -0., -0., 0., 0., 0., 0., 0., 0., 0.])
       scale: tensor([0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500,
              0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500, 0.2500])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([ 0.0396, -0.0348,  0.0458,  0.0297, -0.0694,  0.0213, -0.0060, -0.0400,
              -0.0470,  0.0165,  0.0250,  0.0796,  0.0477,  0.0254,  0.0719,  0.0583])
    )
    (observed): Observed()
  )
  (fc1): Linear(
    in_features=400, out_features=120, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[-0.0091,  0.0136, -0.0051,  ..., -0.0231,  0.0462, -0.0289],
              [-0.0265, -0.0483, -0.0282,  ..., -0.0270,  0.0013,  0.0135],
              [-0.0033, -0.0453,  0.0016,  ..., -0.0309, -0.0074, -0.0022],
              ...,
              [ 0.0203,  0.0274,  0.0125,  ..., -0.0139,  0.0251,  0.0211],
              [-0.0178,  0.0492, -0.0229,  ..., -0.0063, -0.0401,  0.0024],
              [-0.0033, -0.0229,  0.0245,  ...,  0.0420,  0.0423, -0.0412]],
             requires_grad=True)
       tensor: tensor([[ 0.0163,  0.0648,  0.0315,  ..., -0.0154,  0.1702, -0.0326],
              [-0.0761, -0.0296, -0.0542,  ..., -0.0453,  0.0576, -0.0302],
              [ 0.0159, -0.0278, -0.0281,  ..., -0.0069,  0.0057, -0.0484],
              ...,
              [-0.0042,  0.0008,  0.0761,  ..., -0.0470,  0.0703,  0.0368],
              [-0.0550,  0.1441,  0.0373,  ...,  0.0340, -0.0733,  0.0047],
              [-0.1526, -0.0554, -0.0007,  ...,  0.0040,  0.0825, -0.0662]],
             grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498, 0.0498, 0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0058, -0.0242,  0.0295, -0.0301, -0.0095, -0.0331, -0.0259, -0.0153,
               0.0069,  0.0327,  0.0171, -0.0467, -0.0340,  0.0456,  0.0205,  0.0115,
               0.0437, -0.0400, -0.0337,  0.0402,  0.0005,  0.0498,  0.0019, -0.0428,
               0.0204, -0.0429, -0.0188, -0.0454, -0.0102,  0.0484, -0.0102,  0.0098,
              -0.0489, -0.0116,  0.0008,  0.0365,  0.0245,  0.0127,  0.0432,  0.0170,
               0.0064,  0.0087, -0.0159, -0.0008,  0.0249,  0.0176,  0.0213,  0.0199,
              -0.0411, -0.0185,  0.0361,  0.0165, -0.0366,  0.0011,  0.0231, -0.0365,
              -0.0065,  0.0200, -0.0402, -0.0289,  0.0112,  0.0014, -0.0048,  0.0469,
               0.0315, -0.0378,  0.0392, -0.0483,  0.0469,  0.0200, -0.0357,  0.0011,
               0.0031,  0.0180,  0.0359,  0.0056,  0.0218, -0.0161, -0.0221, -0.0196,
              -0.0212,  0.0186,  0.0415, -0.0463, -0.0143, -0.0351,  0.0418,  0.0125,
              -0.0339,  0.0268,  0.0286, -0.0028,  0.0439, -0.0283, -0.0468,  0.0295,
              -0.0397, -0.0159,  0.0498, -0.0395,  0.0458, -0.0338,  0.0416, -0.0442,
              -0.0284,  0.0464,  0.0410, -0.0138,  0.0295, -0.0449, -0.0248,  0.0475,
              -0.0433, -0.0157,  0.0214,  0.0104, -0.0227, -0.0410,  0.0298,  0.0196],
             requires_grad=True)
       tensor: tensor([-1.3555e-01,  2.0939e-02, -1.0978e-02,  2.2384e-02, -3.5042e-02,
              -2.3534e-02,  1.4926e-02,  1.2443e-02,  9.3025e-04, -9.5893e-03,
              -6.1436e-03, -2.0465e-02, -6.8778e-03,  4.9478e-02,  3.5926e-02,
              -4.3168e-02,  5.8348e-02, -9.2632e-02, -8.5259e-02,  6.1411e-02,
               3.5284e-02,  9.3958e-03,  2.7893e-02, -3.3541e-02, -7.2413e-03,
              -2.7861e-02, -2.7886e-02, -2.5454e-02,  2.0168e-02,  5.8621e-02,
               3.8074e-03, -3.6508e-02, -6.4902e-02, -2.8883e-02, -7.6834e-04,
              -7.5272e-03,  6.1758e-03,  7.4719e-02,  2.9627e-02, -1.8554e-02,
               5.9529e-02,  1.5179e-01,  6.4342e-02,  7.4196e-04, -8.1407e-03,
               8.9840e-03,  6.9436e-02, -3.7395e-02, -9.3132e-03, -5.0251e-03,
               6.5105e-02, -1.0527e-02,  3.9077e-02, -1.3772e-02, -4.9026e-02,
               1.2038e-02, -4.4850e-02,  9.1619e-02, -1.1016e-02, -6.6605e-04,
               4.6487e-02, -9.2120e-03, -4.6353e-02,  9.6699e-03, -3.4516e-02,
              -1.3860e-02,  5.3031e-02, -5.4605e-02,  6.1908e-02,  4.8374e-02,
              -2.9948e-02,  6.3221e-02,  1.0043e-02, -1.6848e-02, -3.4269e-02,
               3.2234e-03,  4.0958e-03, -1.0895e-04, -7.2229e-02,  1.4132e-03,
              -1.5699e-02, -1.1046e-02, -3.4086e-02, -2.1041e-02,  2.7204e-02,
              -8.7042e-02,  7.2495e-02,  6.3488e-02, -2.4440e-02,  6.2661e-02,
               2.7276e-02,  7.7142e-02, -9.2377e-03, -8.9469e-02, -8.0198e-02,
              -5.0707e-02, -1.9137e-02, -4.0547e-02,  1.1058e-01, -1.1626e-01,
               1.4978e-02, -8.0421e-02,  2.4994e-02, -8.8205e-02, -3.4078e-02,
               4.4075e-02,  1.1147e-03, -1.4480e-02,  6.0224e-02, -6.7068e-02,
              -1.2512e-01,  7.6667e-02,  3.1673e-02, -6.0719e-02,  5.4334e-02,
              -3.1811e-03, -1.9614e-01, -2.6915e-02, -7.9046e-03, -3.3319e-02],
             grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[-0., 0., -0.,  ..., -0., 0., -0.],
              [-0., -0., -0.,  ..., -0., 0., 0.],
              [-0., -0., 0.,  ..., -0., -0., -0.],
              ...,
              [0., 0., 0.,  ..., -0., 0., 0.],
              [-0., 0., -0.,  ..., -0., -0., 0.],
              [-0., -0., 0.,  ..., 0., 0., -0.]])
       scale: tensor([[0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              ...,
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500],
              [0.0500, 0.0500, 0.0500,  ..., 0.0500, 0.0500, 0.0500]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[-0.0091,  0.0136, -0.0051,  ..., -0.0231,  0.0462, -0.0289],
              [-0.0265, -0.0483, -0.0282,  ..., -0.0270,  0.0013,  0.0135],
              [-0.0033, -0.0453,  0.0016,  ..., -0.0309, -0.0074, -0.0022],
              ...,
              [ 0.0203,  0.0274,  0.0125,  ..., -0.0139,  0.0251,  0.0211],
              [-0.0178,  0.0492, -0.0229,  ..., -0.0063, -0.0401,  0.0024],
              [-0.0033, -0.0229,  0.0245,  ...,  0.0420,  0.0423, -0.0412]])
      (bias): Normal:
       loc: tensor([-0., -0., 0., -0., -0., -0., -0., -0., 0., 0., 0., -0., -0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0., -0.,
              0., -0., -0., -0., -0., 0., -0., 0., -0., -0., 0., 0., 0., 0., 0., 0., 0., 0., -0., -0., 0., 0., 0., 0.,
              -0., -0., 0., 0., -0., 0., 0., -0., -0., 0., -0., -0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., -0., 0.,
              0., 0., 0., 0., 0., -0., -0., -0., -0., 0., 0., -0., -0., -0., 0., 0., -0., 0., 0., -0., 0., -0., -0., 0.,
              -0., -0., 0., -0., 0., -0., 0., -0., -0., 0., 0., -0., 0., -0., -0., 0., -0., -0., 0., 0., -0., -0., 0., 0.])
       scale: tensor([0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913, 0.0913,
              0.0913, 0.0913, 0.0913])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0058, -0.0242,  0.0295, -0.0301, -0.0095, -0.0331, -0.0259, -0.0153,
               0.0069,  0.0327,  0.0171, -0.0467, -0.0340,  0.0456,  0.0205,  0.0115,
               0.0437, -0.0400, -0.0337,  0.0402,  0.0005,  0.0498,  0.0019, -0.0428,
               0.0204, -0.0429, -0.0188, -0.0454, -0.0102,  0.0484, -0.0102,  0.0098,
              -0.0489, -0.0116,  0.0008,  0.0365,  0.0245,  0.0127,  0.0432,  0.0170,
               0.0064,  0.0087, -0.0159, -0.0008,  0.0249,  0.0176,  0.0213,  0.0199,
              -0.0411, -0.0185,  0.0361,  0.0165, -0.0366,  0.0011,  0.0231, -0.0365,
              -0.0065,  0.0200, -0.0402, -0.0289,  0.0112,  0.0014, -0.0048,  0.0469,
               0.0315, -0.0378,  0.0392, -0.0483,  0.0469,  0.0200, -0.0357,  0.0011,
               0.0031,  0.0180,  0.0359,  0.0056,  0.0218, -0.0161, -0.0221, -0.0196,
              -0.0212,  0.0186,  0.0415, -0.0463, -0.0143, -0.0351,  0.0418,  0.0125,
              -0.0339,  0.0268,  0.0286, -0.0028,  0.0439, -0.0283, -0.0468,  0.0295,
              -0.0397, -0.0159,  0.0498, -0.0395,  0.0458, -0.0338,  0.0416, -0.0442,
              -0.0284,  0.0464,  0.0410, -0.0138,  0.0295, -0.0449, -0.0248,  0.0475,
              -0.0433, -0.0157,  0.0214,  0.0104, -0.0227, -0.0410,  0.0298,  0.0196])
    )
    (observed): Observed()
  )
  (fc2): Linear(
    in_features=120, out_features=10, bias=True
    (posterior): Normal(
      (weight): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([[0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              ...,
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498],
              [0.0498, 0.0498, 0.0498,  ..., 0.0498, 0.0498, 0.0498]],
             grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([[ 0.0398,  0.0313, -0.0450,  ...,  0.0049, -0.0503, -0.0522],
              [ 0.0619,  0.0350, -0.0323,  ...,  0.0821, -0.0089, -0.0656],
              [-0.0563,  0.0827, -0.0371,  ..., -0.0067, -0.0121, -0.0045],
              ...,
              [ 0.0112,  0.0345,  0.0310,  ..., -0.0570, -0.0379,  0.0196],
              [ 0.0261,  0.0156,  0.0901,  ...,  0.0250,  0.0211,  0.0911],
              [ 0.0547,  0.0208,  0.0781,  ...,  0.0858,  0.0083, -0.0799]],
             requires_grad=True)
       tensor: tensor([[ 8.0554e-03, -4.1716e-05,  2.3532e-02,  ...,  3.5975e-02,
               -5.7889e-02, -6.4972e-02],
              [ 7.9123e-02,  1.1317e-01, -1.3441e-02,  ...,  8.0784e-02,
               -5.1207e-02, -1.2873e-01],
              [-9.2812e-03,  1.8911e-02,  3.3624e-02,  ..., -1.5495e-02,
                6.0509e-02, -1.7716e-02],
              ...,
              [ 4.3931e-02,  8.3967e-02,  5.6798e-02,  ..., -1.7927e-02,
                3.2743e-03,  7.3954e-02],
              [-1.2302e-03, -3.4438e-02,  1.3092e-01,  ...,  6.6330e-03,
                3.5277e-02,  9.5786e-02],
              [ 8.5259e-02,  2.0223e-02, -1.3733e-02,  ...,  6.2607e-02,
               -2.0986e-02, -8.0120e-02]], grad_fn=<AddBackward0>)
      (bias): Normal:
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       scale: Transform:
       tensor([0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498, 0.0498,
              0.0498], grad_fn=<ExpBackward0>)
       loc: Parameter containing:
      tensor([-0.0608, -0.0649, -0.0132,  0.0874, -0.0895, -0.0182, -0.0491, -0.0399,
               0.0076, -0.0610], requires_grad=True)
       tensor: tensor([-0.0337,  0.0438, -0.0347,  0.1284, -0.1277, -0.0252, -0.0512, -0.0522,
               0.0990, -0.1157], grad_fn=<AddBackward0>)
    )
    (prior): Module(
      (weight): Normal:
       loc: tensor([[0., 0., -0.,  ..., 0., -0., -0.],
              [0., 0., -0.,  ..., 0., -0., -0.],
              [-0., 0., -0.,  ..., -0., -0., -0.],
              ...,
              [0., 0., 0.,  ..., -0., -0., 0.],
              [0., 0., 0.,  ..., 0., 0., 0.],
              [0., 0., 0.,  ..., 0., 0., -0.]])
       scale: tensor([[0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              ...,
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913],
              [0.0913, 0.0913, 0.0913,  ..., 0.0913, 0.0913, 0.0913]])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([[ 0.0398,  0.0313, -0.0450,  ...,  0.0049, -0.0503, -0.0522],
              [ 0.0619,  0.0350, -0.0323,  ...,  0.0821, -0.0089, -0.0656],
              [-0.0563,  0.0827, -0.0371,  ..., -0.0067, -0.0121, -0.0045],
              ...,
              [ 0.0112,  0.0345,  0.0310,  ..., -0.0570, -0.0379,  0.0196],
              [ 0.0261,  0.0156,  0.0901,  ...,  0.0250,  0.0211,  0.0911],
              [ 0.0547,  0.0208,  0.0781,  ...,  0.0858,  0.0083, -0.0799]])
      (bias): Normal:
       loc: tensor([-0., -0., -0., 0., -0., -0., -0., -0., 0., -0.])
       scale: tensor([0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162, 0.3162,
              0.3162])
       posterior: Automatic()
       prior: Module()
       observed: Observed()
       tensor: tensor([-0.0608, -0.0649, -0.0132,  0.0874, -0.0895, -0.0182, -0.0491, -0.0399,
               0.0076, -0.0610])
    )
    (observed): Observed()
  )
)

You just have to define the forward function, and the backward function (where gradients are computed) is automatically defined for you using autograd. You can use any of the Tensor operations in the forward function.

The learnable parameters of a model are returned by net.parameters()

params = list(net.parameters())
print(len(params))
print(params[0].size())

Out:

16
torch.Size([6, 1, 5, 5])

Let try a random 32x32 input

input = torch.randn(1, 1, 32, 32)
out = net(input)
print(out)

Out:

tensor([[-0.3673,  0.0775, -0.4730,  0.1129, -0.8138, -0.3398,  0.2442,  0.6617,
         -0.4352, -0.4799]], grad_fn=<AddmmBackward0>)

Zero the gradient buffers of all parameters and backprops with random gradients:

net.zero_grad()
out.backward(torch.randn(1, 10))

Note

borch.nn only supports mini-batches. The entire borch.nn package only supports inputs that are a mini-batch of samples, and not a single sample.

For example, nn.Conv2d will take in a 4D Tensor of nSamples x nChannels x Height x Width.

If you have a single sample, just use input.unsqueeze(0) to add a fake batch dimension.

Before proceeding further, let’s recap all the classes you’ve seen so far.

Recap:
  • torch.Tensor - A multi-dimensional array with support for autograd operations like backward(). Also holds the gradient w.r.t. the tensor.

  • nn.Module - Neural network module. Convenient way of encapsulating parameters, with helpers for moving them to GPU, exporting, loading, etc.

  • nn.Parameter - A kind of Tensor, that is automatically registered as a parameter when assigned as an attribute to a Module.

  • autograd.Function - Implements forward and backward definitions of an autograd operation. Every Tensor operation, creates at least a single Function node, that connects to functions that created a Tensor and encodes its history.

At this point, we covered:
  • Defining a neural network

  • Processing inputs and calling backward

Still Left:
  • Computing the loss

  • Updating the weights of the network

Loss Function

A loss function takes the (output, target) pair of inputs, and computes a value that estimates how far away the output is from the target.

There are several different loss functions under the nn package . A simple loss is: nn.MSELoss which computes the mean-squared error between the input and the target. They are how ever only equivalent to an maximum likelihood approach in deep learning.

In order to infer the posterior of the weights and thus capture the uncertainty of the weights as well, we have to use the infer package. In this example we will use infer.vi_loss function that automatically creates the best loss function for variational inference given the latent variables in your model.

Similar to how it’s done for random varibles, we can also observe on the module using keyword arguments matching the names of the random variables we want to observe. This will add those random variables to the likelihood term and we will not infer the distribution over it. For example:

target = torch.randint(10, (1,))  # a dummy target, for example
net.observe(classification=target)
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net))
print(loss)

Out:

tensor(9575.6943, grad_fn=<AddBackward0>)

Now, if you would follow loss in the backward direction you will see a graph of computations that looks like this:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d
      -> view -> linear -> relu -> linear ->
      -> loss

So, when we call loss.backward(), the whole graph is differentiated w.r.t. the loss, and all Tensors in the graph that has requires_grad=True will have their .grad Tensor accumulated with the gradient.

Backprop

To backpropagate the error all we have to do is to loss.backward(). You need to clear the existing gradients though, else gradients will be accumulated to existing gradients.

Now we shall call loss.backward(), and have a look at conv1’s bias gradients before and after the backward.

net.zero_grad()  # zeroes the gradient buffers of all parameters

The value for the loc paramater of the approximating distribution of conv1.bias zeroing the gradients is

print(net.conv1.posterior.bias.loc.grad)

loss.backward()

Out:

tensor([0., 0., 0., 0., 0., 0.])

after calling backward the value is

print(net.conv1.posterior.bias.loc.grad)

Out:

tensor([-0.5059,  0.1341,  0.3560, -0.1543, -0.8424,  0.7598])

The only thing left to learn is:

  • Updating the weights of the network

Update the weights

The simplest update rule used in practice is the Stochastic Gradient Descent (SGD):

weight = weight - learning_rate * gradient

We can implement this using simple python code:

learning_rate = 0.01
for f in net.parameters():
    f.data.sub_(f.grad.data * learning_rate)

However, as you use neural networks, you want to use various different update rules such as SGD, Nesterov-SGD, Adam, RMSProp, etc. To enable this, torch built a small package: torch.optim that implements all these methods. Using it is very simple:

import torch.optim as optim

# create your optimizer
optimizer = optim.SGD(net.parameters(), lr=0.01)

# in your training loop:
n_batch_epoch = 10  # number of batches per epoch usually len(dataloader)
optimizer.zero_grad()  # zero the gradient buffers
borch.sample(net)
output = net(input)
loss = infer.vi_loss(**borch.pq_to_infer(net), kl_scaling=1 / n_batch_epoch)
loss.backward()
optimizer.step()  # Does the update

Exercises

  1. The neural network package contains various modules and loss functions that form the building blocks of deep neural networks. Have a look at the documentation to see what is available.

  2. Try designing yor own feed forward networks with two different types of non lineareties ex. relu

Total running time of the script: ( 0 minutes 0.144 seconds)

Gallery generated by Sphinx-Gallery