PyTorch DALI Proxy#

Overview#

DALI Proxy is a tool designed to integrate NVIDIA DALI pipelines with PyTorch data workers while maintaining the simplicity of PyTorch’s dataset logic. The key features of DALI Proxy include:

  • Efficient GPU Utilization: DALI Proxy ensures GPU data processing occurs in the process running the main loop. This avoids performance degradation caused by multiple CUDA contexts for the same GPU.

  • Selective Offloading: Users can offload parts of the data processing pipeline to DALI while retaining PyTorch Dataset logic, making it ideal for multi-modal applications.

This tutorial will explain the key components, workflow, and usage of DALI Proxy in PyTorch.

Note

Disclaimer: At present, data produced by the DALI proxy cannot be further processed within the Dataset. It must be passed as-is to the main loop. If post-processing outside of DALI is needed, it should occur only after the data has been generated by the iterator.

DALI Proxy Workflow#

Key Components

  1. DALI Pipeline A user-defined DALI pipeline processes input data.

  2. DALI Server The server runs a background thread to execute the DALI pipeline asynchronously.

  3. DALI Proxy A callable interface between PyTorch data workers and the DALI Server.

  4. PyTorch Dataset and DataLoader The Dataset remains agnostic of DALI internals and uses the Proxy for preprocessing.

Workflow Summary

  • A DALI pipeline is defined and connected to a DALI Server, which executes the pipeline in a background thread.

  • The DALI Proxy provides an interface for PyTorch data workers to request DALI processing asynchronously.

  • Each data worker invokes the proxy, which returns a reference to a future processed sample.

  • During batch collation, the proxy groups data into a batch and sends it to the server for execution.

  • The server processes the batch asynchronously and outputs the actual data to an output queue.

  • The PyTorch DataLoader retrieves either the processed data or references to pending pipeline runs. The pending pipeline run references are then replaced with actual data, waiting for the data if necessary.

API#

class nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer(pipeline, deterministic=False)#
__enter__()#

Starts the DALI pipeline thread

__exit__(exc_type, exc_value, tb)#

Stops the DALI pipeline thread

__init__(pipeline, deterministic=False)#

Initializes a new DALI server instance.

Parameters:
  • pipeline (Pipeline) – DALI pipeline to run.

  • deterministic (bool) – If True, it ensures that the order of execution is always the same, which is important when the pipeline has a state and we are interested in obtaining reproducible results. Also, if enabled, the execution will be less performant, as the DALI processing can be scheduled only after the data loader has returned the batch information, and not as soon as data worker collates the batch.

Example 1 - Full integration with PyTorch via DALI proxy DataLoader:

@pipeline_def
def rn50_train_pipe():
    rng = fn.random.coin_flip(probability=0.5)
    filepaths = fn.external_source(name="images", no_copy=True)
    jpegs = fn.io.file.read(filepaths)
    images = fn.decoders.image_random_crop(
        jpegs,
        device="mixed",
        output_type=types.RGB,
        random_aspect_ratio=[0.75, 4.0 / 3.0],
        random_area=[0.08, 1.0],
    )
    images = fn.resize(
        images,
        size=[224, 224],
        interp_type=types.INTERP_LINEAR,
        antialias=False,
    )
    output = fn.crop_mirror_normalize(
        images,
        dtype=types.FLOAT,
        output_layout="CHW",
        crop=(224, 224),
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
        mirror=rng,
    )
    return output

def read_filepath(path):
    return np.frombuffer(path.encode(), dtype=np.int8)

nworkers = 8
pipe = rn50_train_pipe(
    batch_size=16, num_threads=3, device_id=0,
    prefetch_queue_depth=2*nworkers)

# The scope makes sure the server starts and stops at enter/exit
with dali_proxy.DALIServer(pipe) as dali_server:
    # DALI proxy instance can be used as a transform callable
    dataset = torchvision.datasets.ImageFolder(
        jpeg, transform=dali_server.proxy, loader=read_filepath)

    # Same interface as torch DataLoader, but takes a dali_server as first argument
    loader = nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader(
        dali_server,
        dataset,
        batch_size=batch_size,
        num_workers=nworkers,
        drop_last=True,
    )

    for data, target in loader:
        # consume it

Example 2 - Manual execution using DALI proxy / DALI server and PyTorch’s default_collate:

@pipeline_def
def my_pipe():
    a = fn.external_source(name="a", no_copy=True)
    b = fn.external_source(name="b", no_copy=True)
    return a + b, a - b

with dali_proxy.DALIServer(
    my_pipe(device='cpu', batch_size=batch_size,
            num_threads=3, device_id=None)) as dali_server:

    outs = []
    for _ in range(batch_size):
        a = np.array(np.random.rand(3, 3), dtype=np.float32)
        b = np.array(np.random.rand(3, 3), dtype=np.float32)
        out0, out1 = dali_server.proxy(a=a, b=b)
        outs.append((a, b, out0, out1))

    outs = torch.utils.data.dataloader.default_collate(outs)

    a, b, a_plus_b, a_minus_b = dali_server.produce_data(outs)

Example 3 - Full integration with PyTorch but using the original PyTorch DataLoader

pipe = rn50_train_pipe(...)
with dali_proxy.DALIServer(pipe) as dali_server:
    dataset = torchvision.datasets.ImageFolder(
        jpeg, transform=dali_server.proxy, loader=read_filepath)

    # Using PyTorch DataLoader directly
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=nworkers,
        drop_last=True,
    )

    for data, target in loader:
        # replaces the output reference with actual data
        data = dali_server.produce_data(data)
        ...
produce_data(obj)#

A generic function to recursively visits all elements in a nested structure and replace instances of DALIOutputBatchRef with the actual data provided by the DALI server See nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer for a full example.

Parameters:

obj – The object to map (can be an instance of any class).

Returns:

A new object where any instance of DALIOutputBatchRef has been replaced with actual data.

start_thread()#

Starts the DALI pipeline thread. Note: Using scope’s __enter__/__exit__ is preferred

stop_thread()#

Stops the DALI pipeline thread. Note: Using scope’s __enter__/__exit__ is preferred

class nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader(*args, **kwargs)#

DALI data loader to be used in the main loop, which replaces the pipeline run references with actual data produced by the DALI server. See nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer for a full example.

__init__(dali_server, *args, **kwargs)#

Same interface as PyTorch’s DataLoader except for the extra DALIServer argument

Example Usage#

DALI Proxy in a Nutshell#

from torchvision import datasets, transforms
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy

# Step 1: Define a DALI pipeline
@pipeline_def
def my_dali_pipeline():
    images = fn.external_source(name="images", no_copy=True)
    images = fn.resize(images, size=[224, 224])
    return fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, output_layout="CHW",
        mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
        std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
    )

# Step 2: Initialize DALI server. The scope makes sure to start and stop the background thread
with dali_proxy.DALIServer(my_dali_pipeline(batch_size=64, num_threads=3, device_id=0)) as dali_server:
    # Step 3: Define a PyTorch Dataset using the DALI proxy
    dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)

    # Step 4: Use DALI proxy DataLoader
    loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)

    # Step 5: Consume data
    for data, target in loader:
        print(data.shape)  # Processed data ready

How It Works#

1. DALI Pipeline

The DALI pipeline defines the data processing steps. Input data is fed using external_source().

from nvidia.dali import pipeline_def, fn, types

@pipeline_def
def example_pipeline():
    images = fn.external_source(name="images", no_copy=True)
    images = fn.io.file.read(images)
    images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
    return fn.resize(images, size=[224, 224])

pipeline = example_pipeline(batch_size=32, num_threads=2, device_id=0)

2. DALI Server and Proxy

The nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer manages the execution of the pipeline. The Proxy acts as an interface for PyTorch data workers. Note that the DALI pipeline should contain at least one input (an external_source() instance), and that the names of those nodes then become the inputs to the DALI proxy callable.

from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
with dali_proxy.DALIServer(pipeline) as dali_server:
   future_samples = [dali_server.proxy(image) for image in images]

With more than one input, we can choose to use positional arguments, keyword arguments:

import numpy as np
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy

@pipeline_def
def example_pipeline2(device):
   a = fn.external_source(name="a", no_copy=True)
   b = fn.external_source(name="b", no_copy=True)
   return a + b, b - a

with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
   a = np.array(...)
   b = np.array(...)

   # Option 1: positional arguments
   a_plus_b, b_minus_a = dali_server.proxy(a, b)

   # Option 2: named arguments
   a_plus_b, b_minus_a = dali_server.proxy(b=b, a=a)

It is also possible to start and stop the server explicitly:

dali_server = dali_proxy.DALIServer(example_pipeline2(...))
dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)

# Optional, it will be started on first attempt to get data from the loader anyway
dali_server.start_thread()

for data in loader:
   ...

# This is needed to make sure we have stopped the thread
dali_server.stop_thread()

When possible, use the with scope.

3. Integration with PyTorch DataLoader

The nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader wrapper provided by DALI Proxy simplifies the integration process.

from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy

with dali_proxy.DALIServer(pipeline) as dali_server:
   dataset = CustomDataset(dali_server.proxy, data=images)
   loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=32, num_workers=4)
   for data, _ in loader:
      print(data.shape)  # Ready-to-use processed batch

If using a custom nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader, call the DALI server explicitly:

with dali_proxy.DALIServer(pipeline) as dali_server:
   dataset = CustomDataset(dali_server.proxy, data=images)
   loader = MyCustomDataloader(...)
   for data, _ in loader:
      # Replaces instances of ``DALIOutputBatchRef`` with actual data
      processed_data = dali_server.produce_data(data)
      print(processed_data.shape)  # data is now ready

4. Integration with PyTorch Dataset

The PyTorch Dataset can directly use the proxy as a transform function. Note that we can choose to offload only part of the processing to DALI, while keeping some of the original data intact.

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, transform_fn, data):
        self.data = data
        self.transform_fn = transform_fn

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        filename, label = self.data[idx]
        return self.transform_fn(filename), label  # Returns processed sample and the original label

5. Data Collation and Execution

This step is usually abstracted away inside the PyTorch DataLoader and the user doesn’t need to take care of it explicitly. The default_collate function combines processed samples into a batch. DALI executes the pipeline asynchronously when a batch is collated.

from torch.utils.data.dataloader import default_collate as default_collate

with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
   outs = []
   for _ in range(10):
      a = np.array(np.random.rand(3, 3), dtype=np.float32)
      b = np.array(np.random.rand(3, 3), dtype=np.float32)
      a_plus_b, b_minus_a = dali_server.proxy(a, b)
      outs.append((a_plus_b, b_minus_a))

   # Collate into a single batch run reference
   outs = default_collate(outs)

   # And we can now replace the run reference with actual data
   outs = dali_server.produce_data(outs)

Summary#

DALI Proxy provides a clean and efficient way to integrate NVIDIA DALI with PyTorch. By offloading computationally intensive tasks to DALI while keeping PyTorch’s Dataset and DataLoader interface intact, it ensures flexibility and maximum performance. This approach is particularly powerful in large-scale data pipelines and multi-modal workflows.