Training a neural network with DALI and JAX

This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebase that can be found here.

We will use MNIST in Caffe2 format from DALI_extra.

import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

First step is to create a definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing.

In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with labels.gpu(). Our model expects labels to be in one-hot encoding, so we use fn.one_hot to convert them.

This example focuses on how to use DALI to train a model defined in JAX. For more information on DALI and JAX integration look into Getting started with JAX and DALI and pipeline documentation

from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types

batch_size = 200
image_size = 28
num_classes = 10

@data_iterator(output_map=["images", "labels"], reader_name="caffe2_reader")
def mnist_iterator(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Next, we use the function to create DALI iterators for training and validation.

print('Creating iterators')

training_iterator = mnist_iterator(
    data_path=training_data_path, random_shuffle=True, batch_size=batch_size)

validation_iterator = mnist_iterator(
    data_path=validation_data_path, random_shuffle=False, batch_size=batch_size)

Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f2894462ef0>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7f28944634c0>

With the setup above, DALI iterators are ready for the training.

Finally, we import training utilities implemented in JAX. init_model will create the model instance and initialize its parameters. In this simple example it is a MLP model with two hidden layers. update performs one iteration of the training. accuracy is a helper function to run validation after each epoch on the test set and get current accuracy of the model.

from model import init_model, update, accuracy

At this point, everything is ready to run the training.

print('Starting training')

model = init_model()
num_epochs = 5

for epoch in range(num_epochs):
    for batch in training_iterator:
        model = update(model, batch)

    test_acc = accuracy(model, validation_iterator)
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Starting training
Epoch 0 sec
Test set accuracy 0.67330002784729
Epoch 1 sec
Test set accuracy 0.7855000495910645
Epoch 2 sec
Test set accuracy 0.8251000642776489
Epoch 3 sec
Test set accuracy 0.8469000458717346
Epoch 4 sec
Test set accuracy 0.8616000413894653