Getting started with JAX and DALI#
This tutorial shows how to use DALI with JAX. You will learn how to use DALI as a data source for JAX workflows, how to use GPU to accelerate data preprocessing and how to scale up your training with multi-GPU.
Prerequisites#
This tutorial assumes that you have already installed DALI and JAX with GPU support. If you haven’t done so, please follow the DALI installation guide and JAX installation guide.
Test data used for this example can be found in DALI Github page.
[1]:
image_dir = "../../data/images"
Quick start#
The code below is a simple and ready to copy example of how to use DALI with JAX. The following sections will go through it step by step and explain how it works.
[2]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=8)
batch = next(iterator) # batch of data ready to be used by JAX
Running DALI with JAX#
In DALI, the main concept that you need to get accustomed to is the Pipeline
. It is a graph of operations that are executed asynchronously with regards to the main Python thread. The Pipeline
is defined in Python, but the execution is done in C++, which makes it fast and efficient. Computations defined as parts of a Pipeline
can be executed on the CPU or the GPU.
We start by defining a pipeline function. Here we declare how the graph of operetions will look like. In this starting example we use operations from nvidia.dali.fn
module. This module contains all the basic operations that are available in DALI:
[3]:
import nvidia.dali.fn as fn
def simple_pipeline():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs)
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
This function defines the following preprocessing pipline: - read raw JPEGs and labels from the storage - decode JPEGs to RGB format - resize decoded images to 256x256
Even though it is very simple it can serve as a good approximation of the usual preprocessing pipeline that is used in many computer vision tasks. The only difference is that we do not apply any augmentations here. We will cover how to apply augmentations in one of the following sections.
Above we defined the graph of operations. Now we need to run it in a JAX context. DALI provides simple to use decorator to transform pipeline definition function.
[4]:
from nvidia.dali.plugin.jax import data_iterator
This decorator can be used to create a function that will produce data iterators compatible with JAX. Decorator accepts arguments to control the final iterator. For now we will focus on: - pipeline function to transform, - output_map
- a list of names for the outputs that will be used in JAX, - reader_name
- a name of the reader operator that will be used to read data from the storage.
[5]:
iterator_fn = data_iterator(
simple_pipeline, output_map=["images", "labels"], reader_name="image_reader"
)
Note, how the reader_name
value is the same as the name
value used for the fn.readers.file
operator in the pipeline function. It tells the iterator which operator is used to read the data from the storage and should be queried for the number of samples in the dataset.
iterator_fn
is a function that will produce data iterators. It can be used to create multiple iterators that will iterate over the same dataset. This is useful when you want to reuse some code for training and validation pipelines. Applying the decorator adds some additional arguments to the function. One of them is batch_size
used to control the number of samples per batch of data be produced by the iterator:
[6]:
iterator = iterator_fn(batch_size=1)
We can use the iterator to get data. The iterator returns a dictionary with the keys defined in the output_map
argument. The values are JAX arrays.
[7]:
output = next(iterator)
print(output.keys())
print(type(output["images"]))
print(output["images"].shape)
print(type(output["labels"]))
print(output["labels"].shape)
dict_keys(['images', 'labels'])
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 256, 256, 3)
<class 'jaxlib.xla_extension.ArrayImpl'>
(1, 1)
To visualize the outputs we create a simple function to print the images using matplotlib:
[8]:
import matplotlib.pyplot as plt
def show_image(image):
plt.imshow(image)
We can now use it to inspect the content of the output:
[9]:
image = output["images"]
show_image(image[0])
print(f'Label = {output["labels"][0]}')
Label = [0]
We can continue to run the iterator and get the next batch:
[10]:
output = next(iterator)
image = output["images"]
show_image(image[0])
print(f'Label = {output["labels"][0]}')
Label = [0]
Note that the data_iterator
can be used in a declarative way as well. Since this is a more condensed way to express the same thing we will use it in the following sections:
[11]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs)
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=1)
Iterator is comatible with Python __iter__
protocol. It can be used in for
loops or with next
function. It is also compatible with len
function. len
returns the number of batches in the dataset.
[12]:
iterator = iterator_fn(batch_size=1)
print(f"Iterator size: {len(iterator)}")
for batch_id, batch in enumerate(iterator):
print(batch_id)
Iterator size: 21
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
Batching#
For now we returned only one sample per iteration or next
call. Note that the returned arrays had a leading dimension of size 1. This is because the iterator returns a batch of data. The size of the batch is controlled by the batch_size
argument of the iterator_fn
function.
We can use the same function to create an iterator returning batch of data that contains multiple samples.
[13]:
iterator = iterator_fn(batch_size=8)
[14]:
batch = next(iterator)
print(type(batch["images"]))
print(batch["images"].shape)
print(type(batch["labels"]))
print(batch["labels"].shape)
<class 'jaxlib.xla_extension.ArrayImpl'>
(8, 256, 256, 3)
<class 'jaxlib.xla_extension.ArrayImpl'>
(8, 1)
We can adjust the show_image
function to plot the whole batch:
[15]:
from matplotlib import gridspec
def show_image(images):
columns = 4
rows = (images.shape[0] + 1) // (columns)
plt.figure(figsize=(24, (24 // columns) * rows))
gs = gridspec.GridSpec(rows, columns)
for j in range(rows * columns):
plt.subplot(gs[j])
plt.axis("off")
plt.imshow(images[j])
[16]:
show_image(batch["images"])
GPU acceleration#
For now the preprocessing was using only the CPU even though we passed the device_id=0
. Note the backend of the output JAX arrays:
[17]:
print(f'Images backing device: {batch["images"].device()}')
Images backing device: TFRT_CPU_0
One of the main features that DALI offers is the ability to run the preprocessing on the GPU. For our simple example let’s run the image resizing on the GPU. The only change required to do it is to move the decoded images to the GPU. We can do it using gpu
method. Resize will recognize that its input is on the GPU and will execute the operation on the GPU as well.
[18]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs)
images = images.gpu()
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=8)
We see that images
is backed by the GPU now.
As mentioned before, decorating the function with data_iterator
adds some arguments to it. We discussed batch_size
already. Another one is device_id
. It is used to control which GPU will be used to execute computations. We can use it to run the whole pipeline on another GPU like this:
[19]:
iterator = iterator_fn(batch_size=1, device_id=1)
batch = next(iterator)
print(f'Images backing device: {batch["images"].device()}')
Images backing device: cuda:1
Note how the backing device for the output changed to the GPU with device ID equal to 1.
Hardware accelerated decoding#
Another important feature of DALI is the ability to GPU accelerate JPEG decoding by using nvJPEG and nvJPEG2000 libraries. This is especially useful when you have a lot of high resolution images that need to be decoded. nvJPEG and nvJPEG2000 are designed to take advantage of the GPU hardware acceleratoed decoder to remove this bottleneck. To learn more about them you can take a look at this developer page.
Not all aspects of the JPEG decoding process are well suited to be parallelized by using the GPU. The CPU is responsible for handling the sequential part of the decoding. To make use of this hybrid approach, you can set mixed
as the device
argument value in the fn.decoders.image
operator.
After this change the decoded outputs are backed by the GPU so it is no longer necessary to move them to the GPU manually. Updated code looks like this:
[20]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=8)
[21]:
batch = next(iterator)
print(f'Images backing device: {batch["images"].device()}')
print(f'Labels backing device: {batch["labels"].device()}')
Images backing device: cuda:0
Labels backing device: TFRT_CPU_0
Random shuffling#
One of the requried steps in training is to shuffle the data. DALI provides a way to do it efficiently. We can use the random_shuffle
argument of the reader to do it. The argument takes a boolean value. If it is set to True
the data will be shuffled randomly. If it is set to False
the data will be returned in the order it was read from the storage. To ensure the reproducibility of the results we can also set the seed
argument to a fixed value:
[22]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(
file_root=image_dir, name="image_reader", random_shuffle=True
)
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=8, seed=0)
batch = next(iterator)
show_image(batch["images"])
Augmentations#
DALI provides a wide range of augmentations that can be used to improve the quality of the training data. To learn more about them you can take a look at DALI documentation. In this section we will cover how to use them in JAX workflows. For this simple example we want to apply random rotation to the images. We want to rotate the images by a random angle in the range [-10, 10]
. To do it we will
use fn.random.uniform
to generate random angles and fn.deformations.rotate
to perform the rotation.
[23]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
jpegs, labels = fn.readers.file(
file_root=image_dir, name="image_reader", random_shuffle=True
)
images = fn.decoders.image(jpegs, device="mixed")
angle = fn.random.uniform(range=(-10.0, 10.0))
images = fn.rotate(images, angle=angle, fill_value=0)
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
iterator = iterator_fn(batch_size=8, seed=0)
batch = next(iterator)
show_image(batch["images"])
Multiple GPUs#
One of the strenghts of JAX is the ability to scale up the training to multiple GPUs very easily. DALI provides a simple way to do it as well that is compatible with JAX scale up mechanism. data_iterator
decorator accepts sharding
argument. We can pass the same value that is used to scale up other computations in JAX.
Let’s assume that in a JAX workflow we want to parallelize the computations across multiple GPUs along the batch dimension. One of the ways to accomplish this is with NamedSharding
. It is a simple way to express the sharding pattern. It maps the dimension names to the list of devices that will be used to execute computations along this dimension. In our case we want to shard along the dimension that we name “batch” to communicate what it represents. We want to use all available GPUs to
execute computations along this dimension. We can use jax.local_devices()
to get the list of available devices. The code looks like this:
[24]:
import jax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
mesh = Mesh(jax.devices(), axis_names=("batch"))
sharding = NamedSharding(mesh, PartitionSpec("batch"))
print(sharding)
NamedSharding(mesh=Mesh('batch': 2), spec=PartitionSpec('batch',))
We want DALI iterator to return outputs compatible with this sharding pattern. We can do it by passing the sharding
argument to the data_iterator
decorator.
One modification needed in the decorated function is an introduction of num_shards
and shard_id
arguments. They can be used to pass the information about sharding to the reader. It will read only the part of the dataset that is assigned to the current shard. To learn more about DALI sharding mechanism look into DALI sharding doc.
[25]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn(num_shards=1, shard_id=0):
jpegs, labels = fn.readers.file(
file_root=image_dir,
name="image_reader",
num_shards=num_shards,
shard_id=shard_id,
)
images = fn.decoders.image(jpegs, device="mixed")
angle = fn.random.uniform(range=(-10.0, 10.0))
images = fn.rotate(images, angle=angle, fill_value=0)
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels.gpu()
Default values of num_shards
and shard_id
set above are the same as the dafault values for these arguments in fn.reader.file
. In this configuration there is no sharding - reader will read the whole dataset and there will be only one shard with shard_id == 0
. When we pass sharding
to the decorator these arguments will be set automatically to appropriate values.
Now we are ready to scale up the iterator_fn
to multiple GPUs. Note the addition of sharding
argumnet to the decorator:
[26]:
@data_iterator(
output_map=["images", "labels"],
reader_name="image_reader",
sharding=sharding,
)
def iterator_fn(num_shards=1, shard_id=0):
jpegs, labels = fn.readers.file(
file_root=image_dir,
name="image_reader",
num_shards=num_shards,
shard_id=shard_id,
)
images = fn.decoders.image(jpegs, device="mixed")
angle = fn.random.uniform(range=(-10.0, 10.0))
images = fn.rotate(images, angle=angle, fill_value=0)
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels.gpu()
In this setup we want to spread the preprocessing computations across multiple GPUs. We no longer need to pass the device_id
. It is set automatically based on the sharding
. In runtime we will have multiple instances of the pipeline running on different GPUs. Each of them will read a different part of the dataset. Iterator will get outputs from all the pipelines and will build a batch from them. The batch will be sharded along the “batch” dimension. The sharding pattern will be compatible
with the sharding pattern used in the rest of the JAX workflow.
[27]:
iterator = iterator_fn(batch_size=8)
[28]:
batch = next(iterator)
images = batch["images"]
print(f"Images shape: {images.shape}")
print(f"Images backing device: {images.devices()}")
print(f"Images sharding: {images.sharding}")
Images shape: (8, 256, 256, 3)
Images backing device: {cuda(id=0), cuda(id=1)}
Images sharding: NamedSharding(mesh=Mesh('batch': 2), spec=PartitionSpec('batch',))
We can use visualise_array_sharding
from JAX to visualize how the data is distributed across the GPUs.
[29]:
jax.debug.visualize_array_sharding(images.ravel())
GPU 0 GPU 1
We can further look into the content of the batch to see how the data is distributed across the GPUs. We can see that the first half of the batch is on the first GPU and the second half is on the second GPU:
[30]:
print(
f"Shard 0 device: {images.device_buffers[0].device()}, "
f"shape: {images.device_buffers[0].shape}"
)
print(
f"Shard 1 device: {images.device_buffers[1].device()}, "
f"shape: {images.device_buffers[1].shape}"
)
Shard 0 device: cuda:0, shape: (4, 256, 256, 3)
Shard 1 device: cuda:1, shape: (4, 256, 256, 3)
Even though the data is distributed across multiple GPUs, we can still use the show_image
function to visualize it. It will automatically collect the data from all the GPUs:
[31]:
show_image(batch["images"])
Technical details#
This section dives deeper into the technical aspects of the functionalities used in this tutorial, providing a more comprehensive understanding of the tools and concepts employed.
num_threads
performance considerations#
num_threads
is another argument added to a function decorated with data_iterator
. It is used to control the number of CPU threads that will be used to execute the iterator. It is important to set it to the right value to get the best performance. The optimal value depends on the use case, batch size, and hardware configuration. Let’s take a look at how it affects the performance of the iterator.
First, we create a function that runs the iterator for a fixed number of iterations. We will use it to benchmark the performance of the iterator for different values of num_threads
:
[32]:
def run_iterator(iterator, epoch_num=10):
for epoch in range(epoch_num):
for batch in iterator:
pass
We instantiate the iterator with num_threads=1
and run it:
[33]:
iterator = iterator_fn(batch_size=64, num_threads=1)
[34]:
%%timeit
run_iterator(iterator)
188 ms ± 11.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Now, we can run the benchmark for different value of num_threads
:
[35]:
iterator = iterator_fn(batch_size=64, num_threads=8)
[36]:
%%timeit
run_iterator(iterator)
89.3 ms ± 640 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
We see how changing the num_threads
affects the performance of the iterator. Note, that if you are running this tutorial on a different machine you may see different results.
To achieve the best performance num_threads
needs to be fine tuned for the specific use case. The optimal value depends on the hardware configuration and the batch size.
data_iterator
decorator internals#
First, lets look at the type of the iterator:
[37]:
print(type(iterator))
<class 'nvidia.dali.plugin.jax.iterator.DALIGenericIterator'>
DALIGenericIterator
is a high-level interface that simplifies the integration of DALI pipelines with machine learning frameworks. It is a wrapper for DALI pipeline objects that is compatible with JAX. DALI pipeline and iterator can be created manually without the data_iterator
decorator. Let’s go back to the original pipeline definition. We can create a pipeline object and later use this object to create an iterator:
[38]:
from nvidia.dali.pipeline import pipeline_def
@pipeline_def
def simple_pipeline():
jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
images = fn.decoders.image(jpegs, device="mixed")
images = fn.resize(images, resize_x=256, resize_y=256)
return images, labels
pipeline = simple_pipeline(batch_size=8, num_threads=1, device_id=0)
print(type(pipeline))
<class 'nvidia.dali.pipeline.Pipeline'>
With the pipeline object we create an iterator:
[39]:
from nvidia.dali.plugin.jax import DALIGenericIterator
iterator = DALIGenericIterator(
pipeline, output_map=["images", "labels"], reader_name="image_reader"
)
print(type(iterator))
<class 'nvidia.dali.plugin.jax.iterator.DALIGenericIterator'>
Iterator created this way is the same as the one ceated with the data_iterator
decorator. We can use it to get data the same way:
[40]:
batch = next(iterator)
print(f'Images backing device: {batch["images"].device()}')
print(f'Labels backing device: {batch["labels"].device()}')
Images backing device: cuda:0
Labels backing device: TFRT_CPU_0
Where to go next#
Congratulations, you learned how to integrate DALI with JAX to efficiently process and augment your data for machine learning workflows. With this foundation, you can explore more advanced topics and enhance your knowledge further: - to continue learning about DALI and JAX you can visit related section of DALI documentation. It containes more detailed information and end-to-end training examples for DALI and JAX including libraries from JAX ecosystem like Flax, T5X and Pax. - to find out more about DALI in general you can visit DALI documentation. It contains detailed information about all the features of DALI. - if you have any questions about DALI you can visit DALI Github page and create an issue. We will be happy to hear from you. - if you prefere video content we have a curated list of video materials about DALI.