Training neural network with DALI and Flax

This simple example shows how to train a neural network implemented in Flax with DALI pipelines. If you want to learn more about training neural networks with Flax, look into Flax Getting Started example.

DALI setup is very similar to the training example with pure JAX. The only difference is the addition of a trailing dimension to the returned image to make it compatible with Flax convolutions. If you are familiar with how to use DALI with JAX you can skip this part and move to the training section of this notebook.

We will use MNIST in Caffe2 format from DALI_extra.

[2]:
import os

training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')

First step is to create a pipeline definition function that will later be used to create instances of DALI pipelines. It defines all steps of the preprocessing. In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with labels.gpu() and apply one hot encoding to them for training with fn.one_hot.

This example focuses on how to use DALI pipeline with JAX. For more information on DALI pipeline look into Getting started and pipeline documentation

[3]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types


batch_size = 200
image_size = 28
num_classes = 10


@pipeline_def(device_id=0, batch_size=batch_size, num_threads=4, seed=0)
def mnist_pipeline(data_path, random_shuffle):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader")
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.])
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Next step is to instantiate DALI pipelines and build them. Building creates and initializes pipeline internals.

[4]:
training_pipeline = mnist_pipeline(data_path=training_data_path, random_shuffle=True)
validation_pipeline = mnist_pipeline(data_path=validation_data_path, random_shuffle=False)

print('Building pipelines')
training_pipeline.build()
validation_pipeline.build()

print(training_pipeline)
print(validation_pipeline)
Building pipelines
<nvidia.dali.pipeline.Pipeline object at 0x7fb455793940>
<nvidia.dali.pipeline.Pipeline object at 0x7fb455791390>

DALI pipeline needs to be wrapped with appropriate DALI iterator to work with JAX. To get the iterator compatible with JAX we need to import it from DALI JAX plugin. In addition to the DALI pipeline object we can pass the output_map, reader_name and auto_reset parameters to the iterator.

Here is a quick explnation of how these parameters work:

  • output_map: iterators return a dictionary with outputs of the pipeline as its values. Keys in this dictionary are defined by output_map. For example, labels output returned from the DALI pipeline defined above will be accessible as iterator_output['labels'],

  • reader_name: setting this parameter introduces the notion of an epoch to our iterator. DALI pipeline itself is infinite, it will return the data indefinately, wrapping around the dataset. DALI readers (such as fn.readers.caffe2 used in this example) have access to the information about the size of the dataset. If we want to pass this information to the iterator, we need to point to the operator that should be queried for the dataset size. We do it by naming the operator (note name="mnist_caffe2_reader") and passing the same name as the value for reader_name argument,

  • auto_reset: this argument controls the behaviour of the iterator after the end of an epoch. If set to True, it will automatically reset the state of the iterator and prepare it to start the next epoch.

If you want to know more about iterator arguments you can look into JAX iterator documentation.

[5]:
from nvidia.dali.plugin import jax as dax


print('Creating iterators')
training_iterator = dax.DALIGenericIterator(
    training_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

validation_iterator = dax.DALIGenericIterator(
    validation_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(training_iterator)
print(f"Number of batches in training iterator = {len(training_iterator)}")
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating iterators
<nvidia.dali.plugin.jax.DALIGenericIterator object at 0x7fb450301b70>
Number of batches in training iterator = 300
Number of batches in validation iterator = 50

With the setup above, DALI iterators are ready for the training.

Now we need to setup model and training utilities. The goal of this notebook is not to explain Flax concepts. We want to show how to train models implemented in Flax with DALI as a data loading and preprocessing library. We used standard Flax tools do define simple neural network. We have functions to create an instance of this network, run one training step on it and calculate accuracy of the model at the end of each epoch.

If you want to learn more about Flax and get better understanding of the code below, look into Flax Documentation.

[6]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=784)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


def create_model_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([784]))['params']
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(
        apply_fn=cnn.apply, params=params, tx=tx)


@jax.jit
def train_step(model_state, batch):
    def loss_fn(params):
        logits = model_state.apply_fn({'params': params}, batch['images'])
        loss = optax.softmax_cross_entropy(logits=logits, labels=batch['labels']).mean()
        return loss
    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state


def accuracy(model_state, iterator):
    correct_predictions = 0
    for batch in iterator:
        logits = model_state.apply_fn({'params': model_state.params}, batch['images'])
        correct_predictions = correct_predictions + \
            jnp.sum(batch['labels'].ravel() == jnp.argmax(logits, axis=-1))

    return correct_predictions / iterator.size

With utilities defined above, we can create an instance of the model we want to train.

[7]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = create_model_state(init_rng, learning_rate, momentum)

At this point, everything is ready to run the training.

[8]:
print('Starting training')

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = train_step(model_state, batch)

    acc = accuracy(model_state, validation_iterator)
    print(f"Accuracy = {acc}")
Starting training
Epoch 0
Accuracy = 0.9637000560760498
Epoch 1
Accuracy = 0.9690000414848328
Epoch 2
Accuracy = 0.975100040435791
Epoch 3
Accuracy = 0.9761000275611877
Epoch 4
Accuracy = 0.9765000343322754

Multiple GPUs with DALI and FLAX

This section shows how to extend the example above to use multiple GPUs.

Again, we start with creating a pipeline definition function. The pipeline was slightly modified to support multiple GPUs.

Note the new arguments passed to fn.readers.caffe2: num_shards and shard_id. They are used to control sharding: - num_shards sets the total number of shards - shard_id tells the pipeline for which shard in the training it is responsible.

Also, the device_id argument was removed from the decorator. Since we want these pipelines to run on different GPUs we will pass particular device_id in pipeline creation. Most often, device_id and shard_id will have the same value but it is not a requirement. In this example we want the total batch size to be the same as in the single GPU version. That is why we define batch_size_per_gpu as batch_size // jax.device_count(). Note, that if batch_size is not divisible by the number of devices this might require some adjustment to make sure all samples are used in every epoch of the training. If you want to learn more about DALI sharding behaviour look into DALI sharding docs page.

[9]:
batch_size = 200
image_size = 28
num_classes = 10
batch_size_per_gpu = batch_size // jax.device_count()


@pipeline_def(batch_size=batch_size_per_gpu, num_threads=4, seed=0)
def mnist_sharded_pipeline(data_path, random_shuffle, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id)
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels


Note the device_id values that are passed to place a pipeline on a different device.

[10]:
pipelines = []
for id, device in enumerate(jax.devices()):
    pipeline = mnist_sharded_pipeline(
        data_path=training_data_path, random_shuffle=True, num_shards=jax.device_count(), shard_id=id, device_id=id)
    print(f'Pipeline {pipeline} working on device {pipeline.device_id}')
    pipelines.append(pipeline)
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7db28c2b2da0> working on device 0
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7db28c2b30d0> working on device 1

We created multiple DALI pipelines. Each will run its computations on a different GPU. Each of them will start the preprocessing from a different shard of the training dataset. In this configuration each pipeline will move to the next shard in the next epoch. If you want to control this you can look into stick_to_shard argument in the readers.

Like in the single GPU example, we create training iterator. It will encapsulate all the pipelines that we created and return a dictionary of JAX arrays. With this simple configuration it will return arrays compatible with JAX pmaped functions. Leaves of the returned dictionary will have shape (num_devices, batch_per_device, ...) and each slice across the first dimension of the array will reside on a different GPU.

[11]:
print('Creating training iterator')
training_iterator = dax.DALIGenericIterator(
    pipelines,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

For simplicity, we will run validation on one GPU. We can reuse the validation iterator from the single GPU example. The only difference is that we will need to pull the model to the same GPU. In real life scenario this might be costly but for this toy educational example is suficient.

For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with pmap you can look into Parallel Evaluation in JAX from the JAX documentation and Ensembling on multiple devices from Flax documentation.

[12]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = jax.pmap(
    create_model_state,
    static_broadcasted_argnums=(1, 2))(
        jax.random.split(
            init_rng,
            jax.device_count()),
        learning_rate,
        momentum)

Since we want to run validation on single GPU, we extract only one replica of the model and pass it to accuracy function.

Now, we are ready to train Flax model on multiple GPUs with DALI as the data source.

[13]:
import flax

parallel_train_step = jax.pmap(train_step)

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = parallel_train_step(model_state, batch)

    acc = accuracy(
        flax.jax_utils.unreplicate(model_state),
        validation_iterator)
    print(f"Accuracy = {acc}")
Epoch 0
Accuracy = 0.9445000290870667
Epoch 1
Accuracy = 0.9641000628471375
Epoch 2
Accuracy = 0.9654000401496887
Epoch 3
Accuracy = 0.9724000692367554
Epoch 4
Accuracy = 0.9760000705718994