Training with multiple GPUs

Here we show how to run training from “Training neural network with DALI and JAX” on multiple GPUs. If you haven’t already done so it is best to start with single GPU example to better understand following content.

Again, we start with creating a pipeline definition function. For the future multiple GPU support it was slightly modified from the single GPU version.

Note the new arguments passed to fn.readers.caffe2: num_shards and shard_id. They are used to control sharding: - num_shards sets the total number of shards - shard_id tells the pipeline for which shard in the training it is responsible.

Also, the device_id argument was removed from the decorator. Since we want these pipelines run on different GPUs we will pass particular device_id in pipeline creation. Most often, device_id and shard_id will have the same value but it is not a requirement. In this example we want the total batch size to be the same as in the single GPU version. That is why we define batch_size_per_gpu as batch_size // jax.device_cout(). Note, that if batch_size is not divisible by the number of devices this might require some adjustment to make sure all samples are used in every epoch of the training. If you want to learn more about DALI sharding behaviour look into DALI sharding docs page.

[1]:
from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types
import jax
import os


training_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/training/')
validation_data_path = os.path.join(os.environ['DALI_EXTRA_PATH'], 'db/MNIST/testing/')


batch_size = 200
image_size = 28
num_classes = 10
batch_size_per_gpu = batch_size // jax.device_count()


@pipeline_def(batch_size=batch_size_per_gpu, num_threads=4, seed=0)
def mnist_sharded_pipeline(data_path, random_shuffle, num_shards, shard_id):
    jpegs, labels = fn.readers.caffe2(
        path=data_path,
        random_shuffle=random_shuffle,
        name="mnist_caffe2_reader",
        num_shards=num_shards,
        shard_id=shard_id)
    images = fn.decoders.image(
        jpegs, device='mixed', output_type=types.GRAY)
    images = fn.crop_mirror_normalize(
        images, dtype=types.FLOAT, std=[255.], output_layout="CHW")
    images = fn.reshape(images, shape=[image_size * image_size])

    labels = labels.gpu()

    if random_shuffle:
        labels = fn.one_hot(labels, num_classes=num_classes)

    return images, labels

Note the device_id values that are passed to place a pipeline on a different device.

[2]:
from nvidia.dali.plugin import jax as dax

print('Creating training pipelines')

pipelines = []
for id, device in enumerate(jax.devices()):
    pipeline = mnist_sharded_pipeline(
        data_path=training_data_path, random_shuffle=True, num_shards=jax.device_count(), shard_id=id, device_id=id)
    print(f'Pipeline {pipeline} working on device {pipeline.device_id}')
    pipelines.append(pipeline)

Creating training pipelines
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7f7e004fae00> working on device 0
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7f7e004fa350> working on device 1

We created multiple DALI pipelines. Each will run its computations on a differnt GPU. Each of them will start the preprocessing from a differnt shard of the training dataset. In this configuration each pipeline will move to the next shard in the next epoch. If you want to control this you can look into stick_to_shard argument in the readers.

Like in the single GPU example, we create training iterator. It will encapsulate all the pipelines that we created and return a dictionary of JAX arrays. With this simple configuration it will return arrays compatible with JAX pmaped functions. Leaves of the returned dictionary will have shape (num_devices, batch_per_device, ...) and each slice across the first dimension of the array will reside on a different GPU.

[3]:
print('Creating training iterator')
training_iterator = dax.DALIGenericIterator(
    pipelines,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

For simplicity, we will run validation on one GPU. We can pass num_shards=1, shard_id=0 and device_id=0 to mnist_sharded_pipeline. It will result in a pipeline identical as in the single GPU example and we can create the validation iterator the same way.

[4]:
print('Creating validation iterator')
validation_pipeline = mnist_sharded_pipeline(data_path=validation_data_path, random_shuffle=False, num_shards=1, shard_id=0, device_id=0)

validation_iterator = dax.DALIGenericIterator(
    validation_pipeline,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True)

print(f"Number of batches in validation iterator = {len(validation_iterator)}")
Creating validation iterator
Number of batches in validation iterator = 100

For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with pmap you can look into Parallel Evaluation in JAX from the JAX documentation.

[5]:
import jax.numpy as jnp
from model import init_model, accuracy


model = init_model()
model = jax.tree_map(lambda x: jnp.array([x] * jax.device_count()), model)

For multigpu training we import update_parallel function. It is the same as the update function with added gradients synchronization across the devices. This will ensure that replicas of the model from different devices remain the same.

Since we want to run validation on single GPU, we extract only one replica of the model and pass it to accuracy function.

[6]:
from model import update_parallel

num_epochs = 10

for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update_parallel(model, batch)

    test_acc = accuracy(jax.tree_map(lambda x: x[0], model), validation_iterator)

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.6729000210762024
Epoch 1 sec
Test set accuracy 0.7845000624656677
Epoch 2 sec
Test set accuracy 0.8250000476837158
Epoch 3 sec
Test set accuracy 0.8457000255584717
Epoch 4 sec
Test set accuracy 0.8600000143051147
Epoch 5 sec
Test set accuracy 0.8712000250816345
Epoch 6 sec
Test set accuracy 0.8770000338554382
Epoch 7 sec
Test set accuracy 0.8746000528335571
Epoch 8 sec
Test set accuracy 0.8870000243186951
Epoch 9 sec
Test set accuracy 0.8915000557899475

Automatic parallelization

The following section shows how to apply automatic parallelization mechanisms in training with DALI and JAX. To learn more about these concepts look into Distributed arrays and automatic parallelization JAX tutrial.

It is possible to pass jax.sharding.Sharding object to DALI iterator. It will be used to construct output arrays consistent with the sharding. In this example we use simple PositionalSharding and pass it to dax.DALIGenericIterator initialization. Everything else remains the same as in the multiple GPUs example with pmap above. We even used the same pipeline objects for this new iterator.

[7]:
from jax.sharding import PositionalSharding, Mesh
from jax.experimental import mesh_utils


mesh = mesh_utils.create_device_mesh((jax.device_count(), 1))
sharding = PositionalSharding(mesh)

print(sharding)
for pipeline in pipelines:
    print(f'Pipeline {pipeline} working on device {pipeline.device_id}')
PositionalSharding([[{GPU 0}]
                    [{GPU 1}]])
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7f7e004fae00> working on device 0
Pipeline <nvidia.dali.pipeline.Pipeline object at 0x7f7e004fa350> working on device 1

Note that sharding and pipelines arguments must match. Devices in the sharding must be the same as the devices that pipelines are working on. See the pipelines creation and sharding creation. In both we used all available devices in the order provided by jax.devices(). Iterator will not copy outputs between the devices. It will assemble a jax.Array from the outputs of the pipelines and the passed sharding. This requirement might be lifted in the future.

[8]:
print('Creating training iterator')
training_iterator = dax.DALIGenericIterator(
    pipelines,
    output_map=["images", "labels"],
    reader_name="mnist_caffe2_reader",
    auto_reset=True,
    sharding=sharding)

print(f"Number of batches in training iterator = {len(training_iterator)}")
Creating training iterator
Number of batches in training iterator = 300

This new iterator is ready for the training. This example utilizes automatic parallelization where the computation follows the data placement. This means that we can use the same update function that we used in single GPU training example and it will automatically run computations on multiple GPUs.

For simplicity we use the same validation_iterator as before and run the accuracy calculation on a single GPU. Model is spread between the devices and we need to pull it to one of them for this to work. Otherwise JAX would throw an error. In real life scenarios this might not be the best for performance.

[9]:
from model import update

model = init_model()

for epoch in range(num_epochs):
    for it, batch in enumerate(training_iterator):
        model = update(model, batch)

    model_on_one_device = jax.tree_map(lambda x: jax.device_put(x, jax.devices()[0]), model)
    test_acc = accuracy(model_on_one_device, validation_iterator)

    print(f"Epoch {epoch} sec")
    print(f"Test set accuracy {test_acc}")
Epoch 0 sec
Test set accuracy 0.687000036239624
Epoch 1 sec
Test set accuracy 0.7791000604629517
Epoch 2 sec
Test set accuracy 0.8225000500679016
Epoch 3 sec
Test set accuracy 0.843000054359436
Epoch 4 sec
Test set accuracy 0.8577000498771667
Epoch 5 sec
Test set accuracy 0.8681000471115112
Epoch 6 sec
Test set accuracy 0.8773000240325928
Epoch 7 sec
Test set accuracy 0.8832000494003296
Epoch 8 sec
Test set accuracy 0.8872000575065613
Epoch 9 sec
Test set accuracy 0.8830000162124634