Running JAX augmentations on multiple GPUs

The previous tutorial presented simple usage of JAX functions as augmentations in DALI.

In this tutorial, we show how to benefit from JAX’s multi-gpu sharding when writting data-processing augmentations. We will create a distributed data-loader that accesses samples from different shards. Using the same setup, you can extend the example to use other JAX distributed operations, like reductions accross global batch.

Setting up distributed run

In this simplified example, we will asume that dataloading is distributed across two processes. Dataloader in each process will process its own portion of data on its own GPU: one will process images of dogs, the other images of cats.

Please note, in a real life scenario, the code of this example, would be launched by some setup script that would make sure to launch it as two separate processes, each of the processes receiving its unique process id as a launch parameter. Here, for the presentation purposes, we run the process_id=0 step by step in this notebook, while the other process with process_id=1 is run in the backround.

import subprocess

subprocess.Popen(["python3.10", ""])
<Popen: returncode: None args: ['python3.10', 'jax_operator_multi_gpu_proces...>

We start by initializing JAX distributed workflow.

import os
import jax

process_id = 0  # the other process is launched with process_id = 1

os.environ["CUDA_VISIBLE_DEVICES"] = str(process_id)


Now we create mesh description - the global batch is split across two processes.

from jax.sharding import Mesh, PartitionSpec, NamedSharding

assert len(jax.devices()) == 2
assert len(jax.local_devices()) == 1

mesh = Mesh(jax.devices(), axis_names=("batch"))
sharding = NamedSharding(mesh, PartitionSpec("batch"))

Iterator with sharding

Next we define the iterator and pass the sharding object defined above. Based on that, the iterator_function will receive shard_id that we use to decide to read only dog or cat images. We read, decode and resize the images. Finally we call the global_mixup JAX operation defined in the next cell.

import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator

dogs = [f"../data/images/dog/dog_{i}.jpg" for i in range(1, 9)]
kittens = [f"../data/images/kitten/cat_{i}.jpg" for i in range(1, 9)]

def iterator_function(shard_id, num_shards):
    assert num_shards == 2
    jpegs, _ = fn.readers.file(files=dogs if shard_id == 0 else kittens, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(244, 244))

    # mixup images between shards
    images = global_mixup(images)
    return images

JAX augmentation with sharding

The global_mixup function mixes samplewise the two batches from two procesess run concurrently.

The sharding object we passed in data_iterator call, is passed to jax_function as well. This hints DALI how to construct JAX global array. Please note, currently, the support for sharding in jax_function is limited to global sharding, i.e. when each data-loading process sees a single (different in each process) GPU device.

The jax.experimental.shard_map specification allows us to use the jax.lax.pshuffle.

from functools import partial
from jax.experimental.shard_map import shard_map

from nvidia.dali.plugin.jax.fn import jax_function

def global_mixup(sample):
    mixed_up = 0.5 * sample + 0.5 * jax.lax.pshuffle(sample, "batch", [1, 0])
    mixed_up = jax.numpy.clip(mixed_up, 0, 255)
    return jax.numpy.array(mixed_up, dtype=jax.numpy.uint8)

Now, we are ready to test the dataloader.

local_batch_size = 8
num_shards = 2

iterator = iterator_function(batch_size=num_shards * local_batch_size, num_threads=4)
batch = next(iterator)

Let us define a simple helper function to present the produced batch.

import matplotlib.pyplot as plt
from matplotlib import gridspec

def show_image(images, columns=4, fig_size=24):
    rows = (len(images) + columns - 1) // columns
    plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
local_batches = [ for x in batch["images"].addressable_shards]
assert len(local_batches) == 1