nvidia.dali.plugin.jax.fn.jax_function#
- nvidia.dali.plugin.jax.fn.jax_function(function=None, num_outputs=1, output_layouts=None, sharding=None, device=None, preserve=True)#
- Transforms the Python function - functionthat processes- jax.Arrayobjects into DALI operator that can be used inside DALI pipeline definition or JAX plugin iterator definition. The transformed function accepts and returns the same number of inputs and outputs as the original- function, but changes their types: from- jax.Arrayto DALI-traced- DataNodes. The resulting function is interoperable with other DALI operators.- For example, we could implement horizontal flipping operation in JAX as follows: - import jax from nvidia.dali import pipeline_def, fn, types from nvidia.dali.plugin import jax as dax @dax.fn.jax_function def flip_horizontal(image_batch: jax.Array): return image_batch[:, :, ::-1, :] # batch of HWC images @pipeline_def(batch_size=4, device_id=0, num_threads=4) def pipeline(): image, _ = fn.readers.file(file_root=jpeg_path_dali_extra) image = fn.decoders.image(image, device="mixed", output_type=types.RGB) image = fn.resize(image, size=(244, 244)) flipped = flip_horizontal(image) return image, flipped - The - functioncan be transformed with usual JAX transformations, for example we can utilize JAX’s just-in-time compilation and vectorization adding the appropriate decorators in the above example:- @dax.fn.jax_function @jax.jit @jax.vmap def flip_horizontal(image: jax.Array): return image[:, ::-1, :] # HWC image - If the resulting function is run with DALI GPU batches, the internal DALI and JAX streams will be synchronized. The JAX operations do not need to be further synchronized by the user. - The - jax.Arrayspassed to the- functionmust not be accessed after the- functioncompletes (for example, they should not be stored in some non-local scope).- Note - This is experimental API and may change in future releases. - Note - The jax_function requires JAX version 0.4.16 or higher, with GPU support. JAX 0.4.16 requires Python 3.9 or higher. - Parameters:
- function¶ (JaxCallback) – Python callback that accepts and returns zero or more jax.Array objects. The function will receive batches processed by DALI as jax.Array tensors (with the leftmost extent corresponding to DALI batch). For this reason, the transformed function can only receive DALI batches that contain samples of uniform shape. 
- num_outputs¶ (int, default=1) – - The number of outputs returned by the - function.- Function can return no output, in that case the - num_outputsmust be set to 0. If the- num_outputsis 1 (the default), callback should return a single JAX array, for- num_outputs> 1, callback should return a tuple of JAX arrays.
- output_layouts¶ (Union[str, Tuple[str]], optional) – - The layouts of returned tensors. - It can be either a list of strings for all of - num_outputsrespective outputs or a single string to be set to all of the outputs.- Please note, in DALI, the outermost batch extent is implicit, the layout should take into account only the sample dimensions. - If the argument is not specified and the - function’s i-th output has the same dimensionality as the i-th input, the layout will be propagated from the input to the corresponding output.
- sharding¶ (jax.sharding.Sharding, optional) – - The JAX sharding object (either - PositionalShardingor- NamedSharding). If specified, the- jax.Arrayspassed to the- functionwill be a global- jax.Arrayaware of the sharding.- Note - Currently, only the global sharding is supported, i.e. the number of the local devices in the given process must be exactly one. 
- device¶ (str, optional) – Either “cpu”, “gpu” or None. The device kind on which all of the DALI inputs and outputs to the transformed function will be placed. If not specified, the device will be deduced based on the DALI inputs passed to the resulting function. Currently, the device kind of all the inputs and outputs must be the same. 
- preserve¶ (bool, default=True) – If set to False, the returned DALI function may be optimized out of the DALI pipeline, if it does not return any outputs or none of the function outputs contribute to the pipeline’s output. 
 
- Returns:
- The transformed function that processes DALI-traced batches (DataNodes). 
- Return type:
- DaliCallback