bridge.models.mimo.mimo_provider#

MIMO Model Provider for heterogeneous multi-module training.

This module provides MimoModelProvider, which integrates with the standard ModelProviderMixin interface to enable MIMO models in the training loop.

Key differences from standard providers:

  • Uses HyperCommGrids for heterogeneous per-module parallelism

  • Has separate build_infra() method for infrastructure metadata

  • Overrides provide_distributed_model() for custom DDP handling

Module Contents#

Classes#

MimoModelInfra

MIMO infrastructure metadata (separate from model).

MimoModelProvider

MIMO provider with heterogeneous parallelism support.

API#

class bridge.models.mimo.mimo_provider.MimoModelInfra#

MIMO infrastructure metadata (separate from model).

This dataclass contains the parallelism infrastructure that MIMO builds, separated from the model itself to maintain the standard provide() contract.

.. attribute:: module_to_grid_map

Mapping of module names to their HyperCommGrids.

.. attribute:: topology

DAG of module data flow (module_name -> list of downstream modules).

.. attribute:: pg_collections

Mapping of module names to ProcessGroupCollections. None for modules this rank doesn’t participate in.

.. attribute:: participating_modules

List of module names this rank participates in.

module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid]#

None

topology: Dict[str, List[str]]#

None

pg_collections: Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#

None

participating_modules: List[str]#

None

class bridge.models.mimo.mimo_provider.MimoModelProvider#

Bases: megatron.bridge.models.model_provider.ModelProviderMixin[megatron.core.models.mimo.MimoModel]

MIMO provider with heterogeneous parallelism support.

Integrates with the standard training loop via provide_distributed_model(). Use build_infra() to access MIMO-specific infrastructure (grids, topology, pg_collections).

This provider handles:

  • HyperCommGrid creation per module (heterogeneous parallelism)

  • ProcessGroupCollection extraction from grids

  • pg_collection injection into specs

  • Rank participation checking

  • Freezing logic

Per-Encoder Parallelism: To use different parallelism for each encoder, treat each encoder as a separate module in both modality_submodules_spec and mimo_parallelism_config:

.. rubric:: Example

mimo_parallelism_config = MimoParallelismConfig( … module_parallelisms={ … “llm”: ModuleParallelismConfig(tensor_model_parallel_size=8), … “clip_encoder”: ModuleParallelismConfig(tensor_model_parallel_size=2), … } … ) provider = MimoModelProvider( … language_model_spec=gpt_spec, … modality_submodules_spec={“clip_encoder”: clip_spec}, … mimo_parallelism_config=mimo_parallelism_config, … )

For training loop integration:

model = provider.provide_distributed_model(ddp_config=ddp_config)

Or for manual usage:

model = provider.provide() infra = provider.build_infra()

language_model_spec: megatron.core.transformer.spec_utils.ModuleSpec#

None

modality_submodules_spec: Dict[str, megatron.core.transformer.spec_utils.ModuleSpec]#

‘field(…)’

special_token_ids: Dict[str, int]#

‘field(…)’

mimo_parallelism_config: Optional[megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig]#

None

_cached_infra: Optional[bridge.models.mimo.mimo_provider.MimoModelInfra]#

‘field(…)’

freeze_language_model: bool#

False

freeze_modality_encoders: Dict[str, bool]#

‘field(…)’

freeze_modality_projections: Dict[str, bool]#

‘field(…)’

fp16: bool#

False

bf16: bool#

True

use_cpu_initialization: bool#

False

init_model_with_meta_device: bool#

False

virtual_pipeline_model_parallel_size: Optional[int]#

None

property tensor_model_parallel_size: int#

Return LLM’s tensor parallel size for compatibility with standard code paths.

property pipeline_model_parallel_size: int#

Return LLM’s pipeline parallel size for compatibility with standard code paths.

property context_parallel_size: int#

Return LLM’s context parallel size for compatibility with standard code paths.

build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#

Build MIMO parallelism infrastructure.

This method builds HyperCommGrids, ProcessGroupCollections, and topology for MIMO’s heterogeneous parallelism. It does not mutate provider state. Use get_or_build_infra() when cached reuse is desired.

Can be called before or after provide(). Call finalize() first to validate the parallelism configuration.

Returns:

MimoModelInfra containing grids, topology, pg_collections, and the list of modules this rank participates in.

get_or_build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#

Return cached MIMO infrastructure, building it once if needed.

_get_pg_collections_from_grids(
grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
) Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#

Get ProcessGroupCollections from HyperCommGrids.

Creates all standard process groups plus embedding groups for PP > 1. Returns None for modules this rank doesn’t participate in.

_inject_pg_collection_into_language_spec(
spec: megatron.core.transformer.spec_utils.ModuleSpec,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) megatron.core.transformer.spec_utils.ModuleSpec#

Deep copy language model spec and inject pg_collection into params.

_inject_pg_collection_into_modality_spec(
spec: megatron.core.transformer.spec_utils.ModuleSpec,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) megatron.core.transformer.spec_utils.ModuleSpec#

Inject pg_collection into encoder specs within a modality submodule.

provide(
pre_process: Optional[bool] = None,
post_process: Optional[bool] = None,
vp_stage: Optional[int] = None,
) megatron.core.models.mimo.MimoModel#

Build and return the MimoModel instance.

This method follows the standard ModelProviderMixin.provide() contract, returning only the model instance. For infrastructure metadata (grids, topology, pg_collections), use build_infra() separately.

Parameters:
  • pre_process – Unused for MIMO (accepted for API compatibility).

  • post_process – Unused for MIMO (accepted for API compatibility).

  • vp_stage – Unused for MIMO (accepted for API compatibility).

Returns:

MimoModel instance.

.. note::

Device/dtype handling is done by provide_distributed_model(), consistent with other providers. This method returns a CPU model.

Raises:

ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).

provide_distributed_model(
ddp_config: Optional[megatron.core.distributed.DistributedDataParallelConfig] = None,
model_type=None,
overlap_param_gather_with_optimizer_step: bool = False,
fp16: Optional[bool] = None,
bf16: Optional[bool] = None,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = False,
use_cpu_initialization: Optional[bool] = False,
init_model_with_meta_device: Optional[bool] = None,
pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]] = None,
post_wrap_hook: Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]] = None,
mixed_precision_wrapper: Optional[Callable] = None,
) List[megatron.core.transformer.module.MegatronModule]#

Build MIMO model with heterogeneous parallelism and DDP wrapping.

This overrides the standard ModelProviderMixin implementation because MIMO:

  • Uses per-module HyperCommGrids instead of global mpu

  • Has different pg_collections per module

  • May have ranks that don’t participate in all modules

  • Requires per-submodule DDP wrapping for correct gradient sync

The method:

  1. Calls finalize() to validate parallelism config

  2. Calls build_infra() to create grids and pg_collections

  3. Calls provide() to build the model

  4. Applies pre-wrap hooks

  5. Moves to device

  6. Wraps each submodule with DDP using its own pg_collection

  7. Applies mixed precision (Float16Module)

  8. Applies post-wrap hooks

Parameters:
  • ddp_config – Configuration for distributed data parallel.

  • model_type – Type of model (unused for MIMO, accepted for compatibility).

  • overlap_param_gather_with_optimizer_step – Whether to overlap param gathering.

  • fp16 – Override FP16 setting.

  • bf16 – Override BF16 setting.

  • use_megatron_fsdp – Use Megatron’s Fully Sharded Data Parallel.

  • use_torch_fsdp2 – Use PyTorch FSDP2.

  • wrap_with_ddp – Whether to wrap model with DDP.

  • data_parallel_random_init – Initialize parameters randomly across DP ranks.

  • use_cpu_initialization – Initialize model on CPU.

  • init_model_with_meta_device – Initialize model on meta device.

  • pre_wrap_hook – Callable(s) to modify model before wrapping.

  • post_wrap_hook – Callable to modify model after wrapping.

  • mixed_precision_wrapper – Wrapper for mixed precision (e.g., Float16Module).

Returns:

List containing the wrapped MimoModel.

Raises:

ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).

_resolve_hooks(
pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]],
) Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]#

Resolve pre-wrap hooks to a single callable.

initialize_model_parallel(
seed: Optional[int] = None,
seed_kwargs: Optional[dict] = None,
**model_parallel_kwargs,
) None#

MIMO uses its own parallelism via MimoParallelismConfig.

This method is a no-op for MIMO. Parallelism is set up in build_infra() using HyperCommGrids, not global mpu state.

.. note::

Call finalize() to validate the parallelism configuration, then build_infra() to create the HyperCommGrids.

_apply_freezing(model: megatron.core.models.mimo.MimoModel) None#

Apply freezing based on configuration.

finalize() None#

Finalize MIMO parallelism configuration.

This validates the parallelism config and should be called before build_infra() or provide(). It is called automatically by provide_distributed_model().

Raises:

ValueError – If any rank doesn’t participate in at least one module. This indicates the parallelism configuration doesn’t cover all ranks in the world (validated by MimoParallelismConfig.finalize()).