nvidia.dali.plugin.jax.fn.jax_function

nvidia.dali.plugin.jax.fn.jax_function(function=None, num_outputs=1, output_layouts=None, sharding=None, device=None, preserve=True)

Transforms the Python function function that processes jax.Array objects into DALI operator that can be used inside DALI pipeline definition or JAX plugin iterator definition. The transformed function accepts and returns the same number of inputs and outputs as the original function, but changes their types: from jax.Array to DALI-traced DataNodes. The resulting function is interoperable with other DALI operators.

For example, we could implement horizontal flipping operation in JAX as follows:

import jax
from nvidia.dali import pipeline_def, fn, types
from nvidia.dali.plugin import jax as dax

@dax.fn.jax_function
def flip_horizontal(image_batch: jax.Array):
    return image_batch[:, :, ::-1, :]  # batch of HWC images

@pipeline_def(batch_size=4, device_id=0, num_threads=4)
def pipeline():
    image, _ = fn.readers.file(file_root=jpeg_path_dali_extra)
    image = fn.decoders.image(image, device="mixed", output_type=types.RGB)
    image = fn.resize(image, size=(244, 244))
    flipped = flip_horizontal(image)
    return image, flipped

The function can be transformed with usual JAX transformations, for example we can utilize JAX’s just-in-time compilation and vectorization adding the appropriate decorators in the above example:

@dax.fn.jax_function
@jax.jit
@jax.vmap
def flip_horizontal(image: jax.Array):
    return image[:, ::-1, :]  # HWC image

If the resulting function is run with DALI GPU batches, the internal DALI and JAX streams will be synchronized. The JAX operations do not need to be further synchronized by the user.

The jax.Arrays passed to the function must not be accessed after the function completes (for example, they should not be stored in some non-local scope).

Note

This is experimental API and may change in future releases.

Note

The jax_function requires JAX version 0.4.16 or higher, with GPU support. JAX 0.4.16 requires Python 3.9 or higher.

Parameters:
  • function (JaxCallback) – Python callback that accepts and returns zero or more jax.Array objects. The function will receive batches processed by DALI as jax.Array tensors (with the leftmost extent corresponding to DALI batch). For this reason, the transformed function can only receive DALI batches that contain samples of uniform shape.

  • num_outputs (int, default=1) –

    The number of outputs returned by the function.

    Function can return no output, in that case the num_outputs must be set to 0. If the num_outputs is 1 (the default), callback should return a single JAX array, for num_outputs > 1, callback should return a tuple of JAX arrays.

  • output_layouts (Union[str, Tuple[str]], optional) –

    The layouts of returned tensors.

    It can be either a list of strings for all of num_outputs respective outputs or a single string to be set to all of the outputs.

    Please note, in DALI, the outermost batch extent is implicit, the layout should take into account only the sample dimensions.

    If the argument is not specified and the function’s i-th output has the same dimensionality as the i-th input, the layout will be propagated from the input to the corresponding output.

  • sharding (jax.sharding.Sharding, optional) –

    The JAX sharding object (either PositionalSharding or NamedSharding). If specified, the jax.Arrays passed to the function will be a global jax.Array aware of the sharding.

    Note

    Currently, only the global sharding is supported, i.e. the number of the local devices in the given process must be exactly one.

  • device (str, optional) – Either “cpu”, “gpu” or None. The device kind on which all of the DALI inputs and outputs to the transformed function will be placed. If not specified, the device will be deduced based on the DALI inputs passed to the resulting function. Currently, the device kind of all the inputs and outputs must be the same.

  • preserve (bool, default=True) – If set to False, the returned DALI function may be optimized out of the DALI pipeline, if it does not return any outputs or none of the function outputs contribute to the pipeline’s output.

Returns:

The transformed function that processes DALI-traced batches (DataNodes).

Return type:

DaliCallback