bridge.models.mimo.mimo_provider#
MIMO Model Provider for heterogeneous multi-module training.
This module provides MimoModelProvider, which integrates with the standard ModelProviderMixin interface to enable MIMO models in the training loop.
Key differences from standard providers:
Uses HyperCommGrids for heterogeneous per-module parallelism
Has separate build_infra() method for infrastructure metadata
Overrides provide_distributed_model() for custom DDP handling
Module Contents#
Classes#
MIMO infrastructure metadata (separate from model). |
|
MIMO provider with heterogeneous parallelism support. |
API#
- class bridge.models.mimo.mimo_provider.MimoModelInfra#
MIMO infrastructure metadata (separate from model).
This dataclass contains the parallelism infrastructure that MIMO builds, separated from the model itself to maintain the standard provide() contract.
.. attribute:: module_to_grid_map
Mapping of module names to their HyperCommGrids.
.. attribute:: topology
DAG of module data flow (module_name -> list of downstream modules).
.. attribute:: pg_collections
Mapping of module names to ProcessGroupCollections. None for modules this rank doesn’t participate in.
.. attribute:: participating_modules
List of module names this rank participates in.
- module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid]#
None
- topology: Dict[str, List[str]]#
None
- pg_collections: Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#
None
- participating_modules: List[str]#
None
- class bridge.models.mimo.mimo_provider.MimoModelProvider#
Bases:
megatron.bridge.models.model_provider.ModelProviderMixin[megatron.core.models.mimo.MimoModel]MIMO provider with heterogeneous parallelism support.
Integrates with the standard training loop via provide_distributed_model(). Use build_infra() to access MIMO-specific infrastructure (grids, topology, pg_collections).
This provider handles:
HyperCommGrid creation per module (heterogeneous parallelism)
ProcessGroupCollection extraction from grids
pg_collection injection into specs
Rank participation checking
Freezing logic
Per-Encoder Parallelism: To use different parallelism for each encoder, treat each encoder as a separate module in both
modality_submodules_specandmimo_parallelism_config:.. rubric:: Example
mimo_parallelism_config = MimoParallelismConfig( … module_parallelisms={ … “llm”: ModuleParallelismConfig(tensor_model_parallel_size=8), … “clip_encoder”: ModuleParallelismConfig(tensor_model_parallel_size=2), … } … ) provider = MimoModelProvider( … language_model_spec=gpt_spec, … modality_submodules_spec={“clip_encoder”: clip_spec}, … mimo_parallelism_config=mimo_parallelism_config, … )
For training loop integration:
model = provider.provide_distributed_model(ddp_config=ddp_config)
Or for manual usage:
model = provider.provide() infra = provider.build_infra()
- language_model_spec: megatron.core.transformer.spec_utils.ModuleSpec#
None
- modality_submodules_spec: Dict[str, megatron.core.transformer.spec_utils.ModuleSpec]#
‘field(…)’
- special_token_ids: Dict[str, int]#
‘field(…)’
- mimo_parallelism_config: Optional[megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig]#
None
- freeze_language_model: bool#
False
- freeze_modality_encoders: Dict[str, bool]#
‘field(…)’
- freeze_modality_projections: Dict[str, bool]#
‘field(…)’
- fp16: bool#
False
- bf16: bool#
True
- use_cpu_initialization: bool#
False
- init_model_with_meta_device: bool#
False
- virtual_pipeline_model_parallel_size: Optional[int]#
None
- _cached_infra: Optional[bridge.models.mimo.mimo_provider.MimoModelInfra]#
‘field(…)’
- property tensor_model_parallel_size: int#
Return LLM’s tensor parallel size for compatibility with standard code paths.
- property pipeline_model_parallel_size: int#
Return LLM’s pipeline parallel size for compatibility with standard code paths.
- property context_parallel_size: int#
Return LLM’s context parallel size for compatibility with standard code paths.
- build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#
Build MIMO parallelism infrastructure.
This method builds HyperCommGrids, ProcessGroupCollections, and topology for MIMO’s heterogeneous parallelism. It is idempotent and does not mutate provider state (results are not cached).
Can be called before or after provide(). Call finalize() first to validate the parallelism configuration.
- Returns:
MimoModelInfra containing grids, topology, pg_collections, and the list of modules this rank participates in.
- _get_pg_collections_from_grids(
- grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
Get ProcessGroupCollections from HyperCommGrids.
Returns None for modules this rank doesn’t participate in.
- _inject_pg_collection_into_language_spec(
- spec: megatron.core.transformer.spec_utils.ModuleSpec,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Deep copy language model spec and inject pg_collection into params.
- _inject_pg_collection_into_modality_spec(
- spec: megatron.core.transformer.spec_utils.ModuleSpec,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Inject pg_collection into encoder specs within a modality submodule.
- provide(
- pre_process: Optional[bool] = None,
- post_process: Optional[bool] = None,
- vp_stage: Optional[int] = None,
Build and return the MimoModel instance.
This method follows the standard ModelProviderMixin.provide() contract, returning only the model instance. For infrastructure metadata (grids, topology, pg_collections), use build_infra() separately.
- Parameters:
pre_process – Unused for MIMO (accepted for API compatibility).
post_process – Unused for MIMO (accepted for API compatibility).
vp_stage – Unused for MIMO (accepted for API compatibility).
- Returns:
MimoModel instance.
.. note::
Device/dtype handling is done by provide_distributed_model(), consistent with other providers. This method returns a CPU model.
- Raises:
ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).
- provide_distributed_model(
- ddp_config: Optional[megatron.core.distributed.DistributedDataParallelConfig] = None,
- model_type=None,
- overlap_param_gather_with_optimizer_step: bool = False,
- fp16: Optional[bool] = None,
- bf16: Optional[bool] = None,
- use_megatron_fsdp: bool = False,
- use_torch_fsdp2: bool = False,
- wrap_with_ddp: bool = True,
- data_parallel_random_init: bool = True,
- use_cpu_initialization: Optional[bool] = False,
- init_model_with_meta_device: Optional[bool] = None,
- pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]] = None,
- post_wrap_hook: Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]] = None,
- mixed_precision_wrapper: Optional[Callable] = None,
Build MIMO model with heterogeneous parallelism and DDP wrapping.
This overrides the standard ModelProviderMixin implementation because MIMO:
Uses per-module HyperCommGrids instead of global mpu
Has different pg_collections per module
May have ranks that don’t participate in all modules
The method:
Calls finalize() to validate parallelism config
Calls build_infra() to create grids and pg_collections
Calls provide() to build the model
Applies pre-wrap hooks
Moves to device and applies mixed precision
Wraps with DDP using LLM’s pg_collection
Applies post-wrap hooks
- Parameters:
ddp_config – Configuration for distributed data parallel.
model_type – Type of model (unused for MIMO, accepted for compatibility).
overlap_param_gather_with_optimizer_step – Whether to overlap param gathering.
fp16 – Override FP16 setting.
bf16 – Override BF16 setting.
use_megatron_fsdp – Use Megatron’s Fully Sharded Data Parallel.
use_torch_fsdp2 – Use PyTorch FSDP2.
wrap_with_ddp – Whether to wrap model with DDP.
data_parallel_random_init – Initialize parameters randomly across DP ranks.
use_cpu_initialization – Initialize model on CPU.
init_model_with_meta_device – Initialize model on meta device.
pre_wrap_hook – Callable(s) to modify model before wrapping.
post_wrap_hook – Callable to modify model after wrapping.
mixed_precision_wrapper – Wrapper for mixed precision (e.g., Float16Module).
- Returns:
List containing the wrapped MimoModel.
- Raises:
ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).
- _resolve_hooks(
- pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]],
Resolve pre-wrap hooks to a single callable.
- _wrap_with_ddp(
- model_list: List[megatron.core.transformer.module.MegatronModule],
- ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- data_parallel_random_init: bool,
- overlap_param_gather_with_optimizer_step: bool,
Wrap model with DistributedDataParallel.
- initialize_model_parallel(
- seed: Optional[int] = None,
- seed_kwargs: Optional[dict] = None,
- **model_parallel_kwargs,
MIMO uses its own parallelism via MimoParallelismConfig.
This method is a no-op for MIMO. Parallelism is set up in build_infra() using HyperCommGrids, not global mpu state.
.. note::
Call finalize() to validate the parallelism configuration, then build_infra() to create the HyperCommGrids.
- _apply_freezing(model: megatron.core.models.mimo.MimoModel) None#
Apply freezing based on configuration.
- finalize() None#
Finalize MIMO parallelism configuration.
This validates the parallelism config and should be called before build_infra() or provide(). It is called automatically by provide_distributed_model().
- Raises:
ValueError – If any rank doesn’t participate in at least one module. This indicates the parallelism configuration doesn’t cover all ranks in the world.
- _validate_all_ranks_participate(
- world_size: Optional[int],
Validate that all ranks participate in at least one module.
- Parameters:
world_size – Total number of ranks. If None, validation is skipped.
- Raises:
ValueError – If any rank doesn’t participate in a module.