Training with multiple GPUs

Here we show how to run training from Training neural network with DALI and JAX on multiple GPUs. We will use the same network and the same data pipeline. The only difference is that we will run it on multiple GPUs. To best understand the following content it is recommended to go through Training neural network with DALI and JAX first.

To learn how to run DALI iterator on multiple GPUs please refer to Getting started with JAX and DALI section about multiple GPU support. It explains how to run DALI iterator on multiple GPUs. Example below is building on top of that knowledge.

Training with automatic parallelization

In this section we want to spread the training across the GPUs with automatic parallelization mechanisms from JAX. To do that we need to define sharding that we want to apply to the computation.

To learn more about sharding please refer to JAX documentation section on distributed arrays and automatic parallelization.

[1]:
import jax
from jax.sharding import PositionalSharding, Mesh
from jax.experimental import mesh_utils


mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))
sharding = PositionalSharding(mesh)

print(sharding)
PositionalSharding([[{GPU 0}]
                    [{GPU 1}]])

Next we create DALI iterator function. We base it on the function from Training neural network with DALI and JAX example and add support for multiple GPUs with sharding and related arguments.

[2]:
from nvidia.dali.plugin.jax import data_iterator
import nvidia.dali.fn as fn
import nvidia.dali.types as types


image_size = 28
num_classes = 10


@data_iterator(
    output_map=["images", "labels"], reader_name="mnist_caffe2_reader", sharding=sharding
)
def mnist_training_iterator(data_path, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=True,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

For simplicity, in this tutorial we run the validation on a single GPU. We create appropriate DALI iterator function for validation data.

[3]:
@data_iterator(output_map=["images", "labels"], reader_name="mnist_caffe2_reader")
def mnist_validation_iterator(data_path):
    jpegs, labels = fn.readers.caffe2(
        path=data_path, random_shuffle=False, name="mnist_caffe2_reader"
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    return images, labels

We define some parameters for training and create iterator instances.

[4]:
import os

training_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/training/")
validation_data_path = os.path.join(os.environ["DALI_EXTRA_PATH"], "db/MNIST/testing/")

batch_size = 200
num_epochs = 5


training_iterator = mnist_training_iterator(batch_size=batch_size, data_path=training_data_path)
print(f"Number of batches in training iterator = {len(training_iterator)}")

validation_iterator = mnist_validation_iterator(
    batch_size=batch_size, data_path=validation_data_path
)
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in training iterator = 300
Number of batches in validation iterator = 50

With all this setup ready we can start the actual training. We import model related utilities from Training neural network with DALI and JAX example and use them to train the model.

Each batch in the training loop contains images and labels sharded according to sharding argument.

Note, how for validation we pull the model to one GPU. As said before, this was done for simplicity. In real world scenario, you could run validation on all GPUs and average the results.

[5]:
from model import init_model, accuracy
from model import update

model = init_model()

for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update(model, batch)

    model_on_one_device = jax.tree_map(lambda x: jax.device_put(x, jax.devices()[0]), model)
    test_acc = accuracy(model_on_one_device, validation_iterator)

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6739000082015991
Epoch 1 sec
Test set accuracy 0.7844000458717346
Epoch 2 sec
Test set accuracy 0.8244000673294067
Epoch 3 sec
Test set accuracy 0.8455000519752502
Epoch 4 sec
Test set accuracy 0.860200047492981

Training with pmapped iterator

JAX offers another mechanism to spread computation across multiple devices: pmap function. DALI can support this way of parallelization as well.

To learn more about pmap look into JAX documentation.

In DALI, to configure the iterator in a way compatible with pmapped functions we pass devices argument instead of sharding. Here we use all available GPUs. Iterator will return batch that is sharded across all GPUs.

As with sharding, under the hood iterator will create multiple instances of DALI pipeline and each instance will be assigned to one GPU. When the outputs are requested, DALI will synchronize the instances and return the results as a single batch.

[6]:
@data_iterator(
    output_map=["images", "labels"], reader_name="mnist_caffe2_reader", devices=jax.devices()
)
def mnist_training_iterator(data_path, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=True,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id,
    )
    images = fn.decoders.image(jpegs, device="mixed", output_type=types.GRAY)
    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()
    labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

We create an iterator instance the same way as before:

[7]:
print("Creating training iterator")
training_iterator = mnist_training_iterator(batch_size=batch_size, data_path=training_data_path)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

For validation, we will use the same iterator as before. Since we are running it on single GPU, we don’t need to change anything. We can again pull the model to one GPU and run the validation.

[8]:
print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Number of batches in validation iterator = 50

For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with pmap you can look into Parallel Evaluation in JAX from the JAX documentation.

[9]:
import jax.numpy as jnp
from model import init_model, accuracy


model = init_model()
model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model)

For multigpu training we import update_parallel function. It is the same as the update function with added gradients synchronization across the devices. This will ensure that replicas of the model from different devices remain the same.

Since we want to run validation on a single GPU, we extract only one replica of the model and pass it to accuracy function.

[10]:
from model import update_parallel


for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update_parallel(model, batch)

    test_acc = accuracy(jax.tree_map(lambda x: x[0], model), validation_iterator)

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6885000467300415
Epoch 1 sec
Test set accuracy 0.7829000353813171
Epoch 2 sec
Test set accuracy 0.8222000598907471
Epoch 3 sec
Test set accuracy 0.8438000679016113
Epoch 4 sec
Test set accuracy 0.8580000400543213