bridge.training.train_mimo#
MIMO Training Loop for heterogeneous multi-module training.
This module provides the dedicated training loop for MIMO models with heterogeneous parallelism. It uses MultiModulePipelineCommunicator for cross-module communication and supports per-module gradient handling.
Key differences from standard train():
Creates MultiModulePipelineCommunicator for cross-module communication
Creates MultiModuleProcessGroupCollection for the schedule
Uses forward_backward_pipelining_without_interleaving with multimodule support
Uses zero_grad_buffer_for_multimodule() for gradient clearing
Supports per-module optimizers
Note: Stub ranks are disallowed - validated at setup time.
Module Contents#
Functions#
Single MIMO training step. |
|
Main MIMO training loop. |
Data#
API#
- bridge.training.train_mimo.logger#
βgetLogger(β¦)β
- bridge.training.train_mimo.train_step_mimo(
- forward_step_func: Callable,
- data_iterator: Iterator,
- model: megatron.core.models.mimo.MimoModel,
- optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
- schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler],
- global_state: megatron.bridge.training.state.GlobalState,
- multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator,
- multimodule_pg_collection,
- infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
- module_to_grid_tuple: List,
- num_microbatches: int,
- seq_length: int,
- micro_batch_size: int,
Single MIMO training step.
- Parameters:
forward_step_func β Forward step function (wrapped with GlobalState).
data_iterator β Iterator over the dataset.
model β MimoModel instance.
optimizer β MimoOptimizer managing per-module optimizers.
schedulers β Per-module learning rate schedulers {module_name: scheduler}.
global_state β GlobalState containing timers, config, train_state.
multimodule_communicator β MultiModulePipelineCommunicator for P2P.
multimodule_pg_collection β PG collection for schedule.
infra β MimoModelInfra with grids, topology, pg_collections.
module_to_grid_tuple β List of (module, grid) tuples.
num_microbatches β Number of microbatches per iteration.
seq_length β Sequence length.
micro_batch_size β Micro batch size.
- Returns:
Tuple of (loss_dict, skipped_iter, grad_norm, num_zeros_in_grad).
- bridge.training.train_mimo.train_mimo(
- forward_step_func: Callable,
- model: megatron.core.models.mimo.MimoModel,
- optimizer: megatron.core.models.mimo.optimizer.MimoOptimizer,
- schedulers: Dict[str, megatron.core.optimizer.optimizer_param_scheduler.OptimizerParamScheduler],
- train_data_iterator: Iterator,
- valid_data_iterator: Optional[Iterator],
- global_state: megatron.bridge.training.state.GlobalState,
- mimo_infra: megatron.bridge.models.mimo.mimo_provider.MimoModelInfra,
- multimodule_communicator: megatron.core.pipeline_parallel.multimodule_communicator.MultiModulePipelineCommunicator,
- multimodule_pg_collection: Optional[megatron.core.process_groups_config.MultiModuleProcessGroupCollection] = None,
- module_to_grid_tuple: Optional[List] = None,
Main MIMO training loop.
Key differences from standard train():
Uses MultiModuleProcessGroupCollection for the schedule
Uses forward_backward_pipelining_without_interleaving with multimodule support
Uses zero_grad_buffer_for_multimodule() for gradient clearing
Uses MimoOptimizer for coordinated gradient clipping with global norm
Reuses from existing Bridge training:
GlobalState for timers, config, train_state
training_log() for metrics reporting
handle_profiling_step() and handle_profiling_stop() for profiler lifecycle
save_checkpoint() with MimoOptimizer for checkpointing
evaluate_and_print_results() for validation with multimodule support
maybe_finalize_async_save() for async checkpoint finalization
- Parameters:
forward_step_func β Forward step function.
model β MimoModel instance.
optimizer β MimoOptimizer managing per-module optimizers.
schedulers β Per-module learning rate schedulers {module_name: scheduler}.
train_data_iterator β Training data iterator.
valid_data_iterator β Validation data iterator (optional).
global_state β GlobalState containing timers, config, train_state.
mimo_infra β MimoModelInfra with grids, topology, pg_collections.
multimodule_communicator β MultiModulePipelineCommunicator for P2P.
multimodule_pg_collection β Pre-built PG collection for the pipeline schedule. If None, built from mimo_infra.
module_to_grid_tuple β Pre-built (module, grid) pairs for gradient ops. If None, built from model and mimo_infra.