bridge.data.megatron_mimo.dp_utils#

Data parallel utilities for MegatronMIMO data loading.

Module Contents#

Functions#

_find_rank_module

Find which module grid the current rank belongs to.

_needs_data_for_module

Determine if the current rank needs to load data for the given module.

get_megatron_mimo_dp_info

Get module-local DP rank, size, data-loading flag, and module name.

get_megatron_mimo_sampling_info

Get sampler DP rank, size, and data-loading flag for MegatronMIMO.

slice_batch_for_megatron_mimo

Slice a global micro-batch for this rank’s module-local DP shard.

API#

bridge.data.megatron_mimo.dp_utils._find_rank_module(
grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
) Tuple[HyperCommGrid | None, str | None]#

Find which module grid the current rank belongs to.

bridge.data.megatron_mimo.dp_utils._needs_data_for_module(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
module_name: str,
) bool#

Determine if the current rank needs to load data for the given module.

LLM: first and last PP stage need data (input_ids and labels respectively). Encoders: only the first PP stage needs raw modality inputs.

bridge.data.megatron_mimo.dp_utils.get_megatron_mimo_dp_info(
megatron_mimo_cfg: megatron.bridge.models.megatron_mimo.megatron_mimo_config.MegatronMIMOParallelismConfig,
grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
) Tuple[int, int, bool, str]#

Get module-local DP rank, size, data-loading flag, and module name.

Returns the DP settings for the module that the current rank participates in. These are used by :func:slice_batch_for_megatron_mimo to sub-shard a global micro-batch into per-module DP shards.

.. note:: Do not use these values to construct a DistributedSampler. For sampler construction use :func:get_megatron_mimo_sampling_info instead, which returns settings that keep all data-loading ranks synchronised on the same sample order.

Parameters:
  • megatron_mimo_cfg – MegatronMIMO parallelism configuration.

  • grids – Module name to HyperCommGrid mapping from build_hypercomm_grids().

Returns:

Tuple of (dp_rank, dp_size, needs_data, loader_module).

bridge.data.megatron_mimo.dp_utils.get_megatron_mimo_sampling_info(
megatron_mimo_cfg: megatron.bridge.models.megatron_mimo.megatron_mimo_config.MegatronMIMOParallelismConfig,
grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
) Tuple[int, int, bool]#

Get sampler DP rank, size, and data-loading flag for MegatronMIMO.

In heterogeneous MegatronMIMO, modules may have different DP sizes. The data loader must give every data-loading rank the same global micro-batch so that :func:slice_batch_for_megatron_mimo (called in the forward step) can sub-shard it consistently with the :class:BridgeCommunicator fan-in / fan-out routing.

This function therefore returns dp_size=1, dp_rank=0 for all ranks, disabling DP sharding at the sampler level. Per-module DP sharding is deferred to :func:slice_batch_for_megatron_mimo.

Parameters:
  • megatron_mimo_cfg – MegatronMIMO parallelism configuration.

  • grids – Module name to HyperCommGrid mapping.

Returns:

Tuple of (sampler_dp_rank, sampler_dp_size, needs_data).

bridge.data.megatron_mimo.dp_utils.slice_batch_for_megatron_mimo(
batch: Dict[str, Any],
dp_rank: int,
dp_size: int,
) Dict[str, Any]#

Slice a global micro-batch for this rank’s module-local DP shard.

All data-loading ranks receive the same global micro-batch (the sampler uses dp_size=1). This function contiguously slices it so that each module-local DP replica processes the correct subset. The slicing is contiguous to match the :class:BridgeCommunicator’s batch-dimension split / concatenate logic for fan-out and fan-in routing.

Handles nested dicts (e.g. modality_inputs) by recursing.

Parameters:
  • batch – Global batch dictionary with tensors of shape [global_batch, …]. May contain nested dicts (e.g. modality_inputs → encoder → kwargs).

  • dp_rank – This rank’s position in its module-local DP group.

  • dp_size – Size of the module-local DP group.

Returns:

Dict with tensors sliced to shape [global_batch // dp_size, …].

.. rubric:: Example

global_batch = {‘tokens’: torch.randn(12, 2048)} local_batch = slice_batch_for_megatron_mimo(global_batch, dp_rank=1, dp_size=3) local_batch[‘tokens’].shape # torch.Size([4, 2048])