bridge.models.mimo.mimo_provider#

MIMO Model Provider for heterogeneous multi-module training.

This module provides MimoModelProvider, which integrates with the standard ModelProviderMixin interface to enable MIMO models in the training loop.

Key differences from standard providers:

  • Uses HyperCommGrids for heterogeneous per-module parallelism

  • Has separate build_infra() method for infrastructure metadata

  • Overrides provide_distributed_model() for custom DDP handling

Module Contents#

Classes#

MimoModelInfra

MIMO infrastructure metadata (separate from model).

MimoModelProvider

MIMO provider with heterogeneous parallelism support.

API#

class bridge.models.mimo.mimo_provider.MimoModelInfra#

MIMO infrastructure metadata (separate from model).

This dataclass contains the parallelism infrastructure that MIMO builds, separated from the model itself to maintain the standard provide() contract.

.. attribute:: module_to_grid_map

Mapping of module names to their HyperCommGrids.

.. attribute:: topology

DAG of module data flow (module_name -> list of downstream modules).

.. attribute:: pg_collections

Mapping of module names to ProcessGroupCollections. None for modules this rank doesn’t participate in.

.. attribute:: participating_modules

List of module names this rank participates in.

module_to_grid_map: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid]#

None

topology: Dict[str, List[str]]#

None

pg_collections: Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#

None

participating_modules: List[str]#

None

class bridge.models.mimo.mimo_provider.MimoModelProvider#

Bases: megatron.bridge.models.model_provider.ModelProviderMixin[megatron.core.models.mimo.MimoModel]

MIMO provider with heterogeneous parallelism support.

Integrates with the standard training loop via provide_distributed_model(). Use build_infra() to access MIMO-specific infrastructure (grids, topology, pg_collections).

This provider handles:

  • HyperCommGrid creation per module (heterogeneous parallelism)

  • ProcessGroupCollection extraction from grids

  • pg_collection injection into specs

  • Rank participation checking

  • Freezing logic

Per-Encoder Parallelism: To use different parallelism for each encoder, treat each encoder as a separate module in both modality_submodules_spec and mimo_parallelism_config:

.. rubric:: Example

mimo_parallelism_config = MimoParallelismConfig( … module_parallelisms={ … “llm”: ModuleParallelismConfig(tensor_model_parallel_size=8), … “clip_encoder”: ModuleParallelismConfig(tensor_model_parallel_size=2), … } … ) provider = MimoModelProvider( … language_model_spec=gpt_spec, … modality_submodules_spec={“clip_encoder”: clip_spec}, … mimo_parallelism_config=mimo_parallelism_config, … )

For training loop integration:

model = provider.provide_distributed_model(ddp_config=ddp_config)

Or for manual usage:

model = provider.provide() infra = provider.build_infra()

language_model_spec: megatron.core.transformer.spec_utils.ModuleSpec#

None

modality_submodules_spec: Dict[str, megatron.core.transformer.spec_utils.ModuleSpec]#

‘field(…)’

special_token_ids: Dict[str, int]#

‘field(…)’

mimo_parallelism_config: Optional[megatron.bridge.models.mimo.mimo_config.MimoParallelismConfig]#

None

freeze_language_model: bool#

False

freeze_modality_encoders: Dict[str, bool]#

‘field(…)’

freeze_modality_projections: Dict[str, bool]#

‘field(…)’

fp16: bool#

False

bf16: bool#

True

use_cpu_initialization: bool#

False

init_model_with_meta_device: bool#

False

virtual_pipeline_model_parallel_size: Optional[int]#

None

_cached_infra: Optional[bridge.models.mimo.mimo_provider.MimoModelInfra]#

‘field(…)’

property tensor_model_parallel_size: int#

Return LLM’s tensor parallel size for compatibility with standard code paths.

property pipeline_model_parallel_size: int#

Return LLM’s pipeline parallel size for compatibility with standard code paths.

property context_parallel_size: int#

Return LLM’s context parallel size for compatibility with standard code paths.

build_infra() bridge.models.mimo.mimo_provider.MimoModelInfra#

Build MIMO parallelism infrastructure.

This method builds HyperCommGrids, ProcessGroupCollections, and topology for MIMO’s heterogeneous parallelism. It is idempotent and does not mutate provider state (results are not cached).

Can be called before or after provide(). Call finalize() first to validate the parallelism configuration.

Returns:

MimoModelInfra containing grids, topology, pg_collections, and the list of modules this rank participates in.

_get_pg_collections_from_grids(
grids: Dict[str, megatron.core.hyper_comm_grid.HyperCommGrid],
) Dict[str, Optional[megatron.core.process_groups_config.ProcessGroupCollection]]#

Get ProcessGroupCollections from HyperCommGrids.

Returns None for modules this rank doesn’t participate in.

_inject_pg_collection_into_language_spec(
spec: megatron.core.transformer.spec_utils.ModuleSpec,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) megatron.core.transformer.spec_utils.ModuleSpec#

Deep copy language model spec and inject pg_collection into params.

_inject_pg_collection_into_modality_spec(
spec: megatron.core.transformer.spec_utils.ModuleSpec,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) megatron.core.transformer.spec_utils.ModuleSpec#

Inject pg_collection into encoder specs within a modality submodule.

provide(
pre_process: Optional[bool] = None,
post_process: Optional[bool] = None,
vp_stage: Optional[int] = None,
) megatron.core.models.mimo.MimoModel#

Build and return the MimoModel instance.

This method follows the standard ModelProviderMixin.provide() contract, returning only the model instance. For infrastructure metadata (grids, topology, pg_collections), use build_infra() separately.

Parameters:
  • pre_process – Unused for MIMO (accepted for API compatibility).

  • post_process – Unused for MIMO (accepted for API compatibility).

  • vp_stage – Unused for MIMO (accepted for API compatibility).

Returns:

MimoModel instance.

.. note::

Device/dtype handling is done by provide_distributed_model(), consistent with other providers. This method returns a CPU model.

Raises:

ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).

provide_distributed_model(
ddp_config: Optional[megatron.core.distributed.DistributedDataParallelConfig] = None,
model_type=None,
overlap_param_gather_with_optimizer_step: bool = False,
fp16: Optional[bool] = None,
bf16: Optional[bool] = None,
use_megatron_fsdp: bool = False,
use_torch_fsdp2: bool = False,
wrap_with_ddp: bool = True,
data_parallel_random_init: bool = True,
use_cpu_initialization: Optional[bool] = False,
init_model_with_meta_device: Optional[bool] = None,
pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]] = None,
post_wrap_hook: Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]] = None,
mixed_precision_wrapper: Optional[Callable] = None,
) List[megatron.core.transformer.module.MegatronModule]#

Build MIMO model with heterogeneous parallelism and DDP wrapping.

This overrides the standard ModelProviderMixin implementation because MIMO:

  • Uses per-module HyperCommGrids instead of global mpu

  • Has different pg_collections per module

  • May have ranks that don’t participate in all modules

The method:

  1. Calls finalize() to validate parallelism config

  2. Calls build_infra() to create grids and pg_collections

  3. Calls provide() to build the model

  4. Applies pre-wrap hooks

  5. Moves to device and applies mixed precision

  6. Wraps with DDP using LLM’s pg_collection

  7. Applies post-wrap hooks

Parameters:
  • ddp_config – Configuration for distributed data parallel.

  • model_type – Type of model (unused for MIMO, accepted for compatibility).

  • overlap_param_gather_with_optimizer_step – Whether to overlap param gathering.

  • fp16 – Override FP16 setting.

  • bf16 – Override BF16 setting.

  • use_megatron_fsdp – Use Megatron’s Fully Sharded Data Parallel.

  • use_torch_fsdp2 – Use PyTorch FSDP2.

  • wrap_with_ddp – Whether to wrap model with DDP.

  • data_parallel_random_init – Initialize parameters randomly across DP ranks.

  • use_cpu_initialization – Initialize model on CPU.

  • init_model_with_meta_device – Initialize model on meta device.

  • pre_wrap_hook – Callable(s) to modify model before wrapping.

  • post_wrap_hook – Callable to modify model after wrapping.

  • mixed_precision_wrapper – Wrapper for mixed precision (e.g., Float16Module).

Returns:

List containing the wrapped MimoModel.

Raises:

ValueError – If this rank doesn’t participate in any module (indicates invalid parallelism configuration).

_resolve_hooks(
pre_wrap_hook: Optional[Union[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]], List[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]]],
) Optional[Callable[[List[megatron.core.transformer.module.MegatronModule]], List[megatron.core.transformer.module.MegatronModule]]]#

Resolve pre-wrap hooks to a single callable.

_wrap_with_ddp(
model_list: List[megatron.core.transformer.module.MegatronModule],
ddp_config: megatron.core.distributed.DistributedDataParallelConfig,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
data_parallel_random_init: bool,
overlap_param_gather_with_optimizer_step: bool,
) List[megatron.core.transformer.module.MegatronModule]#

Wrap model with DistributedDataParallel.

initialize_model_parallel(
seed: Optional[int] = None,
seed_kwargs: Optional[dict] = None,
**model_parallel_kwargs,
) None#

MIMO uses its own parallelism via MimoParallelismConfig.

This method is a no-op for MIMO. Parallelism is set up in build_infra() using HyperCommGrids, not global mpu state.

.. note::

Call finalize() to validate the parallelism configuration, then build_infra() to create the HyperCommGrids.

_apply_freezing(model: megatron.core.models.mimo.MimoModel) None#

Apply freezing based on configuration.

finalize() None#

Finalize MIMO parallelism configuration.

This validates the parallelism config and should be called before build_infra() or provide(). It is called automatically by provide_distributed_model().

Raises:

ValueError – If any rank doesn’t participate in at least one module. This indicates the parallelism configuration doesn’t cover all ranks in the world.

_validate_all_ranks_participate(
world_size: Optional[int],
) None#

Validate that all ranks participate in at least one module.

Parameters:

world_size – Total number of ranks. If None, validation is skipped.

Raises:

ValueError – If any rank doesn’t participate in a module.