bridge.training.train
#
Module Contents#
Functions#
Main training loop. |
|
Single training step. |
|
Run all post-training-step functions (e.g., FT heartbeats, GC). |
|
Determine if forward pre-hooks should be disabled during checkpointing. |
|
Enable forward pre-hook for all model chunks. |
|
Disable forward pre-hook for all model chunks. |
|
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint. |
|
Computes job and cumulative throughputs and appends to progress log. |
|
Saves a checkpoint and logs the timing. |
|
Handles checkpointing decisions and determines if training should exit. |
|
Check if the current iteration should be skipped and handle it if so. |
|
Single dummy training step to fast forward train_data_iterator. |
|
Copy main params to param buffer for mxfp8 with grad buffer reuse. |
API#
- bridge.training.train.train(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- global_state: megatron.bridge.training.state.GlobalState,
- checkpointing_context: dict[str, Any],
- process_non_loss_data_func: Optional[Callable] = None,
- non_loss_data_func: Optional[Callable] = None,
Main training loop.
Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.
- Parameters:
forward_step_func – Callable that executes a single forward step.
model – list of model chunks (potentially wrapped in DDP).
optimizer – The optimizer instance.
scheduler – The learning rate scheduler instance.
train_data_iterator – Iterator for the training dataset.
valid_data_iterator – Iterator for the validation dataset.
global_state – The GlobalState object holding various training states.
checkpointing_context – Context dictionary for checkpointing.
process_non_loss_data_func – Optional function to process non-loss data during evaluation.
non_loss_data_func – Optional function to compute non-loss data during evaluation.
.. warning::
This is an experimental API and is subject to change in backwards incompatible ways without notice.
- bridge.training.train.train_step(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- global_state: megatron.bridge.training.state.GlobalState,
Single training step.
- Parameters:
forward_step_func – Function that performs a forward step (already wrapped if needed)
data_iterator – Iterator over training data
model – list of model chunks
optimizer – Optimizer for model parameters
scheduler – Learning rate scheduler
global_state – Global training state
- Returns:
loss_dict: Dictionary of reduced losses
skipped_iter: Whether the iteration was skipped (1) or not (0)
should_checkpoint: Whether a checkpoint should be saved
should_exit: Whether training should exit
exit_code: Exit code if should_exit is True
grad_norm: Gradient norm if available, None otherwise
num_zeros_in_grad: Number of zeros in gradient if available, None otherwise
- Return type:
tuple containing
- bridge.training.train.post_training_step_callbacks(
- model: list[megatron.core.transformer.MegatronModule],
- num_floating_point_operations_since_last_log_event: float,
- straggler_timer: Any,
- iteration: int,
- prof: Optional[torch.profiler.profile],
- config: megatron.bridge.training.config.ConfigContainer,
- should_toggle_forward_pre_hook: bool,
- nsys_nvtx_context: Optional[megatron.bridge.training.profiling.TNvtxContext] = None,
Run all post-training-step functions (e.g., FT heartbeats, GC).
- Parameters:
model – list of model chunks wrapped in DDP
num_floating_point_operations_since_last_log_event – Number of floating point operations since last log
straggler_timer – Timer for straggler detection
iteration – Current training iteration
prof – PyTorch profiler instance
config – Configuration container
should_toggle_forward_pre_hook – Whether to toggle forward pre-hook
nsys_nvtx_context – NVTX context for nsys profiling (if active)
- bridge.training.train.should_disable_forward_pre_hook(
- use_megatron_fsdp: bool,
- use_distributed_optimizer: bool,
- overlap_param_gather: bool,
Determine if forward pre-hooks should be disabled during checkpointing.
Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering
- Parameters:
use_megatron_fsdp – Whether Megatron FSDP is enabled.
use_distributed_optimizer – Whether distributed optimizer is enabled.
overlap_param_gather – Whether parameter gathering is overlapped.
- Returns:
True if forward pre-hooks should be disabled, False otherwise.
.. note::
This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.
- bridge.training.train.enable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
Enable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
- bridge.training.train.disable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
- param_sync: bool = True,
Disable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
param_sync – Whether to synchronize parameters across model chunks
- bridge.training.train.get_start_time_from_progress_log(
- cfg: megatron.bridge.training.config.ConfigContainer,
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.
- bridge.training.train.compute_throughputs_and_append_to_progress_log(
- state: megatron.bridge.training.state.GlobalState,
- num_floating_point_operations_so_far: float,
Computes job and cumulative throughputs and appends to progress log.
Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.
- Parameters:
state – The GlobalState object.
num_floating_point_operations_so_far – Total floating-point operations completed.
- bridge.training.train.save_checkpoint_and_time(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- non_persistent_ckpt: bool = False,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
Saves a checkpoint and logs the timing.
Wraps the
save_checkpoint
function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.
train_data_iterator – Optional training data iterator to save its state.
- bridge.training.train.checkpoint_and_decide_exit(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Handles checkpointing decisions and determines if training should exit.
Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).
- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
train_data_iterator – Optional training data iterator to save its state.
- Returns:
True if the training loop should exit, False otherwise.
- bridge.training.train._finish_train(
- global_state: megatron.bridge.training.state.GlobalState,
- bridge.training.train._should_skip_and_handle_iteration(
- global_state: megatron.bridge.training.state.GlobalState,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Check if the current iteration should be skipped and handle it if so.
This function checks if the current training step is in the iterations_to_skip list, and if so, performs a dummy training step to consume data and update counters.
- Parameters:
global_state – Global state containing training state and configuration
train_data_iterator – Iterator over training data
- Returns:
True if the iteration was skipped, False otherwise
- Return type:
bool
- bridge.training.train._dummy_train_step(
- global_state: megatron.bridge.training.state.GlobalState,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Single dummy training step to fast forward train_data_iterator.
This function consumes data from the iterator without performing any actual computation, effectively skipping the iteration while maintaining data iterator consistency.
Advance the data iterator on first and last PP stages when data_iterator is not None.
- Parameters:
global_state – Global state containing configuration
train_data_iterator – Iterator over training data
- bridge.training.train._handle_mxfp8_param_buffer_copy(
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- reuse_grad_buf_for_mxfp8_param_ag: bool,
- overlap_param_gather: bool,
Copy main params to param buffer for mxfp8 with grad buffer reuse.
For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.
- Parameters:
optimizer – The MegatronOptimizer instance
reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse
overlap_param_gather – Config flag for overlapping param gathering