bridge.models.mimo.mimo_provider#
MIMO Model Provider for heterogeneous multi-module training.
This module provides MimoModelProvider, which integrates with the standard ModelProviderMixin interface to enable MIMO models in the training loop.
Key differences from standard providers:
Uses HyperCommGrids for heterogeneous per-module parallelism
Has separate build_infra() method for infrastructure metadata
Overrides provide_distributed_model() for custom DDP handling
Module Contents#
Classes#
MIMO infrastructure metadata (separate from model). |
|
MIMO provider with heterogeneous parallelism support. |
API#
- class bridge.models.mimo.mimo_provider.MimoModelInfra#
MIMO infrastructure metadata (separate from model).
This dataclass contains the parallelism infrastructure that MIMO builds, separated from the model itself to maintain the standard provide() contract.
.. attribute:: module_to_grid_map
Mapping of module names to their HyperCommGrids.
.. attribute:: topology
DAG of module data flow (module_name -> list of downstream modules).
.. attribute:: pg_collections
Mapping of module names to ProcessGroupCollections. None for modules this rank doesn’t participate in.
.. attribute:: participating_modules
List of module names this rank participates in.
- module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid]#
None
- topology: Dict[str, List[str]]#
None
- pg_collections: Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#
None
- participating_modules: List[str]#
None
- class bridge.models.mimo.mimo_provider.MimoModelProvider#
Bases:
megatron.bridge.models.model_provider.ModelProviderMixin[megatron.core.models.mimo.MimoModel]MIMO provider with heterogeneous parallelism support.
Integrates with the standard training loop via provide_distributed_model(). Use build_infra() to access MIMO-specific infrastructure (grids, topology, pg_collections).
This provider handles:
HyperCommGrid creation per module (heterogeneous parallelism)
ProcessGroupCollection extraction from grids
pg_collection injection into specs
Rank participation checking
Freezing logic
Per-Encoder Parallelism: To use different parallelism for each encoder, treat each encoder as a separate module in both
modality_submodules_specandmimo_parallelism_config:.. rubric:: Example
mimo_parallelism_config = MimoParallelismConfig( … module_parallelisms={ … “llm”: ModuleParallelismConfig(tensor_model_parallel_size=8), … “clip_encoder”: ModuleParallelismConfig(tensor_model_parallel_size=2), … } … ) provider = MimoModelProvider( … language_model_spec=gpt_spec, … modality_submodules_spec={“clip_encoder”: clip_spec}, … mimo_parallelism_config=mimo_parallelism_config, … )
For training loop integration:
model = provider.provide_distributed_model(ddp_config=ddp_config)
Or for manual usage:
model = provider.provide() infra = provider.build_infra()
- language_model_spec: megatron.core.transformer.spec_utils.ModuleSpec#
None
- modality_submodules_spec: Dict[str, megatron.core.transformer.spec_utils.ModuleSpec]#
‘field(…)’
- special_token_ids: Dict[str, int]#
‘field(…)’
- mimo_parallelism_config: Optional[megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig]#
None
- _cached_infra: Optional[bridge.models.mimo.mimo_provider.MimoModelInfra]#
‘field(…)’
- freeze_language_model: bool#
False
- freeze_modality_encoders: Dict[str, bool]#
‘field(…)’
- freeze_modality_projections: Dict[str, bool]#
‘field(…)’
- fp16: bool#
False
- bf16: bool#
True
- use_cpu_initialization: bool#
False
- init_model_with_meta_device: bool#
False
- virtual_pipeline_model_parallel_size: Optional[int]#
None
- property tensor_model_parallel_size: int#
Return LLM’s tensor parallel size for compatibility with standard code paths.
- property pipeline_model_parallel_size: int#
Return LLM’s pipeline parallel size for compatibility with standard code paths.
- property context_parallel_size: int#
Return LLM’s context parallel size for compatibility with standard code paths.
- build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#
Build MIMO parallelism infrastructure.
This method builds HyperCommGrids, ProcessGroupCollections, and topology for MIMO’s heterogeneous parallelism. It does not mutate provider state. Use get_or_build_infra() when cached reuse is desired.
Can be called before or after provide(). Call finalize() first to validate the parallelism configuration.
- Returns:
MimoModelInfra containing grids, topology, pg_collections, and the list of modules this rank participates in.
- get_or_build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#
Return cached MIMO infrastructure, building it once if needed.
- _get_pg_collections_from_grids(
- grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
Get ProcessGroupCollections from HyperCommGrids.
Creates all standard process groups plus embedding groups for PP > 1. Returns None for modules this rank doesn’t participate in.
- _inject_pg_collection_into_language_spec(
- spec: megatron.core.transformer.spec_utils.ModuleSpec,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Deep copy language model spec and inject pg_collection into params.
- _inject_pg_collection_into_modality_spec(
- spec: megatron.core.transformer.spec_utils.ModuleSpec,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Inject pg_collection into encoder specs within a modality submodule.
- provide(
- pre_process: Optional[bool] = None,
- post_process: Optional[bool] = None,
- vp_stage: Optional[int] = None,
Build and return the MimoModel instance.
This method follows the standard ModelProviderMixin.provide() contract, returning only the model instance. For infrastructure metadata (grids, topology, pg_collections), use build_infra() separately.
- Parameters:
pre_process – Unused for MIMO (accepted for API compatibility).
post_process – Unused for MIMO (accepted for API compatibility).
vp_stage – Unused for MIMO (accepted for API compatibility).
- Returns:
MimoModel instance.
.. note::
Device/dtype handling is done by provide_distributed_model(), consistent with other providers. This method returns a CPU model.
- Raises:
ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).
- provide_distributed_model(
- ddp_config: Optional[megatron.core.distributed.DistributedDataParallelConfig] = None,
- model_type=None,
- overlap_param_gather_with_optimizer_step: bool = False,
- fp16: Optional[bool] = None,
- bf16: Optional[bool] = None,
- use_megatron_fsdp: bool = False,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = False,
- use_cpu_initialization: Optional[bool] = False,
- init_model_with_meta_device: Optional[bool] = None,
- pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]] = None,
- post_wrap_hook: Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]] = None,
- mixed_precision_wrapper: Optional[Callable] = None,
Build MIMO model with heterogeneous parallelism and DDP wrapping.
This overrides the standard ModelProviderMixin implementation because MIMO:
Uses per-module HyperCommGrids instead of global mpu
Has different pg_collections per module
May have ranks that don’t participate in all modules
Requires per-submodule DDP wrapping for correct gradient sync
The method:
Calls finalize() to validate parallelism config
Calls build_infra() to create grids and pg_collections
Calls provide() to build the model
Applies pre-wrap hooks
Moves to device
Wraps each submodule with DDP using its own pg_collection
Applies mixed precision (Float16Module)
Applies post-wrap hooks
- Parameters:
ddp_config – Configuration for distributed data parallel.
model_type – Type of model (unused for MIMO, accepted for compatibility).
overlap_param_gather_with_optimizer_step – Whether to overlap param gathering.
fp16 – Override FP16 setting.
bf16 – Override BF16 setting.
use_megatron_fsdp – Use Megatron’s Fully Sharded Data Parallel.
use_torch_fsdp2 – Use PyTorch FSDP2.
wrap_with_ddp – Whether to wrap model with DDP.
data_parallel_random_init – Initialize parameters randomly across DP ranks.
use_cpu_initialization – Initialize model on CPU.
init_model_with_meta_device – Initialize model on meta device.
pre_wrap_hook – Callable(s) to modify model before wrapping.
post_wrap_hook – Callable to modify model after wrapping.
mixed_precision_wrapper – Wrapper for mixed precision (e.g., Float16Module).
- Returns:
List containing the wrapped MimoModel.
- Raises:
ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).
- _resolve_hooks(
- pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]],
Resolve pre-wrap hooks to a single callable.
- initialize_model_parallel(
- seed: Optional[int] = None,
- seed_kwargs: Optional[dict] = None,
- **model_parallel_kwargs,
MIMO uses its own parallelism via MimoParallelismConfig.
This method is a no-op for MIMO. Parallelism is set up in build_infra() using HyperCommGrids, not global mpu state.
.. note::
Call finalize() to validate the parallelism configuration, then build_infra() to create the HyperCommGrids.
- _apply_freezing(model: megatron.core.models.mimo.MimoModel) None#
Apply freezing based on configuration.
- finalize() None#
Finalize MIMO parallelism configuration.
This validates the parallelism config and should be called before build_infra() or provide(). It is called automatically by provide_distributed_model().
- Raises:
ValueError – If any rank doesn’t participate in at least one module. This indicates the parallelism configuration doesn’t cover all ranks in the world (validated by MimoParallelismConfig.finalize()).