nvidia.dali.plugin.jax.fn.jax_function#
- nvidia.dali.plugin.jax.fn.jax_function(function=None, num_outputs=1, output_layouts=None, sharding=None, device=None, preserve=True)#
Transforms the Python function
function
that processesjax.Array
objects into DALI operator that can be used inside DALI pipeline definition or JAX plugin iterator definition. The transformed function accepts and returns the same number of inputs and outputs as the original function, but changes their types: fromjax.Array
to DALI-tracedDataNodes
. The resulting function is interoperable with other DALI operators.For example, we could implement horizontal flipping operation in JAX as follows:
import jax from nvidia.dali import pipeline_def, fn, types from nvidia.dali.plugin import jax as dax @dax.fn.jax_function def flip_horizontal(image_batch: jax.Array): return image_batch[:, :, ::-1, :] # batch of HWC images @pipeline_def(batch_size=4, device_id=0, num_threads=4) def pipeline(): image, _ = fn.readers.file(file_root=jpeg_path_dali_extra) image = fn.decoders.image(image, device="mixed", output_type=types.RGB) image = fn.resize(image, size=(244, 244)) flipped = flip_horizontal(image) return image, flipped
The
function
can be transformed with usual JAX transformations, for example we can utilize JAX’s just-in-time compilation and vectorization adding the appropriate decorators in the above example:@dax.fn.jax_function @jax.jit @jax.vmap def flip_horizontal(image: jax.Array): return image[:, ::-1, :] # HWC image
If the resulting function is run with DALI GPU batches, the internal DALI and JAX streams will be synchronized. The JAX operations do not need to be further synchronized by the user.
The
jax.Arrays
passed to thefunction
must not be accessed after thefunction
completes (for example, they should not be stored in some non-local scope).Note
This is experimental API and may change in future releases.
Note
The jax_function requires JAX version 0.4.16 or higher, with GPU support. JAX 0.4.16 requires Python 3.9 or higher.
- Parameters:
function (JaxCallback) – Python callback that accepts and returns zero or more jax.Array objects. The function will receive batches processed by DALI as jax.Array tensors (with the leftmost extent corresponding to DALI batch). For this reason, the transformed function can only receive DALI batches that contain samples of uniform shape.
num_outputs (int, default=1) –
The number of outputs returned by the
function
.Function can return no output, in that case the num_outputs must be set to 0. If the
num_outputs
is 1 (the default), callback should return a single JAX array, fornum_outputs
> 1, callback should return a tuple of JAX arrays.output_layouts (Union[str, Tuple[str]], optional) –
The layouts of returned tensors.
It can be either a list of strings for all of
num_outputs
respective outputs or a single string to be set to all of the outputs.Please note, in DALI, the outermost batch extent is implicit, the layout should take into account only the sample dimensions.
If the argument is not specified and the
function
’s i-th output has the same dimensionality as the i-th input, the layout will be propagated from the input to the corresponding output.sharding (jax.sharding.Sharding, optional) –
The JAX sharding object (either
PositionalSharding
orNamedSharding
). If specified, thejax.Arrays
passed to thefunction
will be a globaljax.Array
aware of the sharding.Note
Currently, only the global sharding is supported, i.e. the number of the local devices in the given process must be exactly one.
device (str, optional) – Either “cpu”, “gpu” or None. The device kind on which all of the DALI inputs and outputs to the transformed function will be placed. If not specified, the device will be deduced based on the DALI inputs passed to the resulting function. Currently, the device kind of all the inputs and outputs must be the same.
preserve (bool, default=True) – If set to False, the returned DALI function may be optimized out of the DALI pipeline, if it does not return any outputs or none of the function outputs contribute to the pipeline’s output.
- Returns:
The transformed function that processes DALI-traced batches (DataNodes).
- Return type:
DaliCallback