nemo_rl.data_plane.preshard#
Driver-side balanced packing + per-rank fan-out helpers.
Shared by sync and async data-plane trainers. Operates on full
BatchedDataDicts and relies on shard_by_batch_size’s
bin_count_multiple=DP_world behavior to keep per-rank microbatch
counts uniform — without that, sequence packing / dynamic batching
produce variable per-rank bin counts and Megatron deadlocks at the
first cross-DP collective.
Module Contents#
Functions#
Pure key-list split: assign |
API#
- nemo_rl.data_plane.preshard.shard_meta_for_dp(
- meta: nemo_rl.data_plane.interfaces.KVBatchMeta,
- *,
- dp_world: int,
- batch_size: Optional[int] = None,
- sequence_packing_args: Optional[dict[str, Any]] = None,
- dynamic_batching_args: Optional[dict[str, Any]] = None,
Pure key-list split: assign
meta.sample_idstodp_worldranks.Seq-len-aware on top of
shard_by_batch_size. No I/O, no key minting. Used for every dispatch after rollout (logprob, ref-logprob, train); the rollout actor’s first write goes through- Func:
nemo_rl.experience.sync_rollout_actor.kv_first_writedirectly.
Per-rank packing metadata (
micro_batch_indices/micro_batch_lengths/elem_counts_per_gb) is set in each shard’sextra_infoso the*_preshardedworker can reattach packing as it does on the legacy fan-out path.- Parameters:
meta – Full-batch
KVBatchMetawithsequence_lengthspopulated.dp_world – Number of DP ranks.
batch_size – Total samples;
Nonefor the logprob path, GBS for train.sequence_packing_args – Packing config dict for
shard_by_batch_size.dynamic_batching_args – Dynamic-batching config dict; mutually exclusive with the above.
- Returns:
(per_rank_metas, unsorted_indices).unsorted_indicesis the inverse permutation that maps DP-rank-order outputs back to originalmeta.sample_idsorder (feed toBatchedDataDict.reorder_datapost-aggregation);Noneif no reorder occurred.