bridge.data.megatron_mimo.dp_utils#
Data parallel utilities for MegatronMIMO data loading.
Module Contents#
Functions#
Find which module grid the current rank belongs to. |
|
Determine if the current rank needs to load data for the given module. |
|
Get module-local DP rank, size, data-loading flag, and module name. |
|
Get sampler DP rank, size, and data-loading flag for MegatronMIMO. |
|
Slice a global micro-batch for this rank’s module-local DP shard. |
API#
- bridge.data.megatron_mimo.dp_utils._find_rank_module(
- grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
Find which module grid the current rank belongs to.
- bridge.data.megatron_mimo.dp_utils._needs_data_for_module(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
- module_name: str,
Determine if the current rank needs to load data for the given module.
LLM: first and last PP stage need data (input_ids and labels respectively). Encoders: only the first PP stage needs raw modality inputs.
- bridge.data.megatron_mimo.dp_utils.get_megatron_mimo_dp_info(
- megatron_mimo_cfg: megatron.bridge.models.megatron_mimo.megatron_mimo_config.MegatronMIMOParallelismConfig,
- grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
Get module-local DP rank, size, data-loading flag, and module name.
Returns the DP settings for the module that the current rank participates in. These are used by :func:
slice_batch_for_megatron_mimoto sub-shard a global micro-batch into per-module DP shards... note:: Do not use these values to construct a
DistributedSampler. For sampler construction use :func:get_megatron_mimo_sampling_infoinstead, which returns settings that keep all data-loading ranks synchronised on the same sample order.- Parameters:
megatron_mimo_cfg – MegatronMIMO parallelism configuration.
grids – Module name to HyperCommGrid mapping from build_hypercomm_grids().
- Returns:
Tuple of (dp_rank, dp_size, needs_data, loader_module).
- bridge.data.megatron_mimo.dp_utils.get_megatron_mimo_sampling_info(
- megatron_mimo_cfg: megatron.bridge.models.megatron_mimo.megatron_mimo_config.MegatronMIMOParallelismConfig,
- grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
Get sampler DP rank, size, and data-loading flag for MegatronMIMO.
In heterogeneous MegatronMIMO, modules may have different DP sizes. The data loader must give every data-loading rank the same global micro-batch so that :func:
slice_batch_for_megatron_mimo(called in the forward step) can sub-shard it consistently with the :class:BridgeCommunicatorfan-in / fan-out routing.This function therefore returns
dp_size=1, dp_rank=0for all ranks, disabling DP sharding at the sampler level. Per-module DP sharding is deferred to :func:slice_batch_for_megatron_mimo.- Parameters:
megatron_mimo_cfg – MegatronMIMO parallelism configuration.
grids – Module name to HyperCommGrid mapping.
- Returns:
Tuple of (sampler_dp_rank, sampler_dp_size, needs_data).
- bridge.data.megatron_mimo.dp_utils.slice_batch_for_megatron_mimo(
- batch: Dict[str, Any],
- dp_rank: int,
- dp_size: int,
Slice a global micro-batch for this rank’s module-local DP shard.
All data-loading ranks receive the same global micro-batch (the sampler uses
dp_size=1). This function contiguously slices it so that each module-local DP replica processes the correct subset. The slicing is contiguous to match the :class:BridgeCommunicator’s batch-dimension split / concatenate logic for fan-out and fan-in routing.Handles nested dicts (e.g.
modality_inputs) by recursing.- Parameters:
batch – Global batch dictionary with tensors of shape [global_batch, …]. May contain nested dicts (e.g. modality_inputs → encoder → kwargs).
dp_rank – This rank’s position in its module-local DP group.
dp_size – Size of the module-local DP group.
- Returns:
Dict with tensors sliced to shape [global_batch // dp_size, …].
.. rubric:: Example
global_batch = {‘tokens’: torch.randn(12, 2048)} local_batch = slice_batch_for_megatron_mimo(global_batch, dp_rank=1, dp_size=3) local_batch[‘tokens’].shape # torch.Size([4, 2048])