Python Operators#
This example shows you how to run custom Python code by using the family of DALI python_function
operators to prototype new augmentations or debug the pipeline. The idea behind these operators is to help you to execute the Python code that operates on DALI’s tensors’ data in the pipeline execution.
Defining an Operation#
The operator that we will use first is python_function
, which wraps a regular Python function and runs it in a DALI Pipeline.
We define this function as an example and call it edit_images
.
[1]:
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import numpy as np
def edit_images(image1, image2):
assert image1.shape == image2.shape
h, w, c = image1.shape
y, x = np.ogrid[0:h, 0:w]
mask = (x - w / 2) ** 2 + (y - h / 2) ** 2 > h * w / 9
result1 = np.copy(image1)
result1[mask] = image2[mask]
result2 = np.copy(image2)
result2[mask] = image1[mask]
return result1, result2
In this case, it takes two arrays as inputs and returns two outputs.
The code creates a circular mask and uses it to swap those circular parts between two inputs.
python_function
uses NumPy arrays as the data format for the CPU, and CuPy arrays for GPU.
Note: Both input images are copied, because the input data should not be modified.
Warning
When the pipeline has conditional execution enabled, additional steps must be taken to prevent the function
from being rewritten by AutoGraph. There are two ways to achieve this: 1. Define the function at global scope (i.e. outside of pipeline_def
scope). 2. If function is a result of another “factory” function, then the factory function must have nvidia.dali.pipeline.do_not_convert
attribute.
More details can be found in nvidia.dali.pipeline.do_not_convert
documentation.
Defining a Pipeline#
To see the operator in action, we implement a simple data pipeline:
Load, decode and resize the images to common size.
Wrap the
edit_images
by passing it asfunction
parameter todali.fn.python_function
.In addition to the function, we pass the number of outputs as a parameter.
We invoke the
python_function
like any other DALI operator - the inputs will be passed toedit_images
for processing.
[2]:
image_dir = "../data/images"
batch_size = 4
@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=99)
def pipeline_fn():
input1, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
input2, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
im1, im2 = fn.decoders.image(
[input1, input2], device="cpu", output_type=types.RGB
)
res1, res2 = fn.resize([im1, im2], resize_x=300, resize_y=300)
out1, out2 = fn.python_function(
res1, res2, function=edit_images, num_outputs=2
)
return out1, out2
Running the Pipeline and Visualizing the Results#
To see the results, run the pipeline.
[3]:
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
def show_images(image_batch):
columns = 4
rows = (batch_size + 1) // columns
fig = plt.figure(figsize=(32, (32 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(image_batch.at(j))
pipe = pipeline_fn()
pipe.build()
ims1, ims2 = pipe.run()
show_images(ims1)
show_images(ims2)
Variety of Python Operators#
In DALI, python_function
comes in different flavors. The basic idea remains, but the data format on which the implementation operates differs in the following ways:
python_function
- works on arrays.torch_python_function
- works on PyTorch tensors.dl_tensor_python_function
- works on DLPack tensors.
The most universal operator is dl_tensor_python_function
. DLPack is an open standard for tensor storage and many frameworks and libraries implement conversion methods to and from DLPack tensors. Internally it is used to implement all the other kinds of Python operators.
TorchPythonFunction and DLTensorPythonFunction#
The example provides information about using the PyTorch functions in the DALI pipeline. The ideal way to use those functions is to use the torch_python_function
operator, but we will also use the dl_tensor_python_function
to show how you can work with DLPack
tensors.
We use the torchvision RandomPerspective
transform in the perspective
function, and we will wrap it in the torch_python_function
.
The dlpack_manipulation
function shows you how to handle DLPack data:
The input batch is converted to a list of PyTorch tensors.
Converted input is processed.
The output is converted back to DLPack tensors.
Every Python operator has the batch_processing
parameter. This parameter determines whether the implementation function gets the whole batch as a list of tensors or whether it will be called per sample. Due to historical reasons, for dl_tensor_python_function
, this parameter is set to True by default. We can look at dlpack_manipulation
to see how to work with this kind of input.
[4]:
import nvidia.dali.plugin.pytorch as dalitorch
import torch
import torch.utils.dlpack as torch_dlpack
import torchvision.transforms as transforms
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.RandomPerspective(p=1.0),
transforms.ToTensor(),
]
)
def perspective_fn(t):
return transform(t).transpose(2, 0).transpose(0, 1)
def dlpack_manipulation(dlpacks):
tensors = [torch_dlpack.from_dlpack(dlpack) for dlpack in dlpacks]
output = [(tensor.to(torch.float32) / 255.0).sqrt() for tensor in tensors]
output.reverse()
return [torch_dlpack.to_dlpack(tensor) for tensor in output]
@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=99)
def torch_pipeline_fn():
input, _ = fn.readers.file(file_root=image_dir, random_shuffle=True)
im = fn.decoders.image(input, device="cpu", output_type=types.RGB)
res = fn.resize(im, resize_x=300, resize_y=300)
norm = fn.crop_mirror_normalize(res, std=255.0, mean=0.0)
perspective = dalitorch.fn.torch_python_function(
norm, function=perspective_fn
)
sqrt_color = fn.dl_tensor_python_function(res, function=dlpack_manipulation)
return perspective, sqrt_color
[5]:
torch_pipe = torch_pipeline_fn()
torch_pipe.build()
x, y = torch_pipe.run()
show_images(x)
show_images(y)