nemo_rl.distributed.batched_data_dict#

Module Contents#

Classes#

DynamicBatchingCfg

Configuration settings for dynamic batching.

BatchedDataDict

SlicedDataDict

A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch.

Data#

API#

nemo_rl.distributed.batched_data_dict.DictT#

‘TypeVar(…)’

class nemo_rl.distributed.batched_data_dict.DynamicBatchingCfg[source]#

Bases: typing.TypedDict

Configuration settings for dynamic batching.

Pass this to ‘shard_by_batch_size()’ to preprocess batches for dynamic batching.

Initialization

Initialize self. See help(type(self)) for accurate signature.

max_tokens_per_microbatch: int#

None

sequence_length_round: int#

None

input_key: str#

None

input_lengths_key: str#

None

class nemo_rl.distributed.batched_data_dict.BatchedDataDict(*args, **kwargs)[source]#

Bases: collections.UserDict, typing.Generic[nemo_rl.distributed.batched_data_dict.DictT]

classmethod from_batches(
batches: List[Dict],
pad_value_dict: Optional[Dict[str, int]] = None,
) typing_extensions.Self[source]#

Given a list of batches, stack the tensors/lists within and put them in a single dictionary.

Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict.

Parameters:
  • batches (List[Dict]) – A list of dictionaries, each containing a batch of data.

  • pad_value_dict (Optional[Dict[str, int]]) – An optional dict mapping keys to non-default(0) padding values.

Returns:

A new BatchedDataDict containing the stacked data.

Return type:

BatchedDataDict

all_gather(
group: torch.distributed.ProcessGroup,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Gathers batches with possibly jagged leading dimensions across the DP ranks.

If using reshard, it will treat PP as DP ranks. Works with data that is either tensors or string lists.

chunk(
rank: int,
chunks: int,
) nemo_rl.distributed.batched_data_dict.SlicedDataDict[source]#

Chunks a global batch into ‘chunks’ splits and returns the ‘rank’th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E].

Requires all leading dimensions of tensors and lengths of lists to be the same over the batch and the chunks must divide batch size.

reorder_data(reorded_indices: List[int])[source]#

Reorders the data along the batch dimension by the given indices.

shard_by_batch_size(
shards: int,
batch_size: Optional[int] = None,
allow_uneven_shards: bool = False,
dynamic_batching_cfg: nemo_rl.distributed.batched_data_dict.DynamicBatchingCfg = None,
) List[nemo_rl.distributed.batched_data_dict.SlicedDataDict][source]#

Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position.

If batch_size is None, there will be no chunking beforehand (will default to the total batch size).

For example, with data [A A B B C C D D], batch_size=2, shards=2:

  • Element 0: [A B C D] (first elements from each chunk)

  • Element 1: [A B C D] (second elements from each chunk)

Parameters:
  • shards (int) – The number of shards to divide each batch_size chunk into.

  • batch_size (int) – The size of each initial chunk.

  • allow_uneven_shards (bool) – Whether to allow shards to be unevenly sized. If True, the last shard may be smaller than the others.

  • dynamic_batching_cfg (dict) –

    If passed, preprocess batch for dynamic batching. This dict requires two keys:

    1. max_tokens_per_microbatch (int): the maximum number of tokens in a microbatch

    2. sequence_length_round (int): round each all sequence lengths to this multiple

    3. input_key (str): the key in the batch which holds input ids.

    4. input_lengths_key (str): the key in the batch which holds the sequence length per value. The sequence dim index is assumed to be 1.

Returns:

A list of BatchedDataDicts, length equal to shards.

Return type:

List[BatchedDataDict]

Examples:

>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict
>>> # Create a batch of two message logs with different lengths
>>> batch = BatchedDataDict({
...     'problem_id': [0, 0, 1, 1, 2, 2, 3, 3],
...     'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8]
... })
>>> shards = batch.shard_by_batch_size(shards=2)
>>> shards
[{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}]
>>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO)
>>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch.
>>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first.
>>> shards = batch.shard_by_batch_size(shards=2, batch_size=4)
>>> shards
[{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}]
>>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first.
>>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary.
>>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes.
>>> batch = BatchedDataDict({
...     'problem_id': [0, 1, 2, 3, 4],
...     'arbitrary_data': [10, 11, 12, 13, 14]
... })
>>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True)
>>> shards
[{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}]
>>> # This is incompatible with the batch_size argument
get_batch(
batch_idx,
batch_size,
) nemo_rl.distributed.batched_data_dict.SlicedDataDict[source]#

Slices a subbatch from the batch.

Parameters:
  • batch_idx – the batch index to slice

  • batch_size – the size of the batch to be sliced

Returns:

A new BatchedDataDict containing the sliced data

Return type:

BatchedDataDict

slice(
start: int,
end: int,
) nemo_rl.distributed.batched_data_dict.SlicedDataDict[source]#

Slices the batch from start to end.

Parameters:
  • start – Starting index (inclusive)

  • end – Ending index (exclusive)

Returns:

A new BatchedDataDict containing the sliced data

Return type:

BatchedDataDict

repeat_interleave(
num_repeats: int,
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Repeats the batch num_repeats times.

For each element in the batch, repeat each value num_repeats times. i.e: {“key”: torch.tensor([1, 2, 3]), “other_key”: [1, 2, 3]} -> {“key”: torch.tensor([1, 1, 2, 2, 3, 3]), “other_key”: [1, 1, 2, 2, 3, 3]}

truncate_tensors(dim: int, truncated_len: int)[source]#

Truncates tensors in this dict of a given dim to a given length.

make_microbatch_iterator_with_dynamic_shapes(
sequence_dim: int = 1,
) Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict][source]#

Makes an interator that yields microbatchs of dynamic batch and sequence sizes.

Parameters:

sequence_dim – the index of the sequence dim for all tensors in the data dict

Returns:

An iterator that yield dynamic microbatches

Return type:

Iterator[“SlicedDataDict”]

make_microbatch_iterator(
microbatch_size: int,
) Iterator[nemo_rl.distributed.batched_data_dict.SlicedDataDict][source]#

Make an iterator over the batch that yields microbatches of size microbatch_size.

property size: int#

Get the batch size of the batch.

to(device: torch.device) typing_extensions.Self[source]#

Move tensors in batched dict to device.

select_indices(
indices: Union[List[int], torch.Tensor],
) nemo_rl.distributed.batched_data_dict.BatchedDataDict[source]#

Selects specific rows from the batch based on indices.

Parameters:

indices – A list or tensor of integer indices to select.

Returns:

A new BatchedDataDict containing only the selected rows.

Return type:

BatchedDataDict

get_dict() dict[source]#

Get the underlying data dictionary.

class nemo_rl.distributed.batched_data_dict.SlicedDataDict(*args, **kwargs)[source]#

Bases: nemo_rl.distributed.batched_data_dict.BatchedDataDict

A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch.

This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for type checking.

Initialization