Using DALI in PyTorch Lightning

Overview

This example shows how to use DALI in PyTorch Lightning.

Let us grab a toy example showcasing a classification network and see how DALI can accelerate it.

The DALI_EXTRA_PATH environment variable should point to a DALI extra copy. Please make sure that the proper release tag, the one associated with your DALI version, is checked out.

[1]:
import torch
from torch.nn import functional as F
from torch import nn
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import os

BATCH_SIZE = 64

We will start by implement a training class that uses the native data loader

[2]:
class LitMNIST(LightningModule):

  def __init__(self):
    super().__init__()

    # mnist images are (1, 28, 28) (channels, width, height)
    self.layer_1 = torch.nn.Linear(28 * 28, 128)
    self.layer_2 = torch.nn.Linear(128, 256)
    self.layer_3 = torch.nn.Linear(256, 10)

  def forward(self, x):
    batch_size, channels, width, height = x.size()

    # (b, 1, 28, 28) -> (b, 1*28*28)
    x = x.view(batch_size, -1)
    x = self.layer_1(x)
    x = F.relu(x)
    x = self.layer_2(x)
    x = F.relu(x)
    x = self.layer_3(x)

    x = F.log_softmax(x, dim=1)
    return x

  def process_batch(self, batch):
      return batch

  def training_step(self, batch, batch_idx):
      x, y = self.process_batch(batch)
      logits = self(x)
      loss = F.nll_loss(logits, y)
      return loss

  def cross_entropy_loss(self, logits, labels):
      return F.nll_loss(logits, labels)

  def configure_optimizers(self):
      return Adam(self.parameters(), lr=1e-3)

  def prepare_data(self):# transforms for images
      transform=transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))])
      self.mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)

  def train_dataloader(self):
       return DataLoader(self.mnist_train, batch_size=64, num_workers=8, pin_memory=True)

And see how it works

[3]:
model = LitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K
WARNING: Logging before flag parsing goes to stderr.
I1016 15:34:19.539263 139682004997952 lightning.py:1215]
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K

[3]:
1

The next step is to define a DALI pipeline that will be used for loading and pre-processing data.

[4]:
import nvidia.dali as dali
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
from nvidia.dali.plugin.pytorch import DALIClassificationIterator

# Path to MNIST dataset
data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')

class MnistPipeline(Pipeline):
    def __init__(self, batch_size, device, device_id=0, shard_id=0, num_shards=1, num_threads=4, seed=0):
        super(MnistPipeline, self).__init__(
            batch_size, num_threads, device_id, seed)
        self.device = device
        self.reader = ops.Caffe2Reader(path=data_path, shard_id=shard_id, num_shards=num_shards, random_shuffle=True)
        self.decode = ops.ImageDecoder(
            device='mixed' if device == 'gpu' else 'cpu',
            output_type=types.GRAY)
        self.cmn = ops.CropMirrorNormalize(
            device=device,
            dtype=types.FLOAT,
            std=[0.3081 * 255],
            mean=[0.1307 * 255],
            output_layout="CHW")
        self.to_int64 = ops.Cast(dtype=types.INT64, device=device)

    def define_graph(self):
        inputs, labels = self.reader(name="Reader")
        images = self.decode(inputs)
        images = self.cmn(images)
        if self.device == "gpu":
            labels = labels.gpu()
        # PyTorch expects labels as INT64
        labels = self.to_int64(labels)

        return (images, labels)

Now we are ready to modify the training class to use the DALI pipeline we have just defined. Because we want to integrate with PyTorch, we wrap our pipeline with a PyTorch DALI iterator, that can replace the native data loader with some minor changes in the code. The DALI iterator returns a list dictionaries, where each element in the list corresponds to a pipeline instance, and the entries in the dictionary map to the outputs of the pipeline. For more information, check the documentation of DALIGenericIterator.

[5]:
class DALILitMNIST(LitMNIST):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        device_id = self.local_rank
        shard_id = self.global_rank
        num_shards = self.trainer.world_size
        mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=device_id, shard_id=shard_id,
                                       num_shards=num_shards, num_threads=8)
        self.train_loader = DALIClassificationIterator(mnist_pipeline, reader_name="Reader",
                                                       fill_last_batch=False, auto_reset=True)
    def train_dataloader(self):
        return self.train_loader

    def process_batch(self, batch):
        x = batch[0]["data"]
        y = batch[0]["label"].squeeze(-1).cuda().long()
        return (x, y)

We can now run the training

[6]:
model = DALILitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)
GPU available: True, used: True
I1016 15:34:44.947721 139682004997952 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1016 15:34:44.950112 139682004997952 distributed.py:49] TPU available: False, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
I1016 15:34:44.952062 139682004997952 distributed.py:49] Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
I1016 15:34:44.953487 139682004997952 accelerator_connector.py:333] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K
I1016 15:34:45.160256 139682004997952 lightning.py:1215]
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K

[6]:
1

For even better integration, we can provide a custom DALI iterator wrapper so that no extra processing is required inside LitMNIST.process_batch. Also, PyTorch can learn the size of the dataset this way.

[7]:
class BetterDALILitMNIST(LitMNIST):
    def __init__(self):
        super().__init__()

    def prepare_data(self):
        device_id = self.local_rank
        shard_id = self.global_rank
        num_shards = self.trainer.world_size
        mnist_pipeline = MnistPipeline(BATCH_SIZE, device='cpu', device_id=device_id, shard_id=shard_id, num_shards=num_shards, num_threads=8)

        class LightningWrapper(DALIClassificationIterator):
            def __init__(self, *kargs, **kvargs):
                super().__init__(*kargs, **kvargs)

            def __next__(self):
                out = super().__next__()
                # DDP is used so only one pipeline per process
                # also we need to transform dict returned by DALIClassificationIterator to iterable
                # and squeeze the lables
                out = out[0]
                return [out[k] if k != "label" else torch.squeeze(out[k]) for k in self.output_map]

        self.train_loader = LightningWrapper(mnist_pipeline, reader_name="Reader", fill_last_batch=False, auto_reset=True)

    def train_dataloader(self):
        return self.train_loader

Let us run the training one more time

[8]:
model = BetterDALILitMNIST()
trainer = Trainer(gpus=1, distributed_backend="ddp", max_epochs=5)
trainer.fit(model)
GPU available: True, used: True
I1016 15:35:11.309178 139682004997952 distributed.py:49] GPU available: True, used: True
TPU available: False, using: 0 TPU cores
I1016 15:35:11.311038 139682004997952 distributed.py:49] TPU available: False, using: 0 TPU cores
Using environment variable NODE_RANK for node rank (0).
I1016 15:35:11.312532 139682004997952 distributed.py:49] Using environment variable NODE_RANK for node rank (0).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
I1016 15:35:11.314745 139682004997952 accelerator_connector.py:333] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K
I1016 15:35:11.358530 139682004997952 lightning.py:1215]
  | Name    | Type   | Params
-----------------------------------
0 | layer_1 | Linear | 100 K
1 | layer_2 | Linear | 33 K
2 | layer_3 | Linear | 2 K

[8]:
1