JAX Plugin API reference#
- class nvidia.dali.plugin.jax.DALIGenericIterator(pipelines, output_map, size=-1, reader_name=None, auto_reset=False, last_batch_padded=False, last_batch_policy=LastBatchPolicy.FILL, prepare_first_batch=True, sharding=None)#
- General DALI iterator for JAX. It can return any number of outputs from the DALI pipeline in the form of JAX Arrays. - Parameters:
- output_map¶ (list of str) – List of strings which maps consecutive outputs of DALI pipelines to user specified name. Outputs will be returned from iterator as dictionary of those names. Each name should be distinct 
- size¶ (int, default = -1) – Number of samples in the shard for the wrapped pipeline (if there is more than one it is a sum) Providing -1 means that the iterator will work until StopIteration is raised from the inside of iter_setup(). The options - last_batch_policyand- last_batch_paddeddon’t work in such case. It works with only one pipeline inside the iterator. Mutually exclusive with- reader_nameargument
- reader_name¶ (str, default = None) – Name of the reader which will be queried for the shard size, number of shards and all other properties necessary to count properly the number of relevant and padded samples that iterator needs to deal with. It automatically sets - last_batch_paddedaccordingly to match the reader’s configuration.
- auto_reset¶ (string or bool, optional, default = False) – - Whether the iterator resets itself for the next epoch or it requires reset() to be called explicitly. - It can be one of the following values: - "no",- Falseor- None- at the end of epoch StopIteration is raised
 - and reset() needs to be called * - "yes"or- "True"- at the end of epoch StopIteration is raised but reset() is called internally automatically.
- last_batch_policy¶ (optional, default = LastBatchPolicy.FILL) – What to do with the last batch when there are not enough samples in the epoch to fully fill it. See - nvidia.dali.plugin.base_iterator.LastBatchPolicy(). JAX iterator does not support LastBatchPolicy.PARTIAL
- last_batch_padded¶ (bool, optional, default = False) – Whether the last batch provided by DALI is padded with the last sample or it just wraps up. In the conjunction with - last_batch_policyit tells if the iterator returning last batch with data only partially filled with data from the current epoch is dropping padding samples or samples from the next epoch. If set to- Falsenext epoch will end sooner as data from it was consumed but dropped. If set to True next epoch would be the same length as the first one. For this to happen, the option pad_last_batch in the reader needs to be set to True as well. It is overwritten when- reader_nameargument is provided
- prepare_first_batch¶ (bool, optional, default = True) – Whether DALI should buffer the first batch right after the creation of the iterator, so one batch is already prepared when the iterator is prompted for the data 
- sharding¶ (jax.sharding.Sharding) – jax.sharding.Sharding compatible object that, if present, will be used to build an output jax.Array for each category. If - None, the iterator returns values compatible with pmapped JAX functions, if multiple pipelines are provided.
 
 - Example - With the data set - [1,2,3,4,5,6,7]and the batch size 2:- last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True -> last batch = - [7, 7], next iteration will return- [1, 2]- last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False -> last batch = - [7, 1], next iteration will return- [2, 3]- last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True -> last batch = - [5, 6], next iteration will return- [1, 2]- last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False -> last batch = - [5, 6], next iteration will return- [2, 3]- Note - JAX iterator does not support LastBatchPolicy.PARTIAL. - checkpoints()#
- Returns the current checkpoints of the pipelines. 
 - next()#
- Returns the next batch of data. 
 - reset()#
- Resets the iterator after the full epoch. DALI iterators do not support resetting before the end of the epoch and will ignore such request. 
 - property size#
 
- nvidia.dali.plugin.jax.data_iterator(pipeline_fn=None, output_map=[], size=-1, reader_name=None, auto_reset=False, last_batch_padded=False, last_batch_policy=LastBatchPolicy.FILL, prepare_first_batch=True, sharding=None, devices=None)#
- Decorator for DALI iterator for JAX. Decorated function when called returns DALI iterator for JAX. - Decorated function should return DALI pipeline definition function. Decorator accepts all arguments of - nvidia.dali.plugin.base_iterator.DALIGenericIterator.__init__()and passes them to the iterator constructor. If no device_id argument is passed to the decorated function, it is assumed that the first device is the one we want to use and device_id is set to 0. If the same argument is passed to the decorator and the decorated function, exception is raised.- Parameters:
- function¶ (pipeline_fn) – Function to be decorated. It should be compatible with - nvidia.dali.pipeline.pipeline_def()decorator. For multigpu support it should accept device_id, shard_id and num_shards args.
- output_map¶ (list of str) – List of strings which maps consecutive outputs of DALI pipelines to user specified name. Outputs will be returned from iterator as dictionary of those names. Each name should be distinct 
- size¶ (int, default = -1) – Number of samples in the shard for the wrapped pipeline (if there is more than one it is a sum) Providing -1 means that the iterator will work until StopIteration is raised from the inside of iter_setup(). The options - last_batch_policyand- last_batch_paddeddon’t work in such case. It works with only one pipeline inside the iterator. Mutually exclusive with- reader_nameargument
- reader_name¶ (str, default = None) – Name of the reader which will be queried for the shard size, number of shards and all other properties necessary to count properly the number of relevant and padded samples that iterator needs to deal with. It automatically sets - last_batch_paddedaccordingly to match the reader’s configuration.
- auto_reset¶ (string or bool, optional, default = False) – - Whether the iterator resets itself for the next epoch or it requires reset() to be called explicitly. - It can be one of the following values: - "no",- Falseor- None- at the end of epoch StopIteration is raised
 - and reset() needs to be called * - "yes"or- "True"- at the end of epoch StopIteration is raised but reset() is called internally automatically.
- last_batch_policy¶ (optional, default = LastBatchPolicy.FILL) – What to do with the last batch when there are not enough samples in the epoch to fully fill it. See - nvidia.dali.plugin.base_iterator.LastBatchPolicy(). JAX iterator does not support LastBatchPolicy.PARTIAL
- last_batch_padded¶ (bool, optional, default = False) – Whether the last batch provided by DALI is padded with the last sample or it just wraps up. In the conjunction with - last_batch_policyit tells if the iterator returning last batch with data only partially filled with data from the current epoch is dropping padding samples or samples from the next epoch. If set to- Falsenext epoch will end sooner as data from it was consumed but dropped. If set to True next epoch would be the same length as the first one. For this to happen, the option pad_last_batch in the reader needs to be set to True as well. It is overwritten when- reader_nameargument is provided
- prepare_first_batch¶ (bool, optional, default = True) – Whether DALI should buffer the first batch right after the creation of the iterator, so one batch is already prepared when the iterator is prompted for the data 
- sharding¶ (jax.sharding.Sharding) – jax.sharding.Sharding compatible object that, if present, will be used to build an output jax.Array for each category. Iterator will return outputs compatible with automatic parallelization in JAX. This argument is mutually exclusive with - devicesargument. If- devicesis provided,- shardingshould be set to None.
- devices¶ (list of jax.Device) – List of JAX devices to be used to run the pipeline in parallel. Iterator will return outputs compatible with pmapped JAX functions. This argument is mutually exclusive with - shardingargument. If- shardingis provided,- devicesshould be set to None.
- checkpoints¶ (list of str, optional, default = None) – Checkpoints obtained with .checkpoints() method of the iterator. If provided, they will be used to restore the state of the pipelines. 
 
 - Example - With the data set - [1,2,3,4,5,6,7]and the batch size 2:- last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = True -> last batch = - [7, 7], next iteration will return- [1, 2]- last_batch_policy = LastBatchPolicy.FILL, last_batch_padded = False -> last batch = - [7, 1], next iteration will return- [2, 3]- last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = True -> last batch = - [5, 6], next iteration will return- [1, 2]- last_batch_policy = LastBatchPolicy.DROP, last_batch_padded = False -> last batch = - [5, 6], next iteration will return- [2, 3]- Note - JAX iterator does not support LastBatchPolicy.PARTIAL.