bridge.training.mimo_parallel_utils#
Multi-module process group utilities for MIMO heterogeneous parallel training.
This module provides utilities for building process group structures and handling gradients across modules with different parallelism configurations.
Key functions:
unwrap_mimo_model(): Unwrap Float16Module/DDP to get underlying MimoModel
build_pg_collection_for_schedule(): Build pg_collection compatible with schedule
multimodule_no_sync(): Context manager for gradient sync during microbatch accumulation
finalize_model_grads_multimodule(): Finalize gradients for each module
zero_grad_buffer_for_multimodule(): Reset gradient buffers for all modules
validate_no_stub_ranks(): Ensure every rank participates in at least one module
validate_data_loader_contract(): Validate data loading constraints
Module Contents#
Functions#
Unwrap Float16Module/DDP wrappers to get the underlying MimoModel. |
|
Check if current rank participates in the given grid. |
|
Build list of (module, grid) tuples for all modules the current rank participates in. |
|
Build pg_collection compatible with schedule. |
|
Context manager to disable gradient sync for all modules during microbatch accumulation. |
|
Finalize gradients for each module using infra.pg_collections. |
|
Reset gradient buffers for all DDP-wrapped modules. |
|
Ensure every rank participates in at least one module. |
|
Validate data loading constraints for multimodule training. |
Data#
API#
- bridge.training.mimo_parallel_utils.logger#
‘getLogger(…)’
- bridge.training.mimo_parallel_utils.unwrap_mimo_model(model) megatron.core.models.mimo.MimoModel#
Unwrap Float16Module/DDP wrappers to get the underlying MimoModel.
When using mixed precision (bf16/fp16), models are wrapped in Float16Module. This function unwraps the model to access MimoModel-specific attributes like
role,mimo_config,language_model,modality_submodules, etc.- Parameters:
model – A MimoModel or a wrapped version (Float16Module, DDP).
- Returns:
The underlying MimoModel instance.
- Raises:
RuntimeError – If the model cannot be unwrapped to a MimoModel.
- bridge.training.mimo_parallel_utils.is_current_rank_in_grid(
- grid: megatron.core.hyper_comm_grid.HyperCommGrid,
Check if current rank participates in the given grid.
- Parameters:
grid – HyperCommGrid to check participation in.
- Returns:
True if current rank is within the grid’s rank range.
- bridge.training.mimo_parallel_utils.get_module_to_grid_tuple(
- mimo_model: megatron.core.models.mimo.MimoModel,
- infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
Build list of (module, grid) tuples for all modules the current rank participates in.
- Parameters:
mimo_model – The MimoModel instance.
infra – MimoModelInfra containing module_to_grid_map.
- Returns:
List of (module, grid) tuples for modules this rank participates in.
- bridge.training.mimo_parallel_utils.build_pg_collection_for_schedule(
- infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
Build pg_collection compatible with schedule.
Primary: Use MultiModuleProcessGroupCollection if PR 3212 allows missing LLM PG on encoder-only ranks. Fallback: Return list of ProcessGroupCollections for participating modules.
IMPORTANT: Uses infra.pg_collections directly. Do NOT rebuild PGs.
- Parameters:
infra – MimoModelInfra with pg_collections for each module.
- Returns:
MultiModuleProcessGroupCollection or list of ProcessGroupCollections.
- bridge.training.mimo_parallel_utils.multimodule_no_sync(
- *,
- module_to_grid_tuple: List[Tuple],
Context manager to disable gradient sync for all modules during microbatch accumulation.
This function is designed to be used with functools.partial() to pre-bind the module_to_grid_tuple parameter, since the schedule calls no_sync_func() with no arguments.
- Parameters:
module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).
- Yields:
None - context manager for gradient sync control.
- bridge.training.mimo_parallel_utils.finalize_model_grads_multimodule(
- model,
- num_tokens=None,
- pg_collection=None,
- force_all_reduce=None,
- *,
- infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
- module_to_grid_tuple: List[Tuple],
Finalize gradients for each module using infra.pg_collections.
IMPORTANT: Signature matches schedule’s call pattern: config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag)
The
infraandmodule_to_grid_tupleparameters are pre-bound via partial(). We ignore the schedule-providedpg_collectionand use per-module PGs.- Parameters:
model – Model list (passed by schedule, ignored - we use module_to_grid_tuple).
num_tokens – Token count for gradient scaling.
pg_collection – Schedule-provided PG (ignored - we use per-module PGs).
force_all_reduce – Schedule-provided flag (ignored - per-module PGs control sync).
infra – MimoModelInfra with per-module pg_collections (keyword-only, bound via partial).
module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).
- bridge.training.mimo_parallel_utils.zero_grad_buffer_for_multimodule(
- module_to_grid_tuple: List[Tuple],
Reset gradient buffers for all DDP-wrapped modules.
- Parameters:
module_to_grid_tuple – List of (module, grid) tuples.
- bridge.training.mimo_parallel_utils.validate_no_stub_ranks(
- module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
- world_size: int,
Ensure every rank participates in at least one module.
Stub ranks (ranks not participating in any module) are NOT supported. This validation runs at setup time to fail fast with a clear error.
- Parameters:
module_to_grid_map – Mapping of module names to their HyperCommGrids.
world_size – Total number of ranks in the world.
- Raises:
ValueError – If any rank doesn’t participate in a module.
- bridge.training.mimo_parallel_utils.validate_data_loader_contract(
- infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
- global_batch_size: int,
- micro_batch_size: int,
- num_microbatches: int,
Validate data loading constraints for multimodule training.
Checks:
Global batch size divisible by all module DP sizes
Micro-batch size consistent with per-module sharding
num_microbatches * micro_batch_size == global_batch_size / DP_size (per module)
- Parameters:
infra – MimoModelInfra with module_to_grid_map.
global_batch_size – Total batch size across all data parallel ranks.
micro_batch_size – Batch size per microbatch.
num_microbatches – Number of microbatches per iteration.
- Raises:
ValueError – If any constraint is violated.