Use Pytorch Lightening

Pytorch-lightening handles a lot of the boring engineering code, allowing one to focus on the research code. See https://www.pytorchlightning.ai/ for more infromation. The main benefits of pytorch-lightening are:

  • Models become hardware agnostic

  • Code is clear to read because engineering code is abstracted away

  • Easier to reproduce

  • Make fewer mistakes because lightning handles the tricky engineering

  • Keeps all the flexibility (LightningModules are still PyTorch modules), but removes a ton of boilerplate

Some of the functionality that is availing right out of the box with minimal setup is
  • 16 bit precision

  • Multi-GPU training

  • Multi-node training

  • Early stopping

  • Model checkpointing

Lets start of with some imports

from borch import nn, infer, distributions
import torch
import pytorch_lightning as pl
import borch
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST
from torchvision import transforms
from pytorch_lightning.callbacks import ModelCheckpoint

Here we want to show how to use pytorch_lightning to create a model for MNIST.

The main challenge with using pytorch-lightning is how to scale the vi_loss the loss function normally looks something like loss = infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/len(dataset)) and specifically in the kl_scaling we need both the length of the data set and of the batch.

In pytorch lightning they abstact away the trainloop and just includes a training_step

def training_step(self, batch, batch_idx):
    (input, target) = batch
    self.net.observe(classification=target)
    borch.sample(self.net)
    self.net(input)
    return infer.vi_loss(**borch.pq_to_infer(self.net))

The issue here is that we don’t get access to the data loader or the data set, just the batch. One can of course set the dataset as a global variable and access it to get the length. But by doing so we loose a generality of the model to handle different data sets etc.

Instead we have introduced a data set AddDatasetLength in to borch, it simply wraps any torch.utils.data.Dataset and make it returns the length of the dataset with each element.

from borch.utils.data import AddDatasetLength
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
    ])
mnist= MNIST('mnist',train=True, download=True, transform=transform)
data_set = AddDatasetLength(mnist)
data_loader = DataLoader(data_set, batch_size=3)
print(next(iter(data_loader)))

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 65%|######5   | 6453248/9912422 [00:00<00:00, 64528079.98it/s]
9913344it [00:00, 78541245.70it/s]
Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 100365875.57it/s]
Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
1649664it [00:00, 25331850.52it/s]
Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 32196156.64it/s]
Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw

[tensor([60000, 60000, 60000]), [tensor([[[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]],


        [[[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          ...,
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
          [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]]]]), tensor([5, 0, 4])]]

Now when we have access to the length of the dataset we can actually conduct a training_step using borch and pytorch-lightening.

def training_step(self, batch, batch_idx):
    ds_len, (x, target) = batch
    self.net.observe(classification=target)
    loss = 0
    for _ in range(self.subsamples):
        borch.sample(self)
        self(x)
        loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])
    loss /= self.subsamples
    loss /= x.shape[0]
    self.net.observe(None)
    return loss

In order to keep the example clear, we will separate the network code NeuralNetwork and the pytorch_lightning code LitModel, and simply connect them using self.net = NeuralNetwork() in the init method of the LitModel.

class NeuralNetwork(nn.Module):
    """Our MNIST object detection network"""
    def __init__(self):
        super(NeuralNetwork, self).__init__(posterior=borch.posterior.Automatic())
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
    def forward(self, x):
        """Run trough the network and construct the likelihood"""
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        self.classification = distributions.Categorical(logits=logits)
        return self.classification

class LitModel(pl.LightningModule):
    """The pytorch_lightning module that helps with the training"""
    def __init__(self, subsamples=2, learning_rate=0.001, batch_size=128):
        """Lets just enables some nice hyper parameters"""
        super().__init__()
        self.net = NeuralNetwork()
        self.subsamples = subsamples
        self.learning_rate = learning_rate
        self.batch_size = batch_size
    def forward(self, x):
        """We just run trough the network"""
        return self.net(x)
    def training_step(self, batch, batch_idx, prefix='train'):
        ds_len, (x, target) = batch
        self.net.observe(classification=target)
        loss = 0
        for _ in range(self.subsamples):
            borch.sample(self)
            self(x)
            # We scale the loss with how big part of the dataset we run trough
            loss+= infer.vi_loss(**borch.pq_to_infer(self.net), kl_scaling=x.shape[0]/ds_len[0])
        loss /= self.subsamples
        loss /= x.shape[0]
        self.net.observe(None)
        self.log(f"{prefix}_loss", loss)
        acc = (self.net.prior.classification.tensor == target).long().sum()/len(target)
        self.log(f"{prefix}_accuracy", acc)
        return loss
    def validation_step(self, batch, batch_idx):
        """Use the train step as the validation step"""
        return self.training_step(batch, batch_idx, 'validation')
    def test_step(self, batch, batch_idx):
        """Use the train step as the test step"""
        return self.training_step(batch, batch_idx, 'test')
    def configure_optimizers(self):
        """Set up the optimizers"""
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    def load_state_dict(self, state_dict, strict=False):
        """Change the strict=True default"""
        # Given how we store some intermediate samples we might end up with
        # keys in the statedic missing, and some of the pytorch-lightening functions
        # do not allow us to control the `strict` argument. So the easiest
        # thing is to change the default here
        return super().load_state_dict(state_dict, strict=strict)
    def train_dataloader(self):
        """Train data loader"""
        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        mnist_train = MNIST('data/',train=True, download=True, transform=transform)
        return DataLoader(AddDatasetLength(mnist_train), batch_size=self.batch_size, num_workers=8)
    def val_dataloader(self):
        """Train dataloader as validation data loader"""
        return self.train_dataloader()
    def test_dataloader(self):
        """Validation dataloader as tests data loader"""
        return self.validation_dataloader()

Here we will use the train_accuracy as what to use for the checkpointing but it would be better to ex. use validation_accuracy

checkpoint_callback = ModelCheckpoint(monitor="train_accuracy")

We construct a trainer object that holds all the configuration relating to training see the docs from pytorch-lightening for all the possible settings.

trainer = pl.Trainer(
    fast_dev_run=5, # just run 5 batches as a quick tests
    max_epochs= 5, # maximum epochs, note that `fast_dev_run` will make it stop sooner
    min_epochs = 1, # minimum epochs to run, note that `fast_dev_run` will make it stop sooner
    # precision=16, # use 16 bit precisions, requires that the hardware supports it.
    gradient_clip_val=5, # use gradient clipping
    # gradient_clip_algorithm="value", # gradient clipping algorithm, standard is to clip based on the norm
    # stochastic_weight_avg=True,
    auto_lr_find=True, # automatically find the best learning_rate ("not guaranteed to work)")
    # auto_scale_batch_size= "binsearch", # find as big batch_size as possible before one gets memory errors, thus one can achieve higher GPU utilization
    callbacks=[checkpoint_callback],
)

The first step is to tune the trainer such it can find the best learning_rate and batch_size if that setting is enabled.

model = LitModel()
trainer.tune(model)

Out:

/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/pytorch_lightning/tuner/lr_finder.py:197: UserWarning: Skipping learning rate finder since fast_dev_run is enabled.
  rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.", UserWarning)

{'lr_find': None}

When the tuning is done, we can go over to the fitting.

trainer.fit(model)

Out:

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz

  0%|          | 0/9912422 [00:00<?, ?it/s]
 62%|######2   | 6179840/9912422 [00:00<00:00, 61797427.58it/s]
9913344it [00:00, 73227204.57it/s]
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz

  0%|          | 0/28881 [00:00<?, ?it/s]
29696it [00:00, 99088346.53it/s]
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz

  0%|          | 0/1648877 [00:00<?, ?it/s]
1649664it [00:00, 27051765.27it/s]
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz

  0%|          | 0/4542 [00:00<?, ?it/s]
5120it [00:00, 34805245.51it/s]
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/torch/utils/data/dataloader.py:481: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  cpuset_checked))
/home/docs/checkouts/readthedocs.org/user_builds/borch/envs/latest/lib/python3.7/site-packages/pytorch_lightning/trainer/data_loading.py:429: UserWarning: The number of training samples (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"

Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/10 [00:00<?, ?it/s]
Epoch 0:  10%|#         | 1/10 [00:00<00:05,  1.74it/s]
Epoch 0:  10%|#         | 1/10 [00:00<00:05,  1.74it/s, loss=5.55, v_num=]
Epoch 0:  20%|##        | 2/10 [00:00<00:02,  2.84it/s, loss=5.55, v_num=]
Epoch 0:  20%|##        | 2/10 [00:00<00:02,  2.83it/s, loss=5.65, v_num=]
Epoch 0:  30%|###       | 3/10 [00:00<00:01,  3.60it/s, loss=5.65, v_num=]
Epoch 0:  30%|###       | 3/10 [00:00<00:01,  3.59it/s, loss=5.7, v_num=]
Epoch 0:  40%|####      | 4/10 [00:00<00:01,  4.16it/s, loss=5.7, v_num=]
Epoch 0:  40%|####      | 4/10 [00:00<00:01,  4.16it/s, loss=5.76, v_num=]
Epoch 0:  50%|#####     | 5/10 [00:01<00:01,  4.62it/s, loss=5.76, v_num=]
Epoch 0:  50%|#####     | 5/10 [00:01<00:01,  4.62it/s, loss=5.68, v_num=]

Validating: 0it [00:00, ?it/s]

Validating:   0%|          | 0/5 [00:00<?, ?it/s]

Validating:  20%|##        | 1/5 [00:00<00:02,  1.90it/s]
Epoch 0:  70%|#######   | 7/10 [00:01<00:00,  4.35it/s, loss=5.68, v_num=]

Validating:  60%|######    | 3/5 [00:00<00:00,  5.03it/s]
Epoch 0:  90%|######### | 9/10 [00:01<00:00,  5.06it/s, loss=5.68, v_num=]

Validating: 100%|##########| 5/5 [00:00<00:00,  7.25it/s]
Epoch 0: 100%|##########| 10/10 [00:01<00:00,  5.15it/s, loss=5.68, v_num=]


Epoch 0: 100%|##########| 10/10 [00:01<00:00,  5.15it/s, loss=5.68, v_num=]
Epoch 0: 100%|##########| 10/10 [00:01<00:00,  5.15it/s, loss=5.68, v_num=]

The metrics we logged can now be seen in tensorboard by using the command tensorboard –logdir ./lightning_logs

Remember to load the best checkpoint.

if checkpoint_callback.best_model_path:
    model = LitModel.load_from_checkpoint(checkpoint_callback.best_model_path, strict=False)

Total running time of the script: ( 0 minutes 3.950 seconds)

Gallery generated by Sphinx-Gallery