PyTorch DALI Proxy#
Overview#
DALI Proxy is a tool designed to integrate NVIDIA DALI pipelines with PyTorch data workers while maintaining the simplicity of PyTorch’s dataset logic. The key features of DALI Proxy include:
Efficient GPU Utilization: DALI Proxy ensures GPU data processing occurs in the process running the main loop. This avoids performance degradation caused by multiple CUDA contexts for the same GPU.
Selective Offloading: Users can offload parts of the data processing pipeline to DALI while retaining PyTorch Dataset logic, making it ideal for multi-modal applications.
This tutorial will explain the key components, workflow, and usage of DALI Proxy in PyTorch.
Note
Disclaimer: At present, data produced by the DALI proxy cannot be further processed within the Dataset. It must be passed as-is to the main loop. If post-processing outside of DALI is needed, it should occur only after the data has been generated by the iterator.
DALI Proxy Workflow#
Key Components
DALI Pipeline A user-defined DALI pipeline processes input data.
DALI Server The server runs a background thread to execute the DALI pipeline asynchronously.
DALI Proxy A callable interface between PyTorch data workers and the DALI Server.
PyTorch Dataset and DataLoader The Dataset remains agnostic of DALI internals and uses the Proxy for preprocessing.
Workflow Summary
A DALI pipeline is defined and connected to a DALI Server, which executes the pipeline in a background thread.
The DALI Proxy provides an interface for PyTorch data workers to request DALI processing asynchronously.
Each data worker invokes the proxy, which returns a reference to a future processed sample.
During batch collation, the proxy groups data into a batch and sends it to the server for execution.
The server processes the batch asynchronously and outputs the actual data to an output queue.
The PyTorch DataLoader retrieves either the processed data or references to pending pipeline runs. The pending pipeline run references are then replaced with actual data, waiting for the data if necessary.
API#
- class nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer(pipeline, deterministic=False)#
- __enter__()#
Starts the DALI pipeline thread
- __exit__(exc_type, exc_value, tb)#
Stops the DALI pipeline thread
- __init__(pipeline, deterministic=False)#
Initializes a new DALI server instance.
- Parameters:
deterministic¶ (bool) – If True, it ensures that the order of execution is always the same, which is important when the pipeline has a state and we are interested in obtaining reproducible results. Also, if enabled, the execution will be less performant, as the DALI processing can be scheduled only after the data loader has returned the batch information, and not as soon as data worker collates the batch.
Example 1 - Full integration with PyTorch via DALI proxy DataLoader:
@pipeline_def def rn50_train_pipe(): rng = fn.random.coin_flip(probability=0.5) filepaths = fn.external_source(name="images", no_copy=True) jpegs = fn.io.file.read(filepaths) images = fn.decoders.image_random_crop( jpegs, device="mixed", output_type=types.RGB, random_aspect_ratio=[0.75, 4.0 / 3.0], random_area=[0.08, 1.0], ) images = fn.resize( images, size=[224, 224], interp_type=types.INTERP_LINEAR, antialias=False, ) output = fn.crop_mirror_normalize( images, dtype=types.FLOAT, output_layout="CHW", crop=(224, 224), mean=[0.485 * 255, 0.456 * 255, 0.406 * 255], std=[0.229 * 255, 0.224 * 255, 0.225 * 255], mirror=rng, ) return output def read_filepath(path): return np.frombuffer(path.encode(), dtype=np.int8) nworkers = 8 pipe = rn50_train_pipe( batch_size=16, num_threads=3, device_id=0, prefetch_queue_depth=2*nworkers) # The scope makes sure the server starts and stops at enter/exit with dali_proxy.DALIServer(pipe) as dali_server: # DALI proxy instance can be used as a transform callable dataset = torchvision.datasets.ImageFolder( jpeg, transform=dali_server.proxy, loader=read_filepath) # Same interface as torch DataLoader, but takes a dali_server as first argument loader = nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader( dali_server, dataset, batch_size=batch_size, num_workers=nworkers, drop_last=True, ) for data, target in loader: # consume it
Example 2 - Manual execution using DALI proxy / DALI server and PyTorch’s default_collate:
@pipeline_def def my_pipe(): a = fn.external_source(name="a", no_copy=True) b = fn.external_source(name="b", no_copy=True) return a + b, a - b with dali_proxy.DALIServer( my_pipe(device='cpu', batch_size=batch_size, num_threads=3, device_id=None)) as dali_server: outs = [] for _ in range(batch_size): a = np.array(np.random.rand(3, 3), dtype=np.float32) b = np.array(np.random.rand(3, 3), dtype=np.float32) out0, out1 = dali_server.proxy(a=a, b=b) outs.append((a, b, out0, out1)) outs = torch.utils.data.dataloader.default_collate(outs) a, b, a_plus_b, a_minus_b = dali_server.produce_data(outs)
Example 3 - Full integration with PyTorch but using the original PyTorch DataLoader
pipe = rn50_train_pipe(...) with dali_proxy.DALIServer(pipe) as dali_server: dataset = torchvision.datasets.ImageFolder( jpeg, transform=dali_server.proxy, loader=read_filepath) # Using PyTorch DataLoader directly loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=nworkers, drop_last=True, ) for data, target in loader: # replaces the output reference with actual data data = dali_server.produce_data(data) ...
- produce_data(obj)#
A generic function to recursively visits all elements in a nested structure and replace instances of DALIOutputBatchRef with the actual data provided by the DALI server See
nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
for a full example.- Parameters:
obj¶ – The object to map (can be an instance of any class).
- Returns:
A new object where any instance of DALIOutputBatchRef has been replaced with actual data.
- start_thread()#
Starts the DALI pipeline thread. Note: Using scope’s __enter__/__exit__ is preferred
- stop_thread()#
Stops the DALI pipeline thread. Note: Using scope’s __enter__/__exit__ is preferred
- class nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader(*args, **kwargs)#
DALI data loader to be used in the main loop, which replaces the pipeline run references with actual data produced by the DALI server. See
nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
for a full example.- __init__(dali_server, *args, **kwargs)#
Same interface as PyTorch’s DataLoader except for the extra DALIServer argument
Example Usage#
DALI Proxy in a Nutshell#
from torchvision import datasets, transforms
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
# Step 1: Define a DALI pipeline
@pipeline_def
def my_dali_pipeline():
images = fn.external_source(name="images", no_copy=True)
images = fn.resize(images, size=[224, 224])
return fn.crop_mirror_normalize(
images, dtype=types.FLOAT, output_layout="CHW",
mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],
std=[0.229 * 255, 0.224 * 255, 0.225 * 255],
)
# Step 2: Initialize DALI server. The scope makes sure to start and stop the background thread
with dali_proxy.DALIServer(my_dali_pipeline(batch_size=64, num_threads=3, device_id=0)) as dali_server:
# Step 3: Define a PyTorch Dataset using the DALI proxy
dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)
# Step 4: Use DALI proxy DataLoader
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)
# Step 5: Consume data
for data, target in loader:
print(data.shape) # Processed data ready
How It Works#
1. DALI Pipeline
The DALI pipeline defines the data processing steps. Input data is fed using external_source()
.
from nvidia.dali import pipeline_def, fn, types
@pipeline_def
def example_pipeline():
images = fn.external_source(name="images", no_copy=True)
images = fn.io.file.read(images)
images = fn.decoders.image(images, device="mixed", output_type=types.RGB)
return fn.resize(images, size=[224, 224])
pipeline = example_pipeline(batch_size=32, num_threads=2, device_id=0)
2. DALI Server and Proxy
The nvidia.dali.plugin.pytorch.experimental.proxy.DALIServer
manages the execution of the pipeline. The Proxy acts
as an interface for PyTorch data workers.
Note that the DALI pipeline should contain at least one input (an external_source()
instance), and
that the names of those nodes then become the inputs to the DALI proxy callable.
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
with dali_proxy.DALIServer(pipeline) as dali_server:
future_samples = [dali_server.proxy(image) for image in images]
With more than one input, we can choose to use positional arguments, keyword arguments:
import numpy as np
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
@pipeline_def
def example_pipeline2(device):
a = fn.external_source(name="a", no_copy=True)
b = fn.external_source(name="b", no_copy=True)
return a + b, b - a
with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
a = np.array(...)
b = np.array(...)
# Option 1: positional arguments
a_plus_b, b_minus_a = dali_server.proxy(a, b)
# Option 2: named arguments
a_plus_b, b_minus_a = dali_server.proxy(b=b, a=a)
It is also possible to start and stop the server explicitly:
dali_server = dali_proxy.DALIServer(example_pipeline2(...))
dataset = datasets.ImageFolder("/path/to/images", transform=dali_server.proxy)
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=64, num_workers=8, drop_last=True)
# Optional, it will be started on first attempt to get data from the loader anyway
dali_server.start_thread()
for data in loader:
...
# This is needed to make sure we have stopped the thread
dali_server.stop_thread()
When possible, use the with
scope.
3. Integration with PyTorch DataLoader
The nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader
wrapper provided by DALI Proxy simplifies the integration process.
from nvidia.dali.plugin.pytorch.experimental import proxy as dali_proxy
with dali_proxy.DALIServer(pipeline) as dali_server:
dataset = CustomDataset(dali_server.proxy, data=images)
loader = dali_proxy.DataLoader(dali_server, dataset, batch_size=32, num_workers=4)
for data, _ in loader:
print(data.shape) # Ready-to-use processed batch
If using a custom nvidia.dali.plugin.pytorch.experimental.proxy.DataLoader
, call the DALI server explicitly:
with dali_proxy.DALIServer(pipeline) as dali_server:
dataset = CustomDataset(dali_server.proxy, data=images)
loader = MyCustomDataloader(...)
for data, _ in loader:
# Replaces instances of ``DALIOutputBatchRef`` with actual data
processed_data = dali_server.produce_data(data)
print(processed_data.shape) # data is now ready
4. Integration with PyTorch Dataset
The PyTorch Dataset can directly use the proxy as a transform function. Note that we can choose to offload only part of the processing to DALI, while keeping some of the original data intact.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, transform_fn, data):
self.data = data
self.transform_fn = transform_fn
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
filename, label = self.data[idx]
return self.transform_fn(filename), label # Returns processed sample and the original label
5. Data Collation and Execution
This step is usually abstracted away inside the PyTorch DataLoader and the user doesn’t need to take care of it explicitly.
The default_collate
function combines processed samples into a batch. DALI executes the pipeline asynchronously when a batch is collated.
from torch.utils.data.dataloader import default_collate as default_collate
with dali_proxy.DALIServer(example_pipeline2(...)) as dali_server:
outs = []
for _ in range(10):
a = np.array(np.random.rand(3, 3), dtype=np.float32)
b = np.array(np.random.rand(3, 3), dtype=np.float32)
a_plus_b, b_minus_a = dali_server.proxy(a, b)
outs.append((a_plus_b, b_minus_a))
# Collate into a single batch run reference
outs = default_collate(outs)
# And we can now replace the run reference with actual data
outs = dali_server.produce_data(outs)
Summary#
DALI Proxy provides a clean and efficient way to integrate NVIDIA DALI with PyTorch. By offloading computationally intensive tasks to DALI while keeping PyTorch’s Dataset and DataLoader interface intact, it ensures flexibility and maximum performance. This approach is particularly powerful in large-scale data pipelines and multi-modal workflows.