Using DALI in PyTorch Lightning¶
This example shows how to use DALI in PyTorch Lightning.
Let us grab a toy example showcasing a classification network and see how DALI can accelerate it.
The DALI_EXTRA_PATH environment variable should point to a DALI extra copy. Please make sure that the proper release tag, the one associated with your DALI version, is checked out.
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning import Trainer, LightningModule
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from import DataLoader
import os
# workaround for - error 403 when downloading mnist dataset
import urllib
opener = urllib.request.build_opener()
opener.addheaders = [("User-agent", "Mozilla/5.0")]
We will start by implement a training class that uses the native data loader
class LitMNIST(LightningModule):
def __init__(self):
# mnist images are (1, 28, 28) (channels, width, height)
self.layer_1 = torch.nn.Linear(28 * 28, 128)
self.layer_2 = torch.nn.Linear(128, 256)
self.layer_3 = torch.nn.Linear(256, 10)
def forward(self, x):
batch_size, channels, width, height = x.size()
# (b, 1, 28, 28) -> (b, 1*28*28)
x = x.view(batch_size, -1)
x = self.layer_1(x)
x = F.relu(x)
x = self.layer_2(x)
x = F.relu(x)
x = self.layer_3(x)
x = F.log_softmax(x, dim=1)
return x
def process_batch(self, batch):
return batch
def training_step(self, batch, batch_idx):
x, y = self.process_batch(batch)
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def cross_entropy_loss(self, logits, labels):
return F.nll_loss(logits, labels)
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
def prepare_data(self):
# download data only
self.mnist_train = MNIST(
os.getcwd(), train=True, download=True, transform=transforms.ToTensor()
def setup(self, stage=None):
# transforms for images
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
self.mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64, num_workers=8, pin_memory=True)
And see how it works
model = LitMNIST()
trainer = Trainer(max_epochs=5, devices=1, accelerator="gpu")
# ddp work only in no-interactive mode, to test it unncoment and run as a script
# trainer = Trainer(devices=8, accelerator="gpu", strategy="ddp", max_epochs=5)
## MNIST data set is not always available to download due to network issues
## to run this part of example either uncomment below line
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
The next step is to define a DALI pipeline that will be used for loading and pre-processing data.
import nvidia.dali as dali
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy
# Path to MNIST dataset
data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/")
def GetMnistPipeline(device, shard_id=0, num_shards=1):
jpegs, labels = fn.readers.caffe2(
path=data_path, shard_id=shard_id, num_shards=num_shards, random_shuffle=True, name="Reader"
images = fn.decoders.image(
jpegs, device="mixed" if device == "gpu" else "cpu", output_type=types.GRAY
images = fn.crop_mirror_normalize(
images, dtype=types.FLOAT, std=[0.3081 * 255], mean=[0.1307 * 255], output_layout="CHW"
if device == "gpu":
labels = labels.gpu()
# PyTorch expects labels as INT64
labels = fn.cast(labels, dtype=types.INT64)
return images, labels
Now we are ready to modify the training class to use the DALI pipeline we have just defined. Because we want to integrate with PyTorch, we wrap our pipeline with a PyTorch DALI iterator, that can replace the native data loader with some minor changes in the code. The DALI iterator returns a list of dictionaries, where each element in the list corresponds to a pipeline instance, and the entries in the dictionary map to the outputs of the pipeline.
For more information, check the documentation of DALIGenericIterator.
def __init__(self):
def prepare_data(self):
# no preparation is needed in DALI
def setup(self, stage=None):
device_id = self.local_rank
shard_id = self.global_rank
num_shards = self.trainer.world_size
mnist_pipeline = GetMnistPipeline(
self.train_loader = DALIClassificationIterator(
mnist_pipeline, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL
def train_dataloader(self):
return self.train_loader
def process_batch(self, batch):
x = batch[0]["data"]
y = batch[0]["label"].squeeze(-1)
return (x, y)
We can now run the training
# Even if previous Trainer finished his work it still keeps the GPU booked, force it to release the device.
if "PL_TRAINER_GPUS" in os.environ:
model = DALILitMNIST()
trainer = Trainer(max_epochs=5, devices=1, accelerator="gpu", num_sanity_val_steps=0)
# ddp work only in no-interactive mode, to test it unncoment and run as a script
# trainer = Trainer(devices=8, accelerator="gpu", strategy="ddp", max_epochs=5)
| Name | Type | Params
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33.0 K
2 | layer_3 | Linear | 2.6 K
136 K Trainable params
0 Non-trainable params
136 K Total params
0.544 Total estimated model params size (MB)
For even better integration, we can provide a custom DALI iterator wrapper so that no extra processing is required inside LitMNIST.process_batch
. Also, PyTorch can learn the size of the dataset this way.
class BetterDALILitMNIST(LitMNIST):
def __init__(self):
def prepare_data(self):
# no preparation is needed in DALI
def setup(self, stage=None):
device_id = self.local_rank
shard_id = self.global_rank
num_shards = self.trainer.world_size
mnist_pipeline = GetMnistPipeline(
class LightningWrapper(DALIClassificationIterator):
def __init__(self, *kargs, **kvargs):
super().__init__(*kargs, **kvargs)
def __next__(self):
out = super().__next__()
# DDP is used so only one pipeline per process
# also we need to transform dict returned by DALIClassificationIterator to iterable
# and squeeze the lables
out = out[0]
return [out[k] if k != "label" else torch.squeeze(out[k]) for k in self.output_map]
self.train_loader = LightningWrapper(
mnist_pipeline, reader_name="Reader", last_batch_policy=LastBatchPolicy.PARTIAL
def train_dataloader(self):
return self.train_loader
Let us run the training one more time
# Even if previous Trainer finished his work it still keeps the GPU booked, force it to release the device.
if "PL_TRAINER_GPUS" in os.environ:
model = BetterDALILitMNIST()
trainer = Trainer(max_epochs=5, devices=1, accelerator="gpu", num_sanity_val_steps=0)
# ddp work only in no-interactive mode, to test it unncoment and run as a script
# trainer = Trainer(devices=8, accelerator="gpu", strategy="ddp", max_epochs=5)
