Training neural network with DALI and JAX

This simple example shows how to train a neural network implemented in JAX with DALI pipelines. It builds on MNIST training example from JAX codebse that can be found here.

We will use MNIST in Caffe2 format from DALI_extra.

[1]:
import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

First step is to create a pipeline definition function that will later be used to create instances of DALI pipelines. It defines all steps of the preprocessing. In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with labels.gpu() and apply one hot encoding to them for training with fn.one_hot.

This example focuses on how to use DALI pipeline with JAX. For more information on DALI pipeline look into Getting started and pipeline documentation

[2]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types


batch_size = 200
image_size = 28
num_classes = 10


@pipeline_def(device_id=0, batch_size=batch_size, num_threads=4, seed=0)
def mnist_pipeline(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader")
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Next step is to instantiate DALI pipelines and build them. Building creates and initializes pipeline internals.

[3]:
print('Creating pipelines')
training_pipeline = mnist_pipeline(data_path=training_data_path, random_shuffle=True)
validation_pipeline = mnist_pipeline(data_path=validation_data_path, random_shuffle=False)

print('Building pipelines')
training_pipeline.build()
validation_pipeline.build()

print(training_pipeline)
print(validation_pipeline)
Creating pipelines
Building pipelines
<nvidia.dali.pipeline.Pipeline object at 0x7f2f8797b7c0>
<nvidia.dali.pipeline.Pipeline object at 0x7f2f8797ba00>

DALI pipeline needs to be wrapped with appropriate DALI iterator to work with JAX. To get the iterator compatible with JAX we need to import it from DALI JAX plugin. In addition to the DALI pipeline object we can pass the output_map, reader_name and auto_reset parameters to the iterator.

Here is a quick explnation of how these parameters work:

  • output_map: iterators return a dictionary with outputs of the pipeline as its values. Keys in this dictionary are defined by output_map. For example, labels output returned from the DALI pipeline defined above will be accessible as iterator_output['labels'],

  • reader_name: setting this parameter introduces the notion of an epoch to our iterator. DALI pipeline itself is infinite, it will return the data indefinately, wrapping around the dataset. DALI readers (such as fn.readers.caffe2 used in this example) have access to the information about the size of the dataset. If we want to pass this information to the iterator, we need to point to the operator that should be queried for the dataset size. We do it by naming the operator (note name="mnist_caffe2_reader") and passing the same name as the value for reader_name argument,

  • auto_reset: this argument controls the behaviour of the iterator after the end of an epoch. If set to True, it will automatically reset the state of the iterator and prepare it to start the next epoch.

If you want to know more about iterator arguments you can look into JAX iterator documentation.

[4]:
from nvidia.dali.plugin import jax as dax


print('Creating iterators')
training_iterator = dax.DALIGenericIterator(
    training_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

validation_iterator = dax.DALIGenericIterator(
    validation_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(training_iterator)
print(f"Number of batches in training iterator = {len(training_iterator)}")
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating iterators
<nvidia.dali.plugin.jax.DALIGenericIterator object at 0x7f2f840ef700>
Number of batches in training iterator = 300
Number of batches in validation iterator = 50

With the setup above, DALI iterators are ready for the training.

Next we import training utilities implemented in JAX. init_model will create the model instance and initialize its parameters. In this simple example it is a MLP model with two hidden layers. update performs one iteration of the training. accuracy is a helper function to run validation after each epoch on the test set and get current accuracy of the model.

[5]:
from model import init_model, update, accuracy

At this point, everything is ready to run the training.

[6]:
print('Starting training')

model = init_model()
num_epochs = 10

for epoch in range(num_epochs):
    for batch in training_iterator:
        model = update(model, batch)

    test_acc = accuracy(model, validation_iterator)
    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Starting training
Epoch 0 sec
Test set accuracy 0.6741000413894653
Epoch 1 sec
Test set accuracy 0.7850000262260437
Epoch 2 sec
Test set accuracy 0.8251000642776489
Epoch 3 sec
Test set accuracy 0.8468000292778015
Epoch 4 sec
Test set accuracy 0.8614000678062439
Epoch 5 sec
Test set accuracy 0.8722000122070312
Epoch 6 sec
Test set accuracy 0.8781000375747681
Epoch 7 sec
Test set accuracy 0.8837000131607056
Epoch 8 sec
Test set accuracy 0.8876000642776489
Epoch 9 sec
Test set accuracy 0.8914000391960144