Training a neural network with DALI and JAX#

This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebase that can be found here.

We will use MNIST in Caffe2 format from DALI_extra.

[1]:
import os

training_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/"
)
validation_data_path = os.path.join(
    os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/"
)

First step is to create a definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing.

In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with labels.gpu(). Our model expects labels to be in one-hot encoding, so we use fn.one_hot to convert them.

This example focuses on how to use DALI to train a model defined in JAX. For more information on DALI and JAX integration look into Getting started with JAX and DALI and pipeline documentation

[2]:
from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types

batch_size = 200
image_size = 28
num_classes = 10


@data_iterator(output_map=["images", "labels"], reader_name="caffe2_reader")
def mnist_iterator(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=random_shuffle, name="caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.0], output_layout="CHW"
    )
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Next, we use the function to create DALI iterators for training and validation.

[3]:
print("Creating iterators")

training_iterator = mnist_iterator(
    data_path=training_data_path, random_shuffle=True, batch_size=batch_size
)

validation_iterator = mnist_iterator(
    data_path=validation_data_path, random_shuffle=False, batch_size=batch_size
)

print(training_iterator)
print(validation_iterator)
Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d7397b4b790>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7d739800e530>

With the setup above, DALI iterators are ready for the training.

Finally, we import training utilities implemented in JAX. init_model will create the model instance and initialize its parameters. In this simple example it is a MLP model with two hidden layers. update performs one iteration of the training. accuracy is a helper function to run validation after each epoch on the test set and get current accuracy of the model.

[4]:
from model import init_model, update, accuracy

jax.jit traces, compiles, and caches functions lazily on first invocation for a given input signature. During this process, XLA may capture CUDA graphs, which forbids some CUDA calls that DALI’s background thread uses internally. Since subsequent calls to the JAX function with inputs of the same shape and dtype don’t trigger compilation again, we can work around this by warming up with dummy inputs before starting any DALI workload:

[5]:
import jax.numpy as jnp

model = init_model()
dummy_images = jnp.empty(
    (batch_size, image_size * image_size), dtype=jnp.float32
)
dummy_labels = jnp.empty((batch_size, num_classes), dtype=jnp.float32)
_ = update(model, {"images": dummy_images, "labels": dummy_labels})

Warning

If you skip this step, CUDA graph capture will happen on the first call to update and may overlap with DALI’s execution, causing CUDA errors in JAX.

Alternatively, you can disable XLA command buffers entirely by setting XLA_FLAGS="--xla_gpu_enable_command_buffer=", at the cost of some performance.

At this point, everything is ready to run the training.

[6]:
print("Starting training")

num_epochs = 5

for epoch in range(num_epochs):
    for batch in training_iterator:
        model = update(model, batch)

    test_acc = accuracy(model, validation_iterator)
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Starting training
Epoch 0 sec
Test set accuracy 0.674500048160553
Epoch 1 sec
Test set accuracy 0.7854000329971313
Epoch 2 sec
Test set accuracy 0.8252000212669373
Epoch 3 sec
Test set accuracy 0.847100019454956
Epoch 4 sec
Test set accuracy 0.8618000149726868