Training neural network with DALI and Flax

This simple example shows how to train a neural network implemented in Flax with DALI pipelines. If you want to learn more about training neural networks with Flax, look into Flax Getting Started example.

DALI setup is very similar to the training example with pure JAX. The only difference is the addition of a trailing dimension to the returned image to make it compatible with Flax convolutions. If you are not familiar with how to use DALI with JAX you can learn more in the DALI and JAX Getting Started example.

We use MNIST in Caffe2 format from DALI_extra.

[1]:
import os

training_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/")
validation_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/")

First step is to create an iterator definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. In this simple example we have fn.readers.caffe2 for reading data in Caffe2 format, fn.decoders.image for image decoding, fn.crop_mirror_normalize used to normalize the images and fn.reshape to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with labels.gpu() and apply one hot encoding to them for training with fn.one_hot.

This example focuses on how to use DALI pipeline with JAX. For more information on DALI iterator look into DALI and JAX getting started and pipeline documentation

[2]:
import nvidia.dali.fn as fn
import nvidia.dali.types as types

from nvidia.dali.plugin.jax import data_iterator


batch_size = 50
image_size = 28
num_classes = 10


@data_iterator(output_map=["images", "labels"], reader_name="mnist_caffe2_reader")
def mnist_iterator(data_path, is_training):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=is_training, name="mnist_caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0])
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if is_training:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

With the iterator definition function we can now create DALI iterators.

[3]:
print("Creating iterators")
training_iterator = mnist_iterator(
    data_path=training_data_path, is_training=True, batch_size=batch_size
)
validation_iterator = mnist_iterator(
    data_path=validation_data_path, is_training=False, batch_size=batch_size
)

print(training_iterator)
print(validation_iterator)

print(f"Number of batches in training iterator = {len(training_iterator)}")
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating iterators
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc240f4e50>
<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc1c78e020>
Number of batches in training iterator = 1200
Number of batches in validation iterator = 200

With the setup above, DALI iterators are ready for the training.

Now we need to setup model and training utilities. The goal of this notebook is not to explain Flax concepts. We want to show how to train models implemented in Flax with DALI as a data loading and preprocessing library. We used standard Flax tools do define simple neural network. We have functions to create an instance of this network, run one training step on it and calculate accuracy of the model at the end of each epoch.

If you want to learn more about Flax and get better understanding of the code below, look into Flax Documentation.

[4]:
import jax
import jax.numpy as jnp

from flax import linen as nn
from flax.training import train_state

import optax


class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=784)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=1024)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


def create_model_state(rng, learning_rate, momentum):
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([784]))["params"]
    tx = optax.sgd(learning_rate, momentum)
    return train_state.TrainState.create(apply_fn=cnn.apply, params=params, tx=tx)


@jax.jit
def train_step(model_state, batch):
    def loss_fn(params):
        logits = model_state.apply_fn({"params": params}, batch["images"])
        loss = optax.softmax_cross_entropy(logits=logits, labels=batch["labels"]).mean()
        return loss

    grad_fn = jax.grad(loss_fn)
    grads = grad_fn(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state


def accuracy(model_state, iterator):
    correct_predictions = 0
    for batch in iterator:
        logits = model_state.apply_fn({"params": model_state.params}, batch["images"])
        correct_predictions = correct_predictions + jnp.sum(
            batch["labels"].ravel() == jnp.argmax(logits, axis=-1)
        )

    return correct_predictions / iterator.size

With utilities defined above, we can create an instance of the model we want to train.

[5]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = create_model_state(init_rng, learning_rate, momentum)

At this point, everything is ready to run the training.

[6]:
print("Starting training")

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = train_step(model_state, batch)

    acc = accuracy(model_state, validation_iterator)
    print(f"Accuracy = {acc}")
Starting training
Epoch 0
Accuracy = 0.9551000595092773
Epoch 1
Accuracy = 0.9691000580787659
Epoch 2
Accuracy = 0.9738000631332397
Epoch 3
Accuracy = 0.9622000455856323
Epoch 4
Accuracy = 0.9604000449180603

Multiple GPUs with DALI and FLAX

This section shows how to extend the example above to use multiple GPUs.

Again, we start with creating an iterator definition function. It is a slightly modified version of the function we have seen before.

Note the new arguments passed to fn.readers.caffe2, num_shards and shard_id. They are used to control sharding: - num_shards sets the total number of shards - shard_id tells the pipeline for which shard in the training it is responsible.

We add devices argument to the decorator to specify which devices we want to use. Here we use all GPUs available to JAX on the machine.

[7]:
batch_size = 200
image_size = 28
num_classes = 10


@data_iterator(
    output_map=["images", "labels"], reader_name="mnist_caffe2_reader", devices=jax.devices()
)
def mnist_sharded_iterator(data_path, is_training, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=is_training,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout="CHW")
    images = fn.reshape(images, shape=[-1])  # Flatten the output image

    labels = labels.gpu()

    if is_training:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

With the iterator definition function we can now create DALI iterators for training on multiple GPUs. This iterator will return outputs compatible with pmapped JAX functions.

[8]:
print("Creating training iterator")
training_iterator = mnist_sharded_iterator(
    data_path=training_data_path, is_training=True, batch_size=batch_size
)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

For simplicity, we will run validation on one GPU. We can reuse the validation iterator from the single GPU example. The only difference is that we will need to pull the model to the same GPU. In real life scenario this might be costly but for this toy educational example is suficient.

For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with pmap you can look into Parallel Evaluation in JAX from the JAX documentation and Ensembling on multiple devices from Flax documentation.

[9]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

learning_rate = 0.1
momentum = 0.9

model_state = jax.pmap(create_model_state, static_broadcasted_argnums=(1, 2))(
    jax.random.split(init_rng, jax.device_count()), learning_rate, momentum
)

Since we want to run validation on single GPU, we extract only one replica of the model and pass it to accuracy function.

Now, we are ready to train Flax model on multiple GPUs with DALI as the data source.

[10]:
import flax

parallel_train_step = jax.pmap(train_step)

num_epochs = 5
for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    for batch in training_iterator:
        model_state = parallel_train_step(model_state, batch)

    acc = accuracy(flax.jax_utils.unreplicate(model_state), validation_iterator)
    print(f"Accuracy = {acc}")
Epoch 0
Accuracy = 0.9509000182151794
Epoch 1
Accuracy = 0.9643000364303589
Epoch 2
Accuracy = 0.9724000692367554
Epoch 3
Accuracy = 0.9701000452041626
Epoch 4
Accuracy = 0.9758000373840332