bridge.training.setup_mimo#

MIMO-specific setup for heterogeneous multi-module training.

This module provides the setup logic for MIMO training, mirroring the standard setup.py but adapted for per-module parallelism.

Key components:

  • setup_mimo(): MIMO-specific setup helper (analogous to setup())

  • _set_mimo_random_seeds(): Per-module TP/PP seed initialization

  • _update_mimo_model_config_funcs(): Model config hooks (analogous to _update_model_config_funcs)

  • MimoSetupOutput: Dataclass containing all setup outputs

Module Contents#

Classes#

MimoSetupOutput

Output from setup_mimo() containing all components needed for training.

Functions#

_set_mimo_random_seeds

Initialize random seeds with per-module TP/PP awareness.

_update_mimo_model_config_funcs

Set model config hooks for MIMO training.

setup_mimo

MIMO-specific setup helper.

Data#

API#

bridge.training.setup_mimo.logger#

‘getLogger(…)’

bridge.training.setup_mimo._set_mimo_random_seeds(
cfg: megatron.bridge.training.config.ConfigContainer,
mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
) None#

Initialize random seeds with per-module TP/PP awareness.

Mirrors the standard path’s _set_random_seed() but derives TP/PP ranks from the per-module HyperCommGrids instead of global MPU state.

Must be called after build_infra() (grids exist) and before provide_distributed_model() (weight init needs the CUDA RNG tracker).

class bridge.training.setup_mimo.MimoSetupOutput#

Output from setup_mimo() containing all components needed for training.

.. attribute:: model

MimoModel (distributed, DDP-wrapped).

.. attribute:: mimo_infra

MimoModelInfra (grids, topology, pg_collections).

.. attribute:: multimodule_pg_collection

PG collection for schedule.

.. attribute:: multimodule_communicator

MultiModulePipelineCommunicator for P2P.

.. attribute:: module_to_grid_tuple

List of (module, grid) tuples for gradient handling.

.. attribute:: optimizer

MimoOptimizer (None when build_optimizer=False).

.. attribute:: schedulers

Per-module LR schedulers (empty when build_optimizer=False).

.. attribute:: train_data_iterator

Training data iterator.

.. attribute:: valid_data_iterator

Validation data iterator (optional).

.. attribute:: global_state

GlobalState containing timers, config, train_state.

model: megatron.core.models.mimo.MimoModel#

None

mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra#

None

multimodule_pg_collection: megatron.core.process_groups_config.MultiModuleProcessGroupCollection#

None

multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator#

None

module_to_grid_tuple: List#

None

optimizer: Optional[megatron.core.models.mimo.optimizer.MimoOptimizer]#

None

schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler]#

None

train_data_iterator: Iterator#

None

valid_data_iterator: Optional[Iterator]#

None

global_state: megatron.bridge.training.state.GlobalState#

None

bridge.training.setup_mimo._update_mimo_model_config_funcs(
model: megatron.core.models.mimo.MimoModel,
optimizer: Optional[megatron.core.models.mimo.optimizer.MimoOptimizer],
mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
module_to_grid_tuple: List,
) None#

Set model config hooks for MIMO training.

Mirrors the standard path’s _update_model_config_funcs (in setup.py) but uses per-module gradient operations instead of global ones.

Sets:

  • no_sync_func: per-module no_sync via multimodule_no_sync

  • finalize_model_grads_func: per-module grad all-reduce via finalize_model_grads_multimodule

  • grad_scale_func: loss scaling from MimoOptimizer (if present)

bridge.training.setup_mimo.setup_mimo(
cfg: megatron.bridge.training.config.ConfigContainer,
build_data_iterators_fn: Optional[Callable] = None,
build_optimizer: bool = True,
global_state: Optional[megatron.bridge.training.state.GlobalState] = None,
) bridge.training.setup_mimo.MimoSetupOutput#

MIMO-specific setup helper.

This function sets up all components needed for MIMO training:

  • Builds distributed model via cfg.model (a MimoModelProvider)

  • Builds MIMO infrastructure (grids, topology, pg_collections)

  • Creates MultiModulePipelineCommunicator

  • Creates MimoOptimizer and per-module LR schedulers (when build_optimizer=True)

  • Builds data iterators (if function provided)

  • Validates configuration

Parameters:
  • cfg – ConfigContainer with training configuration. cfg.model must be a MimoModelProvider. cfg.optimizer is used to create the optimizer when build_optimizer=True.

  • build_data_iterators_fn – Optional function to build data iterators. Should have signature: (cfg, mimo_infra) -> (train_iter, valid_iter)

  • build_optimizer – Whether to create optimizer and schedulers. Set to False for inference or evaluation-only callers.

  • global_state – Optional GlobalState. If not provided, creates a new one.

Returns:

MimoSetupOutput containing all components for training.

Reuses from setup.py: - Logging setup (via global_state) - Timer infrastructure (via global_state)