nemo_rl.data_plane.preshard#

Driver-side balanced packing + per-rank fan-out helpers.

Shared by sync and async data-plane trainers. Operates on full BatchedDataDicts and relies on shard_by_batch_size’s bin_count_multiple=DP_world behavior to keep per-rank microbatch counts uniform — without that, sequence packing / dynamic batching produce variable per-rank bin counts and Megatron deadlocks at the first cross-DP collective.

Module Contents#

Functions#

shard_meta_for_dp

Pure key-list split: assign meta.sample_ids to dp_world ranks.

API#

nemo_rl.data_plane.preshard.shard_meta_for_dp(
meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
*,
dp_world: int,
batch_size: Optional[int] = None,
sequence_packing_args: Optional[dict[str, Any]] = None,
dynamic_batching_args: Optional[dict[str, Any]] = None,
) tuple[list[nemo_rl.data_plane.interfaces.KVBatchMeta], Optional[list[int]]]#

Pure key-list split: assign meta.sample_ids to dp_world ranks.

Seq-len-aware on top of shard_by_batch_size. No I/O, no key minting. Used for every dispatch after rollout (logprob, ref-logprob, train); the rollout actor’s first write goes through

Func:

nemo_rl.experience.sync_rollout_actor.kv_first_write directly.

Per-rank packing metadata (micro_batch_indices / micro_batch_lengths / elem_counts_per_gb) is set in each shard’s extra_info so the *_presharded worker can reattach packing as it does on the legacy fan-out path.

Parameters:
  • meta – Full-batch KVBatchMeta with sequence_lengths populated.

  • dp_world – Number of DP ranks.

  • batch_size – Total samples; None for the logprob path, GBS for train.

  • sequence_packing_args – Packing config dict for shard_by_batch_size.

  • dynamic_batching_args – Dynamic-batching config dict; mutually exclusive with the above.

Returns:

(per_rank_metas, unsorted_indices). unsorted_indices is the inverse permutation that maps DP-rank-order outputs back to original meta.sample_ids order (feed to BatchedDataDict.reorder_data post-aggregation); None if no reorder occurred.