bridge.training.setup_mimo#
MIMO-specific setup for heterogeneous multi-module training.
This module provides the setup logic for MIMO training, mirroring the standard
setup.py but adapted for per-module parallelism.
Key components:
setup_mimo(): MIMO-specific setup helper (analogous to setup())
_set_mimo_random_seeds(): Per-module TP/PP seed initialization
_update_mimo_model_config_funcs(): Model config hooks (analogous to _update_model_config_funcs)
MimoSetupOutput: Dataclass containing all setup outputs
Module Contents#
Classes#
Output from setup_mimo() containing all components needed for training. |
Functions#
Initialize random seeds with per-module TP/PP awareness. |
|
Set model config hooks for MIMO training. |
|
MIMO-specific setup helper. |
Data#
API#
- bridge.training.setup_mimo.logger#
‘getLogger(…)’
- bridge.training.setup_mimo._set_mimo_random_seeds(
- cfg: megatron.bridge.training.config.ConfigContainer,
- mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
Initialize random seeds with per-module TP/PP awareness.
Mirrors the standard path’s
_set_random_seed()but derives TP/PP ranks from the per-module HyperCommGrids instead of global MPU state.Must be called after
build_infra()(grids exist) and beforeprovide_distributed_model()(weight init needs the CUDA RNG tracker).
- class bridge.training.setup_mimo.MimoSetupOutput#
Output from setup_mimo() containing all components needed for training.
.. attribute:: model
MimoModel (distributed, DDP-wrapped).
.. attribute:: mimo_infra
MimoModelInfra (grids, topology, pg_collections).
.. attribute:: multimodule_pg_collection
PG collection for schedule.
.. attribute:: multimodule_communicator
MultiModulePipelineCommunicator for P2P.
.. attribute:: module_to_grid_tuple
List of (module, grid) tuples for gradient handling.
.. attribute:: optimizer
MimoOptimizer (None when
build_optimizer=False)... attribute:: schedulers
Per-module LR schedulers (empty when
build_optimizer=False)... attribute:: train_data_iterator
Training data iterator.
.. attribute:: valid_data_iterator
Validation data iterator (optional).
.. attribute:: global_state
GlobalState containing timers, config, train_state.
- model: megatron.core.models.mimo.MimoModel#
None
- mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra#
None
- multimodule_pg_collection: megatron.core.process_groups_config.MultiModuleProcessGroupCollection#
None
- multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator#
None
- module_to_grid_tuple: List#
None
- optimizer: Optional[megatron.core.models.mimo.optimizer.MimoOptimizer]#
None
- schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler]#
None
- train_data_iterator: Iterator#
None
- valid_data_iterator: Optional[Iterator]#
None
- global_state: megatron.bridge.training.state.GlobalState#
None
- bridge.training.setup_mimo._update_mimo_model_config_funcs(
- model: megatron.core.models.mimo.MimoModel,
- optimizer: Optional[megatron.core.models.mimo.optimizer.MimoOptimizer],
- mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
- module_to_grid_tuple: List,
Set model config hooks for MIMO training.
Mirrors the standard path’s
_update_model_config_funcs(insetup.py) but uses per-module gradient operations instead of global ones.Sets:
no_sync_func: per-moduleno_syncviamultimodule_no_syncfinalize_model_grads_func: per-module grad all-reduce viafinalize_model_grads_multimodulegrad_scale_func: loss scaling fromMimoOptimizer(if present)
- bridge.training.setup_mimo.setup_mimo(
- cfg: megatron.bridge.training.config.ConfigContainer,
- build_data_iterators_fn: Optional[Callable] = None,
- build_optimizer: bool = True,
- global_state: Optional[megatron.bridge.training.state.GlobalState] = None,
MIMO-specific setup helper.
This function sets up all components needed for MIMO training:
Builds distributed model via
cfg.model(aMimoModelProvider)Builds MIMO infrastructure (grids, topology, pg_collections)
Creates MultiModulePipelineCommunicator
Creates MimoOptimizer and per-module LR schedulers (when
build_optimizer=True)Builds data iterators (if function provided)
Validates configuration
- Parameters:
cfg – ConfigContainer with training configuration.
cfg.modelmust be aMimoModelProvider.cfg.optimizeris used to create the optimizer whenbuild_optimizer=True.build_data_iterators_fn – Optional function to build data iterators. Should have signature: (cfg, mimo_infra) -> (train_iter, valid_iter)
build_optimizer – Whether to create optimizer and schedulers. Set to
Falsefor inference or evaluation-only callers.global_state – Optional GlobalState. If not provided, creates a new one.
- Returns:
MimoSetupOutput containing all components for training.
Reuses from setup.py: - Logging setup (via global_state) - Timer infrastructure (via global_state)