bridge.training.setup_megatron_mimo#

MegatronMIMO-specific setup for heterogeneous multi-module training.

This module provides the setup logic for MegatronMIMO training, mirroring the standard setup.py but adapted for per-module parallelism.

Key components:

  • setup_megatron_mimo(): MegatronMIMO-specific setup helper (analogous to setup())

  • _set_megatron_mimo_random_seeds(): Per-module TP/PP seed initialization

  • _update_megatron_mimo_model_config_funcs(): Model config hooks (analogous to _update_model_config_funcs)

  • MegatronMIMOSetupOutput: Dataclass containing all setup outputs

Module Contents#

Classes#

MegatronMIMOSetupOutput

Output from setup_megatron_mimo() containing all components needed for training.

Functions#

_set_megatron_mimo_random_seeds

Initialize random seeds with per-module TP/PP awareness.

_update_megatron_mimo_model_config_funcs

Set model config hooks for MegatronMIMO training.

setup_megatron_mimo

MegatronMIMO-specific setup helper.

Data#

API#

bridge.training.setup_megatron_mimo.logger#

‘getLogger(…)’

bridge.training.setup_megatron_mimo._set_megatron_mimo_random_seeds(
cfg: megatron.bridge.training.config.ConfigContainer,
megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
) None#

Initialize random seeds with per-module TP/PP awareness.

Mirrors the standard path’s _set_random_seed() but derives TP/PP ranks from the per-module HyperCommGrids instead of global MPU state.

Must be called after build_infra() (grids exist) and before provide_distributed_model() (weight init needs the CUDA RNG tracker).

class bridge.training.setup_megatron_mimo.MegatronMIMOSetupOutput#

Output from setup_megatron_mimo() containing all components needed for training.

.. attribute:: model

MimoModel (distributed, DDP-wrapped).

.. attribute:: megatron_mimo_infra

MegatronMIMOInfra (grids, topology, pg_collections).

.. attribute:: multimodule_pg_collection

PG collection for schedule.

.. attribute:: multimodule_communicator

MultiModulePipelineCommunicator for P2P.

.. attribute:: module_to_grid_tuple

List of (module, grid) tuples for gradient handling.

.. attribute:: optimizer

MimoOptimizer.

.. attribute:: schedulers

Per-module LR schedulers.

.. attribute:: train_data_iterator

Training data iterator.

.. attribute:: valid_data_iterator

Validation data iterator (optional).

.. attribute:: global_state

GlobalState containing timers, config, train_state.

model: megatron.core.models.mimo.MimoModel#

None

megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra#

None

multimodule_pg_collection: megatron.core.process_groups_config.MultiModuleProcessGroupCollection#

None

multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator#

None

module_to_grid_tuple: List#

None

optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer#

None

schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler]#

None

train_data_iterator: Iterator#

None

valid_data_iterator: Optional[Iterator]#

None

global_state: megatron.bridge.training.state.GlobalState#

None

checkpoint_manager: megatron.bridge.training.checkpointing.CheckpointManager#

None

active_module_name: str#

None

local_pg_collection: megatron.core.process_groups_config.ProcessGroupCollection#

None

bridge.training.setup_megatron_mimo._update_megatron_mimo_model_config_funcs(
model: megatron.core.models.mimo.MimoModel,
optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
megatron_mimo_infra: megatron.bridge.models.megatron_mimo.megatron_mimo_provider.MegatronMIMOInfra,
module_to_grid_tuple: List,
) None#

Set model config hooks for MegatronMIMO training.

Mirrors the standard path’s _update_model_config_funcs (in setup.py) but uses per-module gradient operations instead of global ones.

Sets:

  • no_sync_func: per-module no_sync via multimodule_no_sync

  • finalize_model_grads_func: per-module grad all-reduce via finalize_model_grads_multimodule

  • grad_scale_func: loss scaling from MimoOptimizer

bridge.training.setup_megatron_mimo.setup_megatron_mimo(
state: megatron.bridge.training.state.GlobalState,
build_data_iterators_fn: Optional[Callable] = None,
) bridge.training.setup_megatron_mimo.MegatronMIMOSetupOutput#

MegatronMIMO-specific setup helper.

This function sets up all components needed for MegatronMIMO training:

  • Builds distributed model via cfg.model (an MegatronMIMOProvider)

  • Builds MegatronMIMO infrastructure (grids, topology, pg_collections)

  • Creates MultiModulePipelineCommunicator

  • Creates MimoOptimizer and per-module LR schedulers

  • Loads checkpoint (if one exists)

  • Builds data iterators (if function provided, after checkpoint load)

  • Validates configuration

Parameters:
  • state – GlobalState with state.cfg already set. state.cfg.model must be an MegatronMIMOProvider. state.cfg.optimizer is used to create the optimizer.

  • build_data_iterators_fn – Optional function to build data iterators. Should have signature: (cfg, megatron_mimo_infra) -> (train_iter, valid_iter)

Returns:

MegatronMIMOSetupOutput containing all components for training.