bridge.training.megatron_mimo_parallel_utils#
Multi-module process group utilities for MegatronMIMO heterogeneous parallel training.
This module provides utilities for building process group structures and handling gradients across modules with different parallelism configurations.
Key functions:
unwrap_megatron_mimo_model(): Unwrap Float16Module/DDP to get underlying MimoModel
build_pg_collection_for_schedule(): Build pg_collection compatible with schedule
multimodule_no_sync(): Context manager for gradient sync during microbatch accumulation
finalize_model_grads_multimodule(): Finalize gradients for each module
zero_grad_buffer_for_multimodule(): Reset gradient buffers for all modules
validate_no_stub_ranks(): Ensure every rank participates in at least one module
validate_data_loader_contract(): Validate data loading constraints
Module Contents#
Functions#
Get the DP dimension size from a grid’s shape metadata. |
|
Unwrap Float16Module/DDP wrappers to get the underlying MimoModel. |
|
Check if current rank participates in the given grid. |
|
Return the (module_name, pg_collection) for the single active module on this rank. |
|
Build list of (module, grid) tuples for all modules the current rank participates in. |
|
Build pg_collection compatible with schedule. |
|
Context manager to disable gradient sync for all modules during microbatch accumulation. |
|
Finalize gradients for each module using infra.pg_collections. |
|
Reset gradient buffers for all DDP-wrapped modules. |
|
Ensure every rank participates in at least one module. |
|
Validate data loading constraints for multimodule training. |
Data#
API#
- bridge.training.megatron_mimo_parallel_utils.logger#
‘getLogger(…)’
- bridge.training.megatron_mimo_parallel_utils._get_dp_size_from_grid(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
Get the DP dimension size from a grid’s shape metadata.
Uses grid.shape / grid.dim_names rather than process groups so that it works on ALL ranks, including those outside the grid.
- bridge.training.megatron_mimo_parallel_utils.unwrap_megatron_mimo_model(
- model,
Unwrap Float16Module/DDP wrappers to get the underlying MimoModel.
When using mixed precision (bf16/fp16), models are wrapped in Float16Module. This function unwraps the model to access MimoModel-specific attributes like
role,mimo_config,language_model,modality_submodules, etc.- Parameters:
model – A MimoModel or a wrapped version (Float16Module, DDP).
- Returns:
The underlying MimoModel instance.
- Raises:
RuntimeError – If the model cannot be unwrapped to a MimoModel.
- bridge.training.megatron_mimo_parallel_utils.is_current_rank_in_grid(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
Check if current rank participates in the given grid.
- Parameters:
grid – HyperCommGrid to check participation in.
- Returns:
True if current rank is within the grid’s rank range.
- bridge.training.megatron_mimo_parallel_utils.get_active_module_pg(
- megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
Return the (module_name, pg_collection) for the single active module on this rank.
Non-colocated MegatronMIMO assigns each rank to exactly one module. This helper extracts that module’s name and
ProcessGroupCollection.- Raises:
AssertionError – If more or fewer than one module is active on this rank.
- bridge.training.megatron_mimo_parallel_utils.get_module_to_grid_tuple(
- megatron_mimo_model: megatron.core.models.mimo.MimoModel,
- infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
Build list of (module, grid) tuples for all modules the current rank participates in.
- Parameters:
megatron_mimo_model – The MimoModel instance.
infra – MegatronMIMOInfra containing module_to_grid_map.
- Returns:
List of (module, grid) tuples for modules this rank participates in.
- bridge.training.megatron_mimo_parallel_utils.build_pg_collection_for_schedule(
- infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
Build pg_collection compatible with schedule.
Primary: Use MultiModuleProcessGroupCollection if PR 3212 allows missing LLM PG on encoder-only ranks. Fallback: Return list of ProcessGroupCollections for participating modules.
IMPORTANT: Uses infra.pg_collections directly. Do NOT rebuild PGs.
- Parameters:
infra – MegatronMIMOInfra with pg_collections for each module.
- Returns:
MultiModuleProcessGroupCollection or list of ProcessGroupCollections.
- bridge.training.megatron_mimo_parallel_utils.multimodule_no_sync(
- *,
- module_to_grid_tuple: List[Tuple],
Context manager to disable gradient sync for all modules during microbatch accumulation.
This function is designed to be used with functools.partial() to pre-bind the module_to_grid_tuple parameter, since the schedule calls no_sync_func() with no arguments.
- Parameters:
module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).
- Yields:
None - context manager for gradient sync control.
- bridge.training.megatron_mimo_parallel_utils.finalize_model_grads_multimodule(
- model,
- num_tokens=None,
- pg_collection=None,
- force_all_reduce=None,
- *,
- infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
- module_to_grid_tuple: List[Tuple],
Finalize gradients for each module using infra.pg_collections.
IMPORTANT: Signature matches schedule’s call pattern: config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag)
The
infraandmodule_to_grid_tupleparameters are pre-bound via partial(). We ignore the schedule-providedpg_collectionand use per-module PGs.When encoder DP > LLM DP (heterogeneous), the LLM’s loss normalization divides by tokens for ALL samples it processes, but after bridge fan-out each encoder DP rank only carries gradient for (encoder_dp / llm_dp) fewer samples. This makes encoder gradients too small by a factor of encoder_dp / llm_dp. We compensate after DDP finalization by scaling encoder gradients back up.
- Parameters:
model – Model list (passed by schedule, ignored - we use module_to_grid_tuple).
num_tokens – Token count for gradient scaling.
pg_collection – Schedule-provided PG (ignored - we use per-module PGs).
force_all_reduce – Schedule-provided flag (ignored - per-module PGs control sync).
infra – MegatronMIMOInfra with per-module pg_collections (keyword-only, bound via partial).
module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).
- bridge.training.megatron_mimo_parallel_utils.zero_grad_buffer_for_multimodule(
- module_to_grid_tuple: List[Tuple],
Reset gradient buffers for all DDP-wrapped modules.
- Parameters:
module_to_grid_tuple – List of (module, grid) tuples.
- bridge.training.megatron_mimo_parallel_utils.validate_no_stub_ranks(
- module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
- world_size: int,
Ensure every rank participates in at least one module.
Stub ranks (ranks not participating in any module) are NOT supported. This validation runs at setup time to fail fast with a clear error.
- Parameters:
module_to_grid_map – Mapping of module names to their HyperCommGrids.
world_size – Total number of ranks in the world.
- Raises:
ValueError – If any rank doesn’t participate in a module.
- bridge.training.megatron_mimo_parallel_utils.validate_data_loader_contract(
- infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
- global_batch_size: int,
- micro_batch_size: int,
- num_microbatches: int,
Validate data loading constraints for multimodule training.
Checks:
MIMO micro-batch size divisible by all module DP sizes
Global batch size divisible by all module DP sizes
num_microbatches * micro_batch_size == global_batch_size
- Parameters:
infra – MegatronMIMOInfra with module_to_grid_map.
global_batch_size – Total MIMO batch size per optimizer step.
micro_batch_size – Global MIMO batch size per microbatch before module-local DP slicing.
num_microbatches – Number of microbatches per iteration.
- Raises:
ValueError – If any constraint is violated.