bridge.training.train_mimo#

MIMO Training Loop for heterogeneous multi-module training.

This module provides the dedicated training loop for MIMO models with heterogeneous parallelism. It uses MultiModulePipelineCommunicator for cross-module communication and supports per-module gradient handling.

Key differences from standard train():

  • Creates MultiModulePipelineCommunicator for cross-module communication

  • Creates MultiModuleProcessGroupCollection for the schedule

  • Uses forward_backward_pipelining_without_interleaving with multimodule support

  • Uses zero_grad_buffer_for_multimodule() for gradient clearing

  • Supports per-module optimizers

Note: Stub ranks are disallowed - validated at setup time.

Module Contents#

Functions#

train_step_mimo

Single MIMO training step.

train_mimo

Main MIMO training loop.

Data#

API#

bridge.training.train_mimo.logger#

β€˜getLogger(…)’

bridge.training.train_mimo.train_step_mimo(
forward_step_func: Callable,
data_iterator: Iterator,
model: megatron.core.models.mimo.MimoModel,
optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler],
global_state: megatron.bridge.training.state.GlobalState,
multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator,
multimodule_pg_collection,
infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
module_to_grid_tuple: List,
num_microbatches: int,
seq_length: int,
micro_batch_size: int,
) Tuple[Dict[str, torch.Tensor], Optional[float], Optional[int]]#

Single MIMO training step.

Parameters:
  • forward_step_func – Forward step function (wrapped with GlobalState).

  • data_iterator – Iterator over the dataset.

  • model – MimoModel instance.

  • optimizer – MimoOptimizer managing per-module optimizers.

  • schedulers – Per-module learning rate schedulers {module_name: scheduler}.

  • global_state – GlobalState containing timers, config, train_state.

  • multimodule_communicator – MultiModulePipelineCommunicator for P2P.

  • multimodule_pg_collection – PG collection for schedule.

  • infra – MimoModelInfra with grids, topology, pg_collections.

  • module_to_grid_tuple – List of (module, grid) tuples.

  • num_microbatches – Number of microbatches per iteration.

  • seq_length – Sequence length.

  • micro_batch_size – Micro batch size.

Returns:

Tuple of (loss_dict, skipped_iter, grad_norm, num_zeros_in_grad).

bridge.training.train_mimo.train_mimo(
forward_step_func: Callable,
model: megatron.core.models.mimo.MimoModel,
optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler],
train_data_iterator: Iterator,
valid_data_iterator: Optional[Iterator],
global_state: megatron.bridge.training.state.GlobalState,
mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator,
multimodule_pg_collection: Optional[megatron.core.process_groups_config.MultiModuleProcessGroupCollection] = None,
module_to_grid_tuple: Optional[List] = None,
) None#

Main MIMO training loop.

Key differences from standard train():

  • Uses MultiModuleProcessGroupCollection for the schedule

  • Uses forward_backward_pipelining_without_interleaving with multimodule support

  • Uses zero_grad_buffer_for_multimodule() for gradient clearing

  • Uses MimoOptimizer for coordinated gradient clipping with global norm

Reuses from existing Bridge training:

  • GlobalState for timers, config, train_state

  • training_log() for metrics reporting

  • handle_profiling_step() and handle_profiling_stop() for profiler lifecycle

  • save_checkpoint() with MimoOptimizer for checkpointing

  • evaluate_and_print_results() for validation with multimodule support

  • maybe_finalize_async_save() for async checkpoint finalization

Parameters:
  • forward_step_func – Forward step function.

  • model – MimoModel instance.

  • optimizer – MimoOptimizer managing per-module optimizers.

  • schedulers – Per-module learning rate schedulers {module_name: scheduler}.

  • train_data_iterator – Training data iterator.

  • valid_data_iterator – Validation data iterator (optional).

  • global_state – GlobalState containing timers, config, train_state.

  • mimo_infra – MimoModelInfra with grids, topology, pg_collections.

  • multimodule_communicator – MultiModulePipelineCommunicator for P2P.

  • multimodule_pg_collection – Pre-built PG collection for the pipeline schedule. If None, built from mimo_infra.

  • module_to_grid_tuple – Pre-built (module, grid) pairs for gradient ops. If None, built from model and mimo_infra.