- Published on
PyTorch LightningCLI: Blueprint for a Scalable Training Pipeline
In this post, I provide a guide to the PyTorch Lightning framework and present a blueprint for a training pipeline, using image classification as an example. I’ll start with a common issue in dataset preparation: splitting a dataset into training and validation subsets while applying different transformations to each, a problem typically encountered when using PyTorch’s built-in utilities. This is more of a personal preference and not a crucial element of the post, so feel free to skip it. In the next section, we’ll explore how Lightning organizes training code by separating data and model logic, making it scalable and maintainable. However, this separation introduces challenges when the model requires information from the data pipeline, such as the number of classes or class balancing weights for loss functions. I explain how to resolve these issues by leveraging LightningCLI’s argument linking, which allows seamless parameter sharing between modules. Finally, the post culminates in a comprehensive blueprint for building a scalable, maintainable Lightning-based training system.
We'll cover the following sections:
The split problem
For years, I’ve been using PyTorch dataset classes for various problems, and when working on computer vision tasks, my typical approach has been to load the data into a single dataset and then split it into train, validation, and test sets. This is quite straightforward using PyTorch’s torch.utils.data.random_split
function. However, this method has some limitation: it returns a Subset
class, which contains a subset of indices and a reference to the original dataset:
class Subset(Dataset[T_co]):
"""
Subset of a dataset at specified indices.
Args:
dataset (Dataset): The whole Dataset
indices (sequence): Indices in the whole set selected for subset
"""
dataset: Dataset[T_co]
indices: Sequence[int]
def __init__(self, dataset: Dataset[T_co], indices: Sequence[int]) -> None:
self.dataset = dataset
self.indices = indices
def __getitem__(self, idx):
if isinstance(idx, list):
return self.dataset[[self.indices[i] for i in idx]]
return self.dataset[self.indices[idx]]
def __len__(self):
return len(self.indices)
This implies that you cannot define different augmentation transforms for the individual splits – as they share the parent dataset. Thus, the workaround for this is that one have to create a wrapper class around the train and validation splits with the desired transforms:
import torch
from torch.utils.data import Dataset, random_split, DataLoader
from torchvision import transforms
import numpy as np
# This contain the whole dataset -> no transforms applied here
class CompleteDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, idx):
sample = self.data[idx]
return sample
def __len__(self):
return len(self.data)
# Wrapper class to apply different transforms to the splits
class SubsetTransformWrapper(Dataset):
def __init__(self, subset, transform=None):
self.subset = subset
self.transform = transform
def __getitem__(self, idx):
sample = self.subset[idx]
if self.transform:
sample = self.transform(sample)
return sample
def __len__(self):
return len(self.subset)
# Data and transforms
data = np.random.rand(100, 3, 32, 32) # Example data (100 images, 3 channels, 32x32)
train_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip()])
val_transform = transforms.Compose([transforms.ToTensor()])
# Creating the dataset
dataset = CompleteDataset(data)
# Splitting into train and validation
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_subset, val_subset = random_split(dataset, [train_size, val_size])
# Wrapping the subsets with the desired transforms
train_dataset = SubsetTransformWrapper(train_subset, transform=train_transform)
val_dataset = SubsetTransformWrapper(val_subset, transform=val_transform)
# Loading data
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
# Iterate through a batch of train and validation samples
train_batch = next(iter(train_loader))
val_batch = next(iter(val_loader))
print(train_batch.size(), val_batch.size())
The following figure illustrates this approach:
Although this solution works, it feels unnecessarily cumbersome. I’d like to avoid adding this extra wrapper class and seek a more elegant approach. To achieve this, we should handle the data loading outside PyTorch Dataset
class and avoid relying on its built-in splitting function – as we will see later in the blueprint. This is more of a personal preference than a critical issue, but I believe the final LightningDataModule design will offer a valuable/useful template.
PyTorch Lightning and it’s CLI
For some time I’ve been using PyTorch Lightning, which modularizes and streamlines the training process by organizing PyTorch code to eliminate boilerplate and unlock scalability, including distributed training options. It offers several useful features, such as the fast_dev_run
flag, which runs just one train-validation-test step instead of a full training cycle. This is invaluable for quickly testing the entire pipeline without committing to hours of training that might fail later. Another helpful flag allows overfitting the model on a small set of samples to ensure the neural network architecture is functional and capable of learning. Furthermore, Lightning simplifies the process by defining two core components: LightningModule
for the model and training logic, and LightningDataModule
for managing data loaders. This modularization removes the need to manually implement the training loop, which is typically repetitive and prone to errors. After implementing these modules we can train the model with only a few lines of code:
import lightning as L
from modules import MyLightningModule, MyLightningDataModule
model, data_module = MyLightninModule(), MyLightningDataModule()
# Initialize a Trainer
trainer = L.Trainer()
# Fit the model to data
trainer.fit(model, datamodule=data_module)
# Testing the model
trainer.test(model, datamodule=data_module)
Both LightningModule
and LightningDataModule
come with well-defined methods that structure the training process efficiently. In LightningModule
, methods like forward()
, training_step()
, validation_step()
, and test_step()
define how data flows through the model and how the training, validation, and test steps are executed. Meanwhile, the configure_optimizers()
method sets up the optimizer and learning rate scheduler. The LightningDataModule
provides methods like prepare_data()
for downloading and preparing data, and setup()
for initializing and splitting the dataset. Importantly, these methods are designed to ensure compatibility with distributed training environments, such as multi-GPU or TPU setups. PyTorch Lightning abstracts away the complexity of device management, automatically handling things like data parallelism, gradient synchronization, and checkpointing, allowing one to scale from a single GPU to large clusters with minimal code changes. Specifically, the prepare_data()
method is only called once per machine, ensuring that tasks like data downloading or unzipping files are not repeated across multiple GPU processes, saving time and resources. Note that each process requires its own dataloader, which is why we initialize the actual datasets in the setup()
method. This ensures that each GPU or process has its own independent access to data, allowing for efficient parallel processing during training without duplicating the data preparation steps performed by prepare_data()
those are common to the processes.
CAUTION
prepare_data
is called from the main process once per machine. It is not recommended to assign state here (e.g. self.x = y
) since it is called on a single process and if you assign states here then they won’t be available for other processes. For state assignment use the setup()
method.
Another thing I find particularly useful is LightningCLI
(Lightning’s Command Line Interface) which automates argument parsing. It makes all the LightningModule’s function parameters automatically parsable, eliminating the need to manually handle argument parsing. LightningCLI also supports YAML configs, which I find much cleaner than passing a long list of command-line arguments. Furthermore, it allows using configs and command-line arguments together, so we can define default parameters in a YAML file and override specific ones via the command line. Managing command-line arguments manually can become tedious as we need to track changes in function parameters, but LightningCLI handles this seamlessly by linking the hyperparameters of the lightning modules directly to the CLI and config files.
After this brief detour (which will make sense later hopefully), let’s return to the core question: how should we utilize the LightningDataModule
effectively?
The setup(stage=...)
method is called before each stage (fit
, train
, validate
, test
, predict
), primarily to set up the datasets for the corresponding dataloaders, allowing flexibility to adjust data handling as needed for each stage. During the fit
stage, both training and validation datasets can be set up jointly, since fit
covers the entire train-validate loop – making it the ideal place to handle the training-validation splitting. So, we already know that downloading data to the local machine should be handled in the prepare_data()
method, while the actual dataset construction and splitting mechanisms should be implemented in the setup()
method. Now address the issue of how the separated lightning modules pass information to each other.
A common challenge when using LightningDataModule
and LightningModule
is that key dataset properties, like num_classes
in classification tasks, are needed when building the model architecture. However, since Lightning separates data processing and model definition into different modules, passing this information between them can become complicated, especially when using LightningCLI
. An initial solution might be to compute such properties in the setup()
method of the DataModule
and access them in the model’s setup()
or through hooks like on_fit_start()
. However, this approach is suboptimal, as setup()
in LightningModule is not intended for model initialization, as discussed in this GitHub issue. Using setup()
for model initialization causes checkpoint loading issues because load_from_checkpoint()
expects the model layers to be defined during instantiation. The recommended solution is to define layers in the __init__()
method rather than in setup()
.
To avoid this trap, the recommended approach is to use link_arguments
in LightningCLI
, allowing num_classes
or other parameters computed in the DataModule
to be passed to the LightningModule
during instantiation. However, since prepare_data()
runs after __init__()
, parameters like num_classes
might not be available in time. As a workaround by defining the linked parameters as a @property
, you ensure it is “lazily” computed when first accessed, guaranteeing it’s initialized properly before model construction – as discussed here.
Putting it all together
Knowing all this we can construct the final design template. Let’s take the case of image classification where you have images in different folders representing different categories. A blueprint might look like this for the datamodule
:
import lightning as L
from torch.utils.data import Dataset, DataLoader
# Single torch dataset with transforms
class TorchDataset(Dataset):
def __init__(self, data, transform=None):
self.transform = transform
self.data = data
def __getitem__(self, idx):
sample = self.data[idx]
if self.transform:
sample = self.transform(sample)
return sample
def __len__(self):
return len(self.data)
# LightningDataModule implementation
class ImageClassificationModule(L.LightningDataModule):
def __init__(self, data_dir, train_transforms, val_transforms, data_splits=(0.8, 0.15, 0.15), random_seed=42):
super().__init__()
self.data_dir = data_dir
self.train_transform = train_transforms
self.val_transforms = val_transforms
self.data_splits = data_splits
self.random_seed = random_seed
self.all_data = None
self.train_dataset = None
self.val_dataset = None
self.test_dataset = None
self._num_classes = None
self._class_weights = None
@property
def num_classes(self):
# Ensure data is prepared if not already
if self._num_classes is None:
self.prepare_data() # Ensure data is downloaded and prepared
self._num_classes = infer_num_classes(self.data_dir) # Analyse data and infer number of classes
return self._num_classes
@property
def class_weights(self):
# Ensure data is prepared if not already
if self._class_weights is None:
self.prepare_data() # Ensure data is downloaded and prepared
self._class_weights = calculate_class_weigths(self.data_dir) # Analyse data and calculate class weigths
return self._class_weights
def prepare_data(self):
# Download and unzip data if it is not yet on the machine
...
def setup(self, stage: str = None):
# Create train-val-test split with our custom function
self.all_data = load_all_data(self.data_dir)
train_data, val_data, test_data = random_split_data(self.all_data, self.data_splits, self.random_seed)
if stage == "fit" or stage is None:
self.train_dataset = TorchDataset(train_data, self.train_transforms)
self.val_dataset = TorchDataset(val_data, self.val_transforms)
if stage == "test" or stage is None:
self.test_dataset = TorchDataset(test_data, self.val_transforms)
def train_dataloader(self):
return DataLoader(self.train_dataset, batch_size=32, shuffle=True)
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=32)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=32)
You can see that here we use custom data loading and splitting logic inside the setup()
method and instantiate PyTorch Dataclass
with the splits and corresponding transforms to avoid the need for an extra wrapper class as discussed in the first section. Furthermore, lazy linked-parameter initialization with manual prepare_data()
call inside the property ensures the data is prepared before it calculates the desired parameter values – those are then linked to the LightningModule
right after instantiation. You might stop here and ask: the whole point of the prepare_data()
method was to run only once per machine, and now we are calling it multiple times through the properties, what's going on? Typically, well-written data downloading scripts handle cases where the data is already available, ensuring the operation is only performed once. This way, we ensure that linked parameters like num_classes
are dynamically initialized when needed, without any data preparation overhead during repeated calls.
The corresponding LightningModule
will be straightforward, as we’ll rely on LightningCLI
to link parameters from the DataModule
dynamically:
# Use LightningCLI to link arguments
class MyLightningCLI(L.LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")
parser.link_arguments("data.class_weights", "model.init_args.class_weights", apply_on="instantiate")
# Running LightningCLI with linked arguments
if __name__ == "__main__":
cli = MyLightningCLI(run=False)
# Fit model
cli.trainer.fit(cli.model, datamodule=cli.datamodule)
# Run the test
cli.trainer.test(ckpt_path='best', datamodule=cli.datamodule)
By using the link_arguments
feature in LightningCLI
, the num_classes
and class_weights
from the DataModule
will be automatically passed to the LightningModule
during instantiation – as we set apply_on="instantiate"
. Thus, the LightningModule
would look something like this:
import lightning as L
import torch.nn as nn
import torch.nn.functional as F
class ImageClassificationModel(L.LightningModule):
def __init__(self, num_classes, class_weights=None):
super().__init__()
self.num_classes = num_classes
self.class_weights = class_weights
self.model = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
nn.Linear(64 * 64 * 64, 256),
nn.ReLU(),
nn.Linear(256, self.num_classes)
)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
# Use the class balancing weights linked from the datamodule
loss = F.cross_entropy(logits, y, weight=self.class_weights)
return loss
def validation_step(self, batch, batch_idx):
...
If you run the CLI script with --help
(e.g., python main.py --model path.to.LightningModule --data path.to.LightningDataModule --help
), the linked parameters will be listed but won’t appear as explicit arguments for the model.
IMPORTANT
Note that the linked parameters must be JSON serializable, meaning they should be basic Python types like lists or floats – complex objects like PyTorch tensors cannot be passed as arguments and linked between modules.
Wrapping it up
The key takeaways of this post are:
- To avoid using extra wrapper classes for dataset splitting and transformations, handle data loading and splitting outside the
PyTorch Dataset
class. - PyTorch Lightning organizes training into
LightningModule
andLightningDataModule
, making the code organized and scalable but introduces a challenge when sharing parameters between modules. - Use LightningCLI with
link_arguments
to seamlessly pass information likenum_classes
andclass_weights
from theDataModule
to theLightningModule
. - Put only data download and stateless preparations in
prepare_data()
; put data loading, dataset construction, and splitting insetup()
; and use a@property
-lazy-init workaround to ensure that the linked parameters are correctly initialized beforelink_arguments
passes these parameters between the modules.