bridge.training.setup_megatron_mimo#
MegatronMIMO-specific setup for heterogeneous multi-module training.
This module provides the setup logic for MegatronMIMO training, mirroring the standard
setup.py but adapted for per-module parallelism.
Key components:
setup_megatron_mimo(): MegatronMIMO-specific setup helper (analogous to setup())
_set_megatron_mimo_random_seeds(): Per-module TP/PP seed initialization
_update_megatron_mimo_model_config_funcs(): Model config hooks (analogous to _update_model_config_funcs)
MegatronMIMOSetupOutput: Dataclass containing all setup outputs
Module Contents#
Classes#
Output from setup_megatron_mimo() containing all components needed for training. |
Functions#
Initialize random seeds with per-module TP/PP awareness. |
|
Set model config hooks for MegatronMIMO training. |
|
MegatronMIMO-specific setup helper. |
Data#
API#
- bridge.training.setup_megatron_mimo.logger#
‘getLogger(…)’
- bridge.training.setup_megatron_mimo._set_megatron_mimo_random_seeds(
- cfg: megatron.bridge.training.config.ConfigContainer,
- megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
Initialize random seeds with per-module TP/PP awareness.
Mirrors the standard path’s
_set_random_seed()but derives TP/PP ranks from the per-module HyperCommGrids instead of global MPU state.Must be called after
build_infra()(grids exist) and beforeprovide_distributed_model()(weight init needs the CUDA RNG tracker).
- class bridge.training.setup_megatron_mimo.MegatronMIMOSetupOutput#
Output from setup_megatron_mimo() containing all components needed for training.
.. attribute:: model
MimoModel (distributed, DDP-wrapped).
.. attribute:: megatron_mimo_infra
MegatronMIMOInfra (grids, topology, pg_collections).
.. attribute:: multimodule_pg_collection
PG collection for schedule.
.. attribute:: multimodule_communicator
MultiModulePipelineCommunicator for P2P.
.. attribute:: module_to_grid_tuple
List of (module, grid) tuples for gradient handling.
.. attribute:: optimizer
MimoOptimizer.
.. attribute:: schedulers
Per-module LR schedulers.
.. attribute:: train_data_iterator
Training data iterator.
.. attribute:: valid_data_iterator
Validation data iterator (optional).
.. attribute:: global_state
GlobalState containing timers, config, train_state.
- model: megatron.core.models.mimo.MimoModel#
None
- megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra#
None
- multimodule_pg_collection: megatron.core.process_groups_config.MultiModuleProcessGroupCollection#
None
- multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator#
None
- module_to_grid_tuple: List#
None
- optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer#
None
- schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler]#
None
- train_data_iterator: Iterator#
None
- valid_data_iterator: Optional[Iterator]#
None
- global_state: megatron.bridge.training.state.GlobalState#
None
- checkpoint_manager: megatron.bridge.training.checkpointing.CheckpointManager#
None
- active_module_name: str#
None
- local_pg_collection: megatron.core.process_groups_config.ProcessGroupCollection#
None
- bridge.training.setup_megatron_mimo._update_megatron_mimo_model_config_funcs(
- model: megatron.core.models.mimo.MimoModel,
- optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
- megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
- module_to_grid_tuple: List,
Set model config hooks for MegatronMIMO training.
Mirrors the standard path’s
_update_model_config_funcs(insetup.py) but uses per-module gradient operations instead of global ones.Sets:
no_sync_func: per-moduleno_syncviamultimodule_no_syncfinalize_model_grads_func: per-module grad all-reduce viafinalize_model_grads_multimodulegrad_scale_func: loss scaling fromMimoOptimizer
- bridge.training.setup_megatron_mimo.setup_megatron_mimo(
- state: megatron.bridge.training.state.GlobalState,
- build_data_iterators_fn: Optional[Callable] = None,
MegatronMIMO-specific setup helper.
This function sets up all components needed for MegatronMIMO training:
Builds distributed model via
cfg.model(anMegatronMIMOProvider)Builds MegatronMIMO infrastructure (grids, topology, pg_collections)
Creates MultiModulePipelineCommunicator
Creates MimoOptimizer and per-module LR schedulers
Loads checkpoint (if one exists)
Builds data iterators (if function provided, after checkpoint load)
Validates configuration
- Parameters:
state – GlobalState with
state.cfgalready set.state.cfg.modelmust be anMegatronMIMOProvider.state.cfg.optimizeris used to create the optimizer.build_data_iterators_fn – Optional function to build data iterators. Should have signature: (cfg, megatron_mimo_infra) -> (train_iter, valid_iter)
- Returns:
MegatronMIMOSetupOutput containing all components for training.