bridge.training.mimo_parallel_utils#

Multi-module process group utilities for MIMO heterogeneous parallel training.

This module provides utilities for building process group structures and handling gradients across modules with different parallelism configurations.

Key functions:

  • unwrap_mimo_model(): Unwrap Float16Module/DDP to get underlying MimoModel

  • build_pg_collection_for_schedule(): Build pg_collection compatible with schedule

  • multimodule_no_sync(): Context manager for gradient sync during microbatch accumulation

  • finalize_model_grads_multimodule(): Finalize gradients for each module

  • zero_grad_buffer_for_multimodule(): Reset gradient buffers for all modules

  • validate_no_stub_ranks(): Ensure every rank participates in at least one module

  • validate_data_loader_contract(): Validate data loading constraints

Module Contents#

Functions#

unwrap_mimo_model

Unwrap Float16Module/DDP wrappers to get the underlying MimoModel.

is_current_rank_in_grid

Check if current rank participates in the given grid.

get_module_to_grid_tuple

Build list of (module, grid) tuples for all modules the current rank participates in.

build_pg_collection_for_schedule

Build pg_collection compatible with schedule.

multimodule_no_sync

Context manager to disable gradient sync for all modules during microbatch accumulation.

finalize_model_grads_multimodule

Finalize gradients for each module using infra.pg_collections.

zero_grad_buffer_for_multimodule

Reset gradient buffers for all DDP-wrapped modules.

validate_no_stub_ranks

Ensure every rank participates in at least one module.

validate_data_loader_contract

Validate data loading constraints for multimodule training.

Data#

API#

bridge.training.mimo_parallel_utils.logger#

‘getLogger(…)’

bridge.training.mimo_parallel_utils.unwrap_mimo_model(model) megatron.core.models.mimo.MimoModel#

Unwrap Float16Module/DDP wrappers to get the underlying MimoModel.

When using mixed precision (bf16/fp16), models are wrapped in Float16Module. This function unwraps the model to access MimoModel-specific attributes like role, mimo_config, language_model, modality_submodules, etc.

Parameters:

model – A MimoModel or a wrapped version (Float16Module, DDP).

Returns:

The underlying MimoModel instance.

Raises:

RuntimeError – If the model cannot be unwrapped to a MimoModel.

bridge.training.mimo_parallel_utils.is_current_rank_in_grid(
grid: megatron.core.hyper_comm_grid.HyperCommGrid,
) bool#

Check if current rank participates in the given grid.

Parameters:

grid – HyperCommGrid to check participation in.

Returns:

True if current rank is within the grid’s rank range.

bridge.training.mimo_parallel_utils.get_module_to_grid_tuple(
mimo_model: megatron.core.models.mimo.MimoModel,
infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
) List[Tuple]#

Build list of (module, grid) tuples for all modules the current rank participates in.

Parameters:
  • mimo_model – The MimoModel instance.

  • infra – MimoModelInfra containing module_to_grid_map.

Returns:

List of (module, grid) tuples for modules this rank participates in.

bridge.training.mimo_parallel_utils.build_pg_collection_for_schedule(
infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
)#

Build pg_collection compatible with schedule.

Primary: Use MultiModuleProcessGroupCollection if PR 3212 allows missing LLM PG on encoder-only ranks. Fallback: Return list of ProcessGroupCollections for participating modules.

IMPORTANT: Uses infra.pg_collections directly. Do NOT rebuild PGs.

Parameters:

infra – MimoModelInfra with pg_collections for each module.

Returns:

MultiModuleProcessGroupCollection or list of ProcessGroupCollections.

bridge.training.mimo_parallel_utils.multimodule_no_sync(
*,
module_to_grid_tuple: List[Tuple],
)#

Context manager to disable gradient sync for all modules during microbatch accumulation.

This function is designed to be used with functools.partial() to pre-bind the module_to_grid_tuple parameter, since the schedule calls no_sync_func() with no arguments.

Parameters:

module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).

Yields:

None - context manager for gradient sync control.

bridge.training.mimo_parallel_utils.finalize_model_grads_multimodule(
model,
num_tokens=None,
pg_collection=None,
force_all_reduce=None,
*,
infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
module_to_grid_tuple: List[Tuple],
)#

Finalize gradients for each module using infra.pg_collections.

IMPORTANT: Signature matches schedule’s call pattern: config.finalize_model_grads_func([model], num_tokens, pg_collection, force_all_reduce=flag)

The infra and module_to_grid_tuple parameters are pre-bound via partial(). We ignore the schedule-provided pg_collection and use per-module PGs.

Parameters:
  • model – Model list (passed by schedule, ignored - we use module_to_grid_tuple).

  • num_tokens – Token count for gradient scaling.

  • pg_collection – Schedule-provided PG (ignored - we use per-module PGs).

  • force_all_reduce – Schedule-provided flag (ignored - per-module PGs control sync).

  • infra – MimoModelInfra with per-module pg_collections (keyword-only, bound via partial).

  • module_to_grid_tuple – List of (module, grid) tuples (keyword-only, bound via partial).

bridge.training.mimo_parallel_utils.zero_grad_buffer_for_multimodule(
module_to_grid_tuple: List[Tuple],
)#

Reset gradient buffers for all DDP-wrapped modules.

Parameters:

module_to_grid_tuple – List of (module, grid) tuples.

bridge.training.mimo_parallel_utils.validate_no_stub_ranks(
module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
world_size: int,
)#

Ensure every rank participates in at least one module.

Stub ranks (ranks not participating in any module) are NOT supported. This validation runs at setup time to fail fast with a clear error.

Parameters:
  • module_to_grid_map – Mapping of module names to their HyperCommGrids.

  • world_size – Total number of ranks in the world.

Raises:

ValueError – If any rank doesn’t participate in a module.

bridge.training.mimo_parallel_utils.validate_data_loader_contract(
infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
global_batch_size: int,
micro_batch_size: int,
num_microbatches: int,
)#

Validate data loading constraints for multimodule training.

Checks:

  • Global batch size divisible by all module DP sizes

  • Micro-batch size consistent with per-module sharding

  • num_microbatches * micro_batch_size == global_batch_size / DP_size (per module)

Parameters:
  • infra – MimoModelInfra with module_to_grid_map.

  • global_batch_size – Total batch size across all data parallel ranks.

  • micro_batch_size – Batch size per microbatch.

  • num_microbatches – Number of microbatches per iteration.

Raises:

ValueError – If any constraint is violated.