bridge.training.train#
Module Contents#
Functions#
Main training loop. |
|
Single training step. |
|
Synchronizes CUDA streams when the configured interval is reached. |
|
Reports straggler metrics if logging is enabled. |
|
Verifies weight hashes across data-parallel replicas when requested. |
|
Runs manual garbage collection according to the configured interval. |
|
Determine if forward pre-hooks should be disabled during checkpointing. |
|
Enable forward pre-hook for all model chunks. |
|
Disable forward pre-hook for all model chunks. |
|
Force parameter synchronization for all model chunks. |
|
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint. |
|
Computes job and cumulative throughputs and appends to progress log. |
|
Saves a checkpoint and logs the timing. |
|
Handles checkpointing decisions and determines if training should exit. |
|
Check if the current iteration should be skipped and handle it if so. |
|
Single dummy training step to fast forward train_data_iterator. |
|
Copy main params to param buffer for mxfp8 with grad buffer reuse. |
API#
- bridge.training.train.train(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- global_state: megatron.bridge.training.state.GlobalState,
- checkpointing_context: dict[str, Any],
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- process_non_loss_data_func: Optional[Callable] = None,
- non_loss_data_func: Optional[Callable] = None,
Main training loop.
Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.
- Parameters:
forward_step_func – Callable that executes a single forward step.
model – list of model chunks (potentially wrapped in DDP).
optimizer – The optimizer instance.
scheduler – The learning rate scheduler instance.
train_data_iterator – Iterator for the training dataset.
valid_data_iterator – Iterator for the validation dataset.
global_state – The GlobalState object holding various training states.
checkpointing_context – Context dictionary for checkpointing.
process_non_loss_data_func – Optional function to process non-loss data during evaluation.
non_loss_data_func – Optional function to compute non-loss data during evaluation.
.. warning::
This is an experimental API and is subject to change in backwards incompatible ways without notice.
- bridge.training.train.train_step(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- global_state: megatron.bridge.training.state.GlobalState,
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
- forward_backward_func: Callable,
Single training step.
- Parameters:
forward_step_func – Function that performs a forward step (already wrapped if needed)
data_iterator – Iterator over training data
model – list of model chunks
optimizer – Optimizer for model parameters
scheduler – Learning rate scheduler
global_state – Global training state
pg_collection – Process group collection
forward_backward_func – forward-backward function
- Returns:
loss_dict: Dictionary of reduced losses
skipped_iter: Whether the iteration was skipped (1) or not (0)
should_checkpoint: Whether a checkpoint should be saved
should_exit: Whether training should exit
exit_code: Exit code if should_exit is True
grad_norm: Gradient norm if available, None otherwise
num_zeros_in_grad: Number of zeros in gradient if available, None otherwise
- Return type:
tuple containing
- bridge.training.train.maybe_synchronize_training_step(
- train_sync_interval: Optional[int],
- iteration: int,
Synchronizes CUDA streams when the configured interval is reached.
- Parameters:
train_sync_interval – Number of iterations between synchronizations;
Nonedisables it.iteration – Zero-based training iteration counter.
- bridge.training.train.maybe_report_stragglers(
- log_interval: int,
- log_straggler: bool,
- straggler_timer: Any,
- iteration: int,
- num_floating_point_operations_since_last_log_event: float,
Reports straggler metrics if logging is enabled.
- Parameters:
log_interval – Iteration interval for logging.
log_straggler – Whether straggler logging is enabled.
straggler_timer – Timer utility used to record straggler metrics.
iteration – Zero-based training iteration counter.
num_floating_point_operations_since_last_log_event – FLOPs accumulated since the last logging event.
- Returns:
Updated FLOP counter, reset to
0.0when a report is emitted; otherwise the original value.- Return type:
float
- bridge.training.train.maybe_check_weight_hash_across_dp_replicas(
- model: list[megatron.core.transformer.MegatronModule],
- check_weight_hash_across_dp_replicas_interval: Optional[int],
- iteration: int,
- should_toggle_forward_pre_hook: bool,
Verifies weight hashes across data-parallel replicas when requested.
- Parameters:
model – List of model chunks to validate.
check_weight_hash_across_dp_replicas_interval – Interval at which to verify;
Noneto skip.iteration – Zero-based training iteration counter.
should_toggle_forward_pre_hook – Whether the pre-hook must be disabled during the check.
- bridge.training.train.maybe_run_manual_gc(
- manual_gc_enabled: bool,
- manual_gc_interval: int,
- iteration: int,
Runs manual garbage collection according to the configured interval.
- Parameters:
manual_gc_enabled – Whether manual garbage collection is enabled.
manual_gc_interval – Number of iterations between collections;
0disables periodic runs.iteration – Zero-based training iteration counter.
- bridge.training.train.should_disable_forward_pre_hook(
- use_megatron_fsdp: bool,
- use_distributed_optimizer: bool,
- overlap_param_gather: bool,
Determine if forward pre-hooks should be disabled during checkpointing.
Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering
- Parameters:
use_megatron_fsdp – Whether Megatron FSDP is enabled.
use_distributed_optimizer – Whether distributed optimizer is enabled.
overlap_param_gather – Whether parameter gathering is overlapped.
- Returns:
True if forward pre-hooks should be disabled, False otherwise.
.. note::
This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.
- bridge.training.train.enable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
Enable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
- bridge.training.train.disable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
- param_sync: bool = True,
Disable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
param_sync – Whether to synchronize parameters across model chunks
- bridge.training.train.force_param_sync(
- model: list[megatron.core.distributed.DistributedDataParallel],
Force parameter synchronization for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP.
- bridge.training.train.get_start_time_from_progress_log(
- cfg: megatron.bridge.training.config.ConfigContainer,
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.
- bridge.training.train.compute_throughputs_and_append_to_progress_log(
- state: megatron.bridge.training.state.GlobalState,
- num_floating_point_operations_so_far: float,
Computes job and cumulative throughputs and appends to progress log.
Calculates Model TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.
- Parameters:
state – The GlobalState object.
num_floating_point_operations_so_far – Total floating-point operations completed.
- bridge.training.train.save_checkpoint_and_time(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- non_persistent_ckpt: bool = False,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
Saves a checkpoint and logs the timing.
Wraps the
save_checkpointfunction with timers and forces parameter synchronization when using distributed optimizer with overlapped parameter gather to ensure checkpoint correctness.- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative Model TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.
train_data_iterator – Optional training data iterator to save its state.
- bridge.training.train.checkpoint_and_decide_exit(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Handles checkpointing decisions and determines if training should exit.
Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).
- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
train_data_iterator – Optional training data iterator to save its state.
- Returns:
True if the training loop should exit, False otherwise.
- bridge.training.train._finish_train(
- global_state: megatron.bridge.training.state.GlobalState,
- bridge.training.train._should_skip_and_handle_iteration(
- global_state: megatron.bridge.training.state.GlobalState,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Check if the current iteration should be skipped and handle it if so.
This function checks if the current training step is in the iterations_to_skip list, and if so, performs a dummy training step to consume data and update counters.
- Parameters:
global_state – Global state containing training state and configuration
train_data_iterator – Iterator over training data
- Returns:
True if the iteration was skipped, False otherwise
- Return type:
bool
- bridge.training.train._dummy_train_step(
- global_state: megatron.bridge.training.state.GlobalState,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
Single dummy training step to fast forward train_data_iterator.
This function consumes data from the iterator without performing any actual computation, effectively skipping the iteration while maintaining data iterator consistency.
Advance the data iterator on first and last PP stages when data_iterator is not None.
- Parameters:
global_state – Global state containing configuration
train_data_iterator – Iterator over training data
- bridge.training.train._handle_mxfp8_param_buffer_copy(
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- reuse_grad_buf_for_mxfp8_param_ag: bool,
- overlap_param_gather: bool,
Copy main params to param buffer for mxfp8 with grad buffer reuse.
For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.
- Parameters:
optimizer – The MegatronOptimizer instance
reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse
overlap_param_gather – Config flag for overlapping param gathering