Published on

PyTorch LightningCLI: Blueprint for a Scalable Training Pipeline

In this post, I provide a guide to the PyTorch Lightning framework and present a blueprint for a training pipeline, using image classification as an example. I’ll start with a common issue in dataset preparation: splitting a dataset into training and validation subsets while applying different transformations to each, a problem typically encountered when using PyTorch’s built-in utilities. This is more of a personal preference and not a crucial element of the post, so feel free to skip it. In the next section, we’ll explore how Lightning organizes training code by separating data and model logic, making it scalable and maintainable. However, this separation introduces challenges when the model requires information from the data pipeline, such as the number of classes or class balancing weights for loss functions. I explain how to resolve these issues by leveraging LightningCLI’s argument linking, which allows seamless parameter sharing between modules. Finally, the post culminates in a comprehensive blueprint for building a scalable, maintainable Lightning-based training system.

We'll cover the following sections:

The split problem

For years, I’ve been using PyTorch dataset classes for various problems, and when working on computer vision tasks, my typical approach has been to load the data into a single dataset and then split it into train, validation, and test sets. This is quite straightforward using PyTorch’s torch.utils.data.random_split function. However, this method has some limitation: it returns a Subset class, which contains a subset of indices and a reference to the original dataset:

class Subset(Dataset[T_co]):
    """
    Subset of a dataset at specified indices.

    Args:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """

    dataset: Dataset[T_co]
    indices: Sequence[int]

    def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices

    def __getitem__(self, idx):
        if isinstance(idx, list):
            return self.dataset[[self.indices[i] for i in idx]]
        return self.dataset[self.indices[idx]]

    def __len__(self):
        return len(self.indices)

This implies that you cannot define different augmentation transforms for the individual splits – as they share the parent dataset. Thus, the workaround for this is that one have to create a wrapper class around the train and validation splits with the desired transforms:

import torch
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import numpy as np

# This contain the whole dataset -> no transforms applied here
class CompleteDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample

    def __len__(self):
        return len(self.data)

# Wrapper class to apply different transforms to the splits
class SubsetTransformWrapper(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        sample = self.subset[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return len(self.subset)

# Data and transforms
data = np.random.rand(100, 3, 32, 32)  # Example data (100 images, 3 channels, 32x32)

train_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip()])
val_transform = transforms.Compose([transforms.ToTensor()])

# Creating the dataset
dataset = CompleteDataset(data)

# Splitting into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_subset, val_subset = random_split(dataset, [train_size, val_size])

# Wrapping the subsets with the desired transforms
train_dataset = SubsetTransformWrapper(train_subset, transform=train_transform)
val_dataset = SubsetTransformWrapper(val_subset, transform=val_transform)

# Loading data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Iterate through a batch of train and validation samples
train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))

print(train_batch.size(), val_batch.size())

The following figure illustrates this approach:

Although this solution works, it feels unnecessarily cumbersome. I’d like to avoid adding this extra wrapper class and seek a more elegant approach. To achieve this, we should handle the data loading outside PyTorch Dataset class and avoid relying on its built-in splitting function – as we will see later in the blueprint. This is more of a personal preference than a critical issue, but I believe the final LightningDataModule design will offer a valuable/useful template.

PyTorch Lightning and it’s CLI

For some time I’ve been using PyTorch Lightning, which modularizes and streamlines the training process by organizing PyTorch code to eliminate boilerplate and unlock scalability, including distributed training options. It offers several useful features, such as the fast_dev_run flag, which runs just one train-validation-test step instead of a full training cycle. This is invaluable for quickly testing the entire pipeline without committing to hours of training that might fail later. Another helpful flag allows overfitting the model on a small set of samples to ensure the neural network architecture is functional and capable of learning. Furthermore, Lightning simplifies the process by defining two core components: LightningModule for the model and training logic, and LightningDataModule for managing data loaders. This modularization removes the need to manually implement the training loop, which is typically repetitive and prone to errors. After implementing these modules we can train the model with only a few lines of code:

import lightning as L
from modules import MyLightningModule, MyLightningDataModule

model, data_module = MyLightninModule(), MyLightningDataModule()

# Initialize a Trainer
trainer = L.Trainer()
# Fit the model to data
trainer.fit(model, datamodule=data_module)

# Testing the model
trainer.test(model, datamodule=data_module)

Both LightningModule and LightningDataModule come with well-defined methods that structure the training process efficiently. In LightningModule, methods like forward(), training_step(), validation_step(), and test_step() define how data flows through the model and how the training, validation, and test steps are executed. Meanwhile, the configure_optimizers() method sets up the optimizer and learning rate scheduler. The LightningDataModule provides methods like prepare_data() for downloading and preparing data, and setup() for initializing and splitting the dataset. Importantly, these methods are designed to ensure compatibility with distributed training environments, such as multi-GPU or TPU setups. PyTorch Lightning abstracts away the complexity of device management, automatically handling things like data parallelism, gradient synchronization, and checkpointing, allowing one to scale from a single GPU to large clusters with minimal code changes. Specifically, the prepare_data() method is only called once per machine, ensuring that tasks like data downloading or unzipping files are not repeated across multiple GPU processes, saving time and resources. Note that each process requires its own dataloader, which is why we initialize the actual datasets in the setup() method. This ensures that each GPU or process has its own independent access to data, allowing for efficient parallel processing during training without duplicating the data preparation steps performed by prepare_data() those are common to the processes.

CAUTION

prepare_data is called from the main process once per machine. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won’t be available for other processes. For state assignment use the setup() method.

Another thing I find particularly useful is LightningCLI (Lightning’s Command Line Interface) which automates argument parsing. It makes all the LightningModule’s function parameters automatically parsable, eliminating the need to manually handle argument parsing. LightningCLI also supports YAML configs, which I find much cleaner than passing a long list of command-line arguments. Furthermore, it allows using configs and command-line arguments together, so we can define default parameters in a YAML file and override specific ones via the command line. Managing command-line arguments manually can become tedious as we need to track changes in function parameters, but LightningCLI handles this seamlessly by linking the hyperparameters of the lightning modules directly to the CLI and config files.

After this brief detour (which will make sense later hopefully), let’s return to the core question: how should we utilize the LightningDataModule effectively?

The setup(stage=...) method is called before each stage (fit, train, validate, test, predict), primarily to set up the datasets for the corresponding dataloaders, allowing flexibility to adjust data handling as needed for each stage. During the fit stage, both training and validation datasets can be set up jointly, since fit covers the entire train-validate loop – making it the ideal place to handle the training-validation splitting. So, we already know that downloading data to the local machine should be handled in the prepare_data() method, while the actual dataset construction and splitting mechanisms should be implemented in the setup() method. Now address the issue of how the separated lightning modules pass information to each other.

A common challenge when using LightningDataModule and LightningModule is that key dataset properties, like num_classes in classification tasks, are needed when building the model architecture. However, since Lightning separates data processing and model definition into different modules, passing this information between them can become complicated, especially when using LightningCLI. An initial solution might be to compute such properties in the setup() method of the DataModule and access them in the model’s setup() or through hooks like on_fit_start(). However, this approach is suboptimal, as setup() in LightningModule is not intended for model initialization, as discussed in this GitHub issue. Using setup() for model initialization causes checkpoint loading issues because load_from_checkpoint() expects the model layers to be defined during instantiation. The recommended solution is to define layers in the __init__() method rather than in setup().

To avoid this trap, the recommended approach is to use link_arguments in LightningCLI, allowing num_classes or other parameters computed in the DataModule to be passed to the LightningModule during instantiation. However, since prepare_data() runs after __init__(), parameters like num_classes might not be available in time. As a workaround by defining the linked parameters as a @property, you ensure it is “lazily” computed when first accessed, guaranteeing it’s initialized properly before model construction – as discussed here.

Putting it all together

Knowing all this we can construct the final design template. Let’s take the case of image classification where you have images in different folders representing different categories. A blueprint might look like this for the datamodule:

data.py
import lightning as L
from torch.utils.data import Dataset, DataLoader

# Single torch dataset with transforms
class TorchDataset(Dataset):
    def __init__(self, data, transform=None):
        self.transform = transform
        self.data = data

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

    def __len__(self):
        return len(self.data)

# LightningDataModule implementation
class ImageClassificationModule(L.LightningDataModule):
    def __init__(self, data_dir, train_transforms, val_transforms, data_splits=(0.8, 0.15, 0.15), random_seed=42):
        super().__init__()
        self.data_dir = data_dir
        self.train_transform = train_transforms
        self.val_transforms = val_transforms
        self.data_splits = data_splits
        self.random_seed = random_seed

        self.all_data = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

        self._num_classes = None
        self._class_weights = None

    @property
    def num_classes(self):
		    # Ensure data is prepared if not already
        if self._num_classes is None:
            self.prepare_data()  # Ensure data is downloaded and prepared
            self._num_classes = infer_num_classes(self.data_dir)  # Analyse data and infer number of classes
        return self._num_classes

    @property
    def class_weights(self):
		    # Ensure data is prepared if not already
        if self._class_weights is None:
            self.prepare_data()  # Ensure data is downloaded and prepared
            self._class_weights = calculate_class_weigths(self.data_dir) # Analyse data and calculate class weigths
        return self._class_weights

    def prepare_data(self):
        # Download and unzip data if it is not yet on the machine
        ...

    def setup(self, stage: str = None):
        # Create train-val-test split with our custom function
        self.all_data = load_all_data(self.data_dir)
        train_data, val_data, test_data = random_split_data(self.all_data, self.data_splits, self.random_seed)

        if stage == "fit" or stage is None:
            self.train_dataset = TorchDataset(train_data, self.train_transforms)
            self.val_dataset = TorchDataset(val_data, self.val_transforms)

        if stage == "test" or stage is None:
            self.test_dataset = TorchDataset(test_data, self.val_transforms)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=32, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=32)

You can see that here we use custom data loading and splitting logic inside the setup() method and instantiate PyTorch Dataclass with the splits and corresponding transforms to avoid the need for an extra wrapper class as discussed in the first section. Furthermore, lazy linked-parameter initialization with manual prepare_data() call inside the property ensures the data is prepared before it calculates the desired parameter values – those are then linked to the LightningModule right after instantiation. You might stop here and ask: the whole point of the prepare_data() method was to run only once per machine, and now we are calling it multiple times through the properties, what's going on? Typically, well-written data downloading scripts handle cases where the data is already available, ensuring the operation is only performed once. This way, we ensure that linked parameters like num_classes are dynamically initialized when needed, without any data preparation overhead during repeated calls.

The corresponding LightningModule will be straightforward, as we’ll rely on LightningCLI to link parameters from the DataModule dynamically:

main.py
# Use LightningCLI to link arguments
class MyLightningCLI(L.LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
        parser.link_arguments("data.class_weights", "model.init_args.class_weights", apply_on="instantiate")

# Running LightningCLI with linked arguments
if __name__ == "__main__":
    cli = MyLightningCLI(run=False)
    # Fit model
    cli.trainer.fit(cli.model, datamodule=cli.datamodule)
    # Run the test
    cli.trainer.test(ckpt_path='best', datamodule=cli.datamodule)

By using the link_arguments feature in LightningCLI, the num_classes and class_weights from the DataModule will be automatically passed to the LightningModule during instantiation – as we set apply_on="instantiate". Thus, the LightningModule would look something like this:

model.py
import lightning as L
import torch.nn as nn
import torch.nn.functional as F

class ImageClassificationModel(L.LightningModule):
    def __init__(self, num_classes, class_weights=None):
        super().__init__()
        self.num_classes = num_classes
        self.class_weights = class_weights
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 64 * 64, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_classes)
        )

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        # Use the class balancing weights linked from the datamodule
        loss = F.cross_entropy(logits, y, weight=self.class_weights)
        return loss

    def validation_step(self, batch, batch_idx):
        ...

If you run the CLI script with --help (e.g., python main.py --model path.to.LightningModule --data path.to.LightningDataModule --help), the linked parameters will be listed but won’t appear as explicit arguments for the model.

IMPORTANT

Note that the linked parameters must be JSON serializable, meaning they should be basic Python types like lists or floats – complex objects like PyTorch tensors cannot be passed as arguments and linked between modules.

Wrapping it up

The key takeaways of this post are:

  • To avoid using extra wrapper classes for dataset splitting and transformations, handle data loading and splitting outside the PyTorch Dataset class.
  • PyTorch Lightning organizes training into LightningModule and LightningDataModule, making the code organized and scalable but introduces a challenge when sharing parameters between modules.
  • Use LightningCLI with link_arguments to seamlessly pass information like num_classes and class_weights from the DataModule to the LightningModule.
  • Put only data download and stateless preparations in prepare_data(); put data loading, dataset construction, and splitting in setup(); and use a @property-lazy-init workaround to ensure that the linked parameters are correctly initialized before link_arguments passes these parameters between the modules.

References