Operator Objects (Legacy)#
Overview#
Before introducing the fn
API and the pipeline_def
decorator, DALI used Operator Objects API. It more closely resembles internals of DALI but is harder to use. As you may find some code samples written in this old API, this example can be used as a quick reference. It goes through the same steps as Getting Started page but with the legacy Operator Objects API.
Defining the Pipeline#
Let us start with defining a very simple pipeline for a classification task determining whether a picture contains dog or a cat.
We prepared a directory structure containing pictures of dogs and cats:
[1]:
import os.path
import fnmatch
for root, dir, files in os.walk("data/images"):
depth = root.count("/")
ret = ""
if depth > 0:
ret += " " * (depth - 1) + "|-"
print(ret + root)
for items in fnmatch.filter(files, "*"):
print(" " * len(ret) + "|-" + items)
|-data/images
|-file_list.txt
|-data/images/dog
|-dog_4.jpg
|-dog_5.jpg
|-dog_9.jpg
|-dog_6.jpg
|-dog_3.jpg
|-dog_7.jpg
|-dog_10.jpg
|-dog_2.jpg
|-dog_8.jpg
|-dog_1.jpg
|-dog_11.jpg
|-data/images/kitten
|-cat_10.jpg
|-cat_5.jpg
|-cat_9.jpg
|-cat_8.jpg
|-cat_1.jpg
|-cat_7.jpg
|-cat_6.jpg
|-cat_3.jpg
|-cat_2.jpg
|-cat_4.jpg
Our simple pipeline will read images from this directory, decode them and return (image, label) pairs.
[2]:
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
image_dir = "data/images"
max_batch_size = 8
class SimplePipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(SimplePipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(file_root=image_dir)
self.decode = ops.decoders.Image(device="cpu", output_type=types.RGB)
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
return (images, labels)
The SimplePipeline
class is a subclass of dali.pipeline.Pipeline
, which provides most of the methods to create and launch a pipeline. The only 2 methods that we need to implement is the constructor and the define_graph
function.
In the constructor we first call our superclass constructor, in order to set global parameters of the pipeline:
batch size,
number of threads used to perform computation on the CPU,
which GPU device to use (
SimplePipeline
does not yet use GPU for compute though)seed for random number generation
In the constructor we also define member variables of our SimplePipeline
class as operations defined in dali.ops
module:
readers.File
- traverses the directory and returns pairs of (encoded image, label)decoders.Image
- takes an encoded image input and outputs decoded RGB image
In the define_graph
function we define the actual flow of computation:
jpegs, labels = self.input()
uses our input
operation to create jpegs
(encoded images) and labels
.
images = self.decode(jpegs)
Next, we use the decode
operation to create images
(decoded RGB images).
return (images, labels)
Finally, we specify which of the intermediate variables should be returned as outputs of the pipeline.
Building the Pipeline#
In order to use our SimplePipeline
, we need to build it. This is achieved by calling the build
function.
[3]:
pipe = SimplePipeline(max_batch_size, 1, 0)
pipe.build()
Running the Pipeline#
After the pipeline is built, we can run it to get a batch of results.
[4]:
pipe_out = pipe.run()
print(pipe_out)
(TensorListCPU(
[[[[255 255 255]
[255 255 255]
...
[ 86 46 55]
[ 86 46 55]]
[[255 255 255]
[255 255 255]
...
[ 86 46 55]
[ 86 46 55]]
...
[[158 145 154]
[158 147 155]
...
[ 93 38 41]
[ 93 38 41]]
[[157 145 155]
[158 146 156]
...
[ 93 38 41]
[ 93 38 41]]]
[[[ 69 77 80]
[ 69 77 80]
...
[ 97 105 108]
[ 97 105 108]]
[[ 69 77 80]
[ 70 78 81]
...
[ 97 105 108]
[ 97 105 108]]
...
[[199 203 206]
[199 203 206]
...
[206 210 213]
[206 210 213]]
[[199 203 206]
[199 203 206]
...
[206 210 213]
[206 210 213]]]
...
[[[ 26 28 25]
[ 26 28 25]
...
[ 34 39 33]
[ 34 39 33]]
[[ 26 28 25]
[ 26 28 25]
...
[ 34 39 33]
[ 34 39 33]]
...
[[ 35 46 30]
[ 36 47 31]
...
[114 99 106]
[127 114 121]]
[[ 35 46 30]
[ 35 46 30]
...
[107 92 99]
[112 97 102]]]
[[[182 185 132]
[180 183 128]
...
[ 98 103 9]
[ 97 102 8]]
[[180 183 130]
[179 182 127]
...
[ 93 98 4]
[ 91 96 2]]
...
[[ 69 111 71]
[ 68 111 66]
...
[147 159 121]
[148 163 124]]
[[ 64 109 68]
[ 64 110 64]
...
[113 123 88]
[104 116 80]]]],
dtype=DALIDataType.UINT8,
layout="HWC",
num_samples=8,
shape=[(427, 640, 3),
(427, 640, 3),
(425, 640, 3),
(480, 640, 3),
(485, 640, 3),
(427, 640, 3),
(409, 640, 3),
(427, 640, 3)]), TensorListCPU(
[[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]],
dtype=DALIDataType.INT32,
num_samples=8,
shape=[(1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)]))
The output of the pipeline, which we saved to pipe_out
variable, is a tuple of 2 elements (as expected - we specified 2 outputs in define_graph
function in SimplePipeline
class). Both of these elements are TensorListCPU
objects - each containing a list of CPU tensors.
In order to show the results (just for debugging purposes - during the actual training we would not do that step, as it would make our batch of images do a round trip from GPU to CPU and back) we can send our data from DALI’s Tensor to NumPy array. Not every TensorList
can be accessed that way though - TensorList
is more general than NumPy array and can hold tensors with different shapes. In order to check whether we can send it to NumPy directly, we can call the is_dense_tensor
function of TensorList
[5]:
images, labels = pipe_out
print("Images is_dense_tensor: " + str(images.is_dense_tensor()))
print("Labels is_dense_tensor: " + str(labels.is_dense_tensor()))
Images is_dense_tensor: False
Labels is_dense_tensor: True
As it turns out, TensorList
containing labels can be represented by a tensor, while the TensorList
containing images cannot.
Let us see, what is the shape and contents of returned labels.
[6]:
print(labels)
TensorListCPU(
[[0]
[0]
[0]
[0]
[0]
[0]
[0]
[0]],
dtype=DALIDataType.INT32,
num_samples=8,
shape=[(1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)])
In order to see the images, we will need to loop over all tensors contained in TensorList
, accessed with its at
method.
[7]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline
def show_images(image_batch):
columns = 4
rows = (max_batch_size + 1) // (columns)
fig = plt.figure(figsize=(24, (24 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(image_batch.at(j))
[8]:
show_images(images)
Adding Augmentations#
Random Shuffle#
As we can see from the example above, the first batch of images returned by our pipeline contains only dogs. That is because we did not shuffle our dataset, and so FileReader
returns images in order in which it encountered them while traversing the directory structure.
Let us make a new pipeline, that will change that.
[9]:
class ShuffledSimplePipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(ShuffledSimplePipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(
file_root=image_dir, random_shuffle=True, initial_fill=21
)
self.decode = ops.decoders.Image(device="cpu", output_type=types.RGB)
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
return (images, labels)
We made 2 changes to SimplePipeline
to obtain ShuffledSimplePipeline
- we added 2 parameters to the FileReader
operation
random_shuffle
enables shuffling of images in the reader. Shuffling is performed by using a buffer of images read from disk. When the reader is asked to provide the next image, it randomly selects an image from the buffer, outputs it and immediately replaces that spot in the buffer with a freshly read image.initial_fill
sets the capacity of the buffer. The default value of this parameter (1000), well suited for datasets containing thousands of examples, is too big for our very small dataset, which contains only 21 images. This could result in frequent duplicates in the returned batch. That is why in this example we set it to the size of our dataset.
Let us test the result of this modification.
[10]:
pipe = ShuffledSimplePipeline(max_batch_size, 1, 0)
pipe.build()
[11]:
pipe_out = pipe.run()
images, labels = pipe_out
show_images(images)
Now the images returned by the pipeline are shuffled properly.
Augmentations#
DALI can not only read images from disk and batch them into tensors, it can also perform various augmentations on those images to improve Deep Learning training results.
One example of such augmentations is rotation. Let us make a new pipeline, which rotates the images before outputting them.
[12]:
class RotatedSimplePipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(RotatedSimplePipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(
file_root=image_dir, random_shuffle=True, initial_fill=21
)
self.decode = ops.decoders.Image(device="cpu", output_type=types.RGB)
self.rotate = ops.Rotate(angle=10.0, fill_value=0)
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
rotated_images = self.rotate(images)
return (rotated_images, labels)
To do that, we added a new operation to our pipeline: dali.ops.Rotate
. To obtain information on required and optional arguments of any operation provided by DALI, we can use help
function.
[13]:
help(ops.Rotate)
Help on class Rotate in module nvidia.dali.ops:
class Rotate(builtins.object)
| Rotate(**kwargs)
|
| Rotates the images by the specified angle.
|
| This operator supports volumetric data.
|
| Supported backends
| * 'cpu'
| * 'gpu'
|
|
| Keyword args
| ------------
| angle : float or TensorList of float
| Angle, in degrees, by which the image is rotated.
|
| For two-dimensional data, the rotation is counter-clockwise, assuming the top-left corner is
| at ``(0,0)``. For three-dimensional data, the ``angle`` is a positive rotation around the provided
| axis.
| axis : float or list of float or TensorList of float, optional, default = `[]`
| Applies **only** to three-dimension and is the axis
| around which to rotate the image.
|
| The vector does not need to be normalized, but it must have a non-zero length.
| Reversing the vector is equivalent to changing the sign of ``angle``.
|
| bytes_per_sample_hint : int or list of int, optional, default = `[0]`
| Output size hint, in bytes per sample.
|
| If specified, the operator's outputs residing in GPU or page-locked host memory will be preallocated
| to accommodate a batch of samples of this size.
| dtype : nvidia.dali.types.DALIDataType, optional, default = `DALIDataType.NO_TYPE`
| Output data type.
|
| If not set, the input type is used.
| fill_value : nvidia.dali.types.DALIDataType, optional, default = `DALIDataType.FLOAT`
| Value used to fill areas that are outside the source image.
|
| If a value is not specified, the source coordinates are clamped and the border pixel is
| repeated.
| interp_type : nvidia.dali.types.DALIInterpType, optional, default = `DALIInterpType.INTERP_LINEAR`
| Type of interpolation used.
| keep_size : bool, optional, default = `False`
| If True, original canvas size is kept.
|
| If set to False (default), and the size is not set, the canvas size is adjusted to
| accommodate the rotated image with the least padding possible.
|
| preserve : bool, optional, default = `False`
| Prevents the operator from being removed from the
| graph even if its outputs are not used.
| seed : int, optional, default = `-1`
| Random seed.
|
| If not provided, it will be populated based on the global seed of the pipeline.
| size : float or list of float or TensorList of float, optional, default = `[]`
| Output size, in pixels/points.
|
| Non-integer sizes are rounded to nearest integer. The channel dimension should
| be excluded (for example, for RGB images, specify ``(480,640)``, not ``(480,640,3)``.
|
| output_dtype : nvidia.dali.types.DALIDataType
| .. warning::
|
| The argument ``output_dtype`` is a deprecated alias for ``dtype``. Use ``dtype`` instead.
|
| Methods defined here:
|
| __call__(self, *inputs, **kwargs)
| __call__(data, **kwargs)
|
| Operator call to be used in graph definition.
|
| Args
| ----
| data : TensorList ('HWC', 'DHWC')
| Input to the operator.
|
| __init__(self, **kwargs)
|
| ----------------------------------------------------------------------
| Readonly properties defined here:
|
| device
|
| preserve
|
| schema
|
| spec
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
|
| ----------------------------------------------------------------------
| Data and other attributes defined here:
|
| schema_name = 'Rotate'
As we can see, Rotate
can take multiple arguments, but only one of them, angle
, is required - it tells the operator how much it should rotate images.
Let us test the newly created pipeline:
[14]:
pipe = RotatedSimplePipeline(max_batch_size, 1, 0)
pipe.build()
[15]:
pipe_out = pipe.run()
images, labels = pipe_out
show_images(images)
Tensors as Arguments and Random Number Generation#
Rotating every image by 10 degrees is not that interesting. To make a meaningful augmentation, we would like an operator that rotates our images by a random angle in a given range.
The help
output for Rotate
operation tells us, that angle
parameter can accept float
or float tensor
types of values. The second option, float tensor
, enables us to feed the operator with different rotation angles for every image, via a tensor produced by other operation.
Random number generators are examples of operations that one can use with DALI. Let us use dali.ops.random.Uniform
to make a pipeline that rotates images by a random angle.
[16]:
class RandomRotatedSimplePipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(RandomRotatedSimplePipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(
file_root=image_dir, random_shuffle=True, initial_fill=21
)
self.decode = ops.decoders.Image(device="cpu", output_type=types.RGB)
self.rotate = ops.Rotate(fill_value=0)
self.rng = ops.random.Uniform(range=(-10.0, 10.0))
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
angle = self.rng()
rotated_images = self.rotate(images, angle=angle)
return (rotated_images, labels)
This time, instead of providing a fixed value for the angle
argument in the constructor, we set it to the output of the dali.ops.random.Uniform
operator.
Let us check the result:
[17]:
pipe = RandomRotatedSimplePipeline(max_batch_size, 1, 0)
pipe.build()
[18]:
pipe_out = pipe.run()
images, labels = pipe_out
show_images(images)
This time, the rotation angle is randomly selected from a value range.
Adding GPU Acceleration#
DALI offers access to GPU accelerated operators, that can increase the speed of the input and augmentation pipeline and let it scale to multi-GPU systems.
Copying Tensors to GPU#
Let us modify our previous example of the RandomRotatedSimplePipeline
to use the GPU for the rotation.
[19]:
class RandomRotatedGPUPipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(RandomRotatedGPUPipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(
file_root=image_dir, random_shuffle=True, initial_fill=21
)
self.decode = ops.decoders.Image(device="cpu", output_type=types.RGB)
self.rotate = ops.Rotate(device="gpu", fill_value=0)
self.rng = ops.random.Uniform(range=(-10.0, 10.0))
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
angle = self.rng()
rotated_images = self.rotate(images.gpu(), angle=angle)
return (rotated_images, labels)
In order to tell DALI that we want to use the GPU, we needed to make 2 changes to the pipeline:
we added a
device = "gpu"
parameter to theRotate
operationwe changed input to the rotate from
images
, which is a tensor on the CPU, toimages.gpu()
which copies it to the GPU
[20]:
pipe = RandomRotatedGPUPipeline(max_batch_size, 1, 0)
pipe.build()
[21]:
pipe_out = pipe.run()
print(pipe_out)
(TensorListGPU(
[[[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
...
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]]
[[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
...
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]]
...
[[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
...
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]]
[[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
...
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]
[[0 0 0]
[0 0 0]
...
[0 0 0]
[0 0 0]]]],
dtype=DALIDataType.UINT8,
layout="HWC",
num_samples=8,
shape=[(583, 710, 3),
(477, 682, 3),
(482, 642, 3),
(761, 736, 3),
(467, 666, 3),
(449, 654, 3),
(510, 662, 3),
(463, 664, 3)]), TensorListCPU(
[[0]
[0]
[1]
[1]
[0]
[1]
[0]
[0]],
dtype=DALIDataType.INT32,
num_samples=8,
shape=[(1,), (1,), (1,), (1,), (1,), (1,), (1,), (1,)]))
pipe_out
still contains 2 TensorLists
, but this time the first output, result of the Rotate
operation, is on the GPU. We cannot access contents of TensorListGPU
directly from the CPU, so in order to visualize the result we need to copy it to the CPU by using as_cpu
method.
[22]:
images, labels = pipe_out
show_images(images.as_cpu())
Important Notice#
The default executor in DALI does not allow CPU operators to follow GPU ones. To enable more flexible pipelines, pass exec_dynamic=True
to the pipeline constructor or @pipeline_def
decorator.
Hybrid Decoding#
Sometimes, especially for higher resolution images, decoding images stored in JPEG format may become a bottleneck. To address this problem, nvJPEG and nvJPEG2000 libraries were developed. They split the decoding process between CPU and GPU, significantly reducing the decoding time.
Specifying “mixed” device parameter in decoders.Image
enables nvJPEG and nvJPEG2000 support. Other file formats are still decoded on the CPU.
[23]:
class HybridPipeline(Pipeline):
def __init__(self, max_batch_size, num_threads, device_id):
super(HybridPipeline, self).__init__(
max_batch_size, num_threads, device_id, seed=1234
)
self.input = ops.readers.File(
file_root=image_dir, random_shuffle=True, initial_fill=21
)
self.decode = ops.decoders.Image(device="mixed", output_type=types.RGB)
def define_graph(self):
jpegs, labels = self.input()
images = self.decode(jpegs)
# images are on the GPU
return (images, labels)
decoders.Image
with device=mixed
uses a hybrid approach of computation that employs both the CPU and the GPU. This means that it accepts CPU inputs, but returns GPU outputs. That is why images
objects returned from the pipeline are of type TensorListGPU
.
[24]:
pipe = HybridPipeline(max_batch_size, 1, 0)
pipe.build()
[25]:
pipe_out = pipe.run()
images, labels = pipe_out
show_images(images.as_cpu())
Let us compare the speed of decoders.Image
for ‘cpu’ and ‘mixed’ backends by measuring speed of ShuffledSimplePipeline
and HybridPipeline
with 4 CPU threads.
[26]:
from timeit import default_timer as timer
test_batch_size = 64
def speedtest(pipeclass, batch, n_threads):
pipe = pipeclass(batch, n_threads, 0)
pipe.build()
# warmup
for i in range(5):
pipe.run()
# test
n_test = 20
t_start = timer()
for i in range(n_test):
pipe.run()
t = timer() - t_start
print("Speed: {} imgs/s".format((n_test * batch) / t))
[27]:
speedtest(ShuffledSimplePipeline, test_batch_size, 4)
Speed: 2710.326124438788 imgs/s
[28]:
speedtest(HybridPipeline, test_batch_size, 4)
Speed: 5860.449939768643 imgs/s
As we can see, using GPU accelerated decoding resulted in significant speedup.