bridge.training.train
#
Module Contents#
Functions#
Main training loop. |
|
Single training step. |
|
Run all post-training-step functions (e.g., FT heartbeats, GC). |
|
Determine if forward pre-hooks should be disabled during checkpointing. |
|
Enable forward pre-hook for all model chunks. |
|
Disable forward pre-hook for all model chunks. |
|
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint. |
|
Computes job and cumulative throughputs and appends to progress log. |
|
Saves a checkpoint and logs the timing. |
|
Handles checkpointing decisions and determines if training should exit. |
|
Copy main params to param buffer for mxfp8 with grad buffer reuse. |
API#
- bridge.training.train.train(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- global_state: megatron.bridge.training.state.GlobalState,
- checkpointing_context: dict[str, Any],
- process_non_loss_data_func: Optional[Callable] = None,
- non_loss_data_func: Optional[Callable] = None,
Main training loop.
Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.
- Parameters:
forward_step_func – Callable that executes a single forward step.
model – list of model chunks (potentially wrapped in DDP).
optimizer – The optimizer instance.
scheduler – The learning rate scheduler instance.
train_data_iterator – Iterator for the training dataset.
valid_data_iterator – Iterator for the validation dataset.
global_state – The GlobalState object holding various training states.
checkpointing_context – Context dictionary for checkpointing.
process_non_loss_data_func – Optional function to process non-loss data during evaluation.
non_loss_data_func – Optional function to compute non-loss data during evaluation.
.. warning::
This is an experimental API and is subject to change in backwards incompatible ways without notice.
- bridge.training.train.train_step(
- forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- global_state: megatron.bridge.training.state.GlobalState,
Single training step.
- Parameters:
forward_step_func – Function that performs a forward step (already wrapped if needed)
data_iterator – Iterator over training data
model – list of model chunks
optimizer – Optimizer for model parameters
scheduler – Learning rate scheduler
global_state – Global training state
- Returns:
loss_dict: Dictionary of reduced losses
skipped_iter: Whether the iteration was skipped (1) or not (0)
should_checkpoint: Whether a checkpoint should be saved
should_exit: Whether training should exit
exit_code: Exit code if should_exit is True
grad_norm: Gradient norm if available, None otherwise
num_zeros_in_grad: Number of zeros in gradient if available, None otherwise
- Return type:
tuple containing
- bridge.training.train.post_training_step_callbacks(
- model: list[megatron.core.transformer.MegatronModule],
- num_floating_point_operations_since_last_log_event: float,
- straggler_timer: Any,
- iteration: int,
- prof: Optional[torch.profiler.profile],
- config: megatron.bridge.training.config.ConfigContainer,
- should_toggle_forward_pre_hook: bool,
- nsys_nvtx_context: Optional[megatron.bridge.training.profiling.TNvtxContext] = None,
Run all post-training-step functions (e.g., FT heartbeats, GC).
- Parameters:
model – list of model chunks wrapped in DDP
num_floating_point_operations_since_last_log_event – Number of floating point operations since last log
straggler_timer – Timer for straggler detection
iteration – Current training iteration
prof – PyTorch profiler instance
config – Configuration container
should_toggle_forward_pre_hook – Whether to toggle forward pre-hook
nsys_nvtx_context – NVTX context for nsys profiling (if active)
- bridge.training.train.should_disable_forward_pre_hook(
- use_megatron_fsdp: bool,
- use_distributed_optimizer: bool,
- overlap_param_gather: bool,
Determine if forward pre-hooks should be disabled during checkpointing.
Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering
- Parameters:
use_megatron_fsdp – Whether Megatron FSDP is enabled.
use_distributed_optimizer – Whether distributed optimizer is enabled.
overlap_param_gather – Whether parameter gathering is overlapped.
- Returns:
True if forward pre-hooks should be disabled, False otherwise.
.. note::
This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.
- bridge.training.train.enable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
Enable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
- bridge.training.train.disable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
- param_sync: bool = True,
Disable forward pre-hook for all model chunks.
- Parameters:
model – list of model chunks wrapped in DDP
param_sync – Whether to synchronize parameters across model chunks
- bridge.training.train.get_start_time_from_progress_log(
- cfg: megatron.bridge.training.config.ConfigContainer,
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.
- bridge.training.train.compute_throughputs_and_append_to_progress_log(
- state: megatron.bridge.training.state.GlobalState,
- num_floating_point_operations_so_far: float,
Computes job and cumulative throughputs and appends to progress log.
Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.
- Parameters:
state – The GlobalState object.
num_floating_point_operations_so_far – Total floating-point operations completed.
- bridge.training.train.save_checkpoint_and_time(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- non_persistent_ckpt: bool = False,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
Saves a checkpoint and logs the timing.
Wraps the
save_checkpoint
function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.
train_data_iterator – Optional training data iterator to save its state.
- bridge.training.train.checkpoint_and_decide_exit(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Handles checkpointing decisions and determines if training should exit.
Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).
- Parameters:
state – The global state object.
model – list of model chunks (MegatronModule instances).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.
checkpointing_context – Dictionary holding checkpointing-related state.
train_data_iterator – Optional training data iterator to save its state.
- Returns:
True if the training loop should exit, False otherwise.
- bridge.training.train._finish_train(
- global_state: megatron.bridge.training.state.GlobalState,
- bridge.training.train._handle_mxfp8_param_buffer_copy(
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- reuse_grad_buf_for_mxfp8_param_ag: bool,
- overlap_param_gather: bool,
Copy main params to param buffer for mxfp8 with grad buffer reuse.
For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.
- Parameters:
optimizer – The MegatronOptimizer instance
reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse
overlap_param_gather – Config flag for overlapping param gathering