Running custom JAX augmentations in DALI#

This tutorial shows how to run JAX functions inside DALI pipeline or iterator using plugin.jax.fn.jax_function. This way, you can write custom augmentations in JAX and make them interoperable with other DALI operations.

Setting up the example#

We will start with a simple image-processing DALI iterator. You can read more on how to define iterator for JAX with DALI in DALI and JAX getting started.

[1]:
import nvidia.dali.fn as fn
from nvidia.dali.plugin.jax import data_iterator

image_dir = "../data/images"


@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def baseline_iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    return images, labels


baseline_iterator = baseline_iterator_fn(batch_size=4)

baseline_batch = next(baseline_iterator)

Let us define a simple helper function to present the produced batch, we will use it later.

[2]:
import matplotlib.pyplot as plt
from matplotlib import gridspec


def show_image(images, columns=4, fig_size=24):
    rows = (len(images) + columns - 1) // columns
    plt.figure(figsize=(fig_size, (fig_size // columns) * rows))
    gs = gridspec.GridSpec(rows, columns)
    for j in range(rows * columns):
        plt.subplot(gs[j])
        plt.axis("off")
        plt.imshow(images[j])

Adding an augmentation defined with JAX#

Now, let us add some JAX processing to the picture. As a simple example, we will flip the images horizontally using jax.numpy array indexing.

We import jax and write a function that expects a 4D array - a batch of HWC images. Similarily, the function returns 4D array, just with the W dimension is reversed.

[3]:
import jax


def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]

To plug the horz_flip into the iterator we need to transform the function using jax_function.

[4]:
from nvidia.dali.plugin.jax.fn import jax_function


@jax_function
def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]

That is it, we can call the function as a regular DALI operation inside the iterator definition.

[5]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))

    images = horz_flip(images)

    return images, labels


iterator = iterator_fn(batch_size=4)
batch = next(iterator)

Let us compare the output of the baseline iterator and the one that uses horz_flip.

[6]:
show_image(
    [
        image
        for pair in zip(baseline_batch["images"], batch["images"])
        for image in pair
    ]
)
../../_images/examples_custom_operations_jax_operator_basic_12_0.png

JAX function transformations#

The jax_function can be combined with common JAX transformations. For example we can use jax.vmap to vectorize the processing along batch dimension, jax.jit to get benefits of JAX’s just in time compilation, or use both.

The one thing to note is that jax_function must be the outermost transformation.

[7]:
@jax_function
@jax.jit
def horz_flip(images: jax.Array):
    return images[:, :, ::-1, :]  # batch of HWC images


@jax_function
@jax.vmap
def horz_flip(image: jax.Array):
    # single HWC image (batch is implicit thanks to jax.vmap)
    return image[:, ::-1, :]


@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array):
    # single HWC image (batch is implicit thanks to jax.vmap)
    return image[:, ::-1, :]

Multiple inputs and outputs#

Next, let us add another argument to the horz_flip that will controll if the given image should be flipped or left unchanged. We will flip the image depending on the output of DALI’s fn.random.coin_flip().

[8]:
@jax_function
@jax.jit
@jax.vmap
def horz_flip(image: jax.Array, should_flip: jax.Array):
    return jax.lax.cond(
        should_flip, lambda x: x[:, ::-1, :], lambda x: x, image
    )
[9]:
@data_iterator(output_map=["images", "labels"], reader_name="image_reader")
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    should_flip = fn.random.coin_flip(seed=45)
    # note, currently all the inputs must reside on the same backend type,
    # as images are in GPU memory, we need to move should_flip there as well.
    images = horz_flip(images, should_flip.gpu())
    return images, labels


iterator = iterator_fn(batch_size=8)
batch = next(iterator)  # batch of data ready to be used by JAX
[10]:
show_image(
    [
        image
        for pair in zip(baseline_batch["images"], batch["images"])
        for image in pair
    ],
    columns=4,
)
../../_images/examples_custom_operations_jax_operator_basic_18_0.png

As expected, some of the images are left unchanged.

We have just seen that the processing function can accept multiple inputs. Similarily, it can return multiple outputs. For that, however, we need to hint DALI how many outputs it should expect. We can do that passing the num_outputs to the jax_function.

[11]:
@jax_function(num_outputs=2)
@jax.jit
@jax.vmap
def flip(image: jax.Array):
    horz_flip = image[:, ::-1, :]
    vert_flip = image[::-1, :, :]
    return horz_flip, vert_flip
[12]:
@data_iterator(
    output_map=["horz", "vert", "labels"], reader_name="image_reader"
)
def iterator_fn():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    horz_flipped, vert_flipped = flip(images)
    return horz_flipped, vert_flipped, labels


iterator = iterator_fn(batch_size=2)
batch = next(iterator)  # batch of data ready to be used by JAX
[13]:
show_image(
    [
        image
        for triple in zip(
            baseline_batch["images"], batch["horz"], batch["vert"]
        )
        for image in triple
    ],
    columns=3,
)
../../_images/examples_custom_operations_jax_operator_basic_22_0.png

JAX augmentations in regular pipelines#

The JAX agumentations are not limited to JAX iterators, they can work with regular DALI pipelines too.

[14]:
from nvidia.dali import pipeline_def


@pipeline_def(batch_size=4, device_id=0, num_threads=4)
def pipeline():
    jpegs, labels = fn.readers.file(file_root=image_dir, name="image_reader")
    images = fn.decoders.image(jpegs, device="mixed")
    images = fn.resize(images, size=(300, 300))
    should_flip = fn.random.coin_flip(seed=44)
    flipped_images = horz_flip(images, should_flip.gpu())
    return images, flipped_images, labels


p = pipeline()
p.build()
images, flipped_images, labels = p.run()
[15]:
show_image(
    [
        image
        for pair in zip(images.as_cpu(), flipped_images.as_cpu())
        for image in pair
    ],
    columns=4,
)
../../_images/examples_custom_operations_jax_operator_basic_25_0.png