bridge.training.train#

Module Contents#

Functions#

train

Main training loop.

train_step

Single training step.

post_training_step_callbacks

Run all post-training-step functions (e.g., FT heartbeats, GC).

should_disable_forward_pre_hook

Determine if forward pre-hooks should be disabled during checkpointing.

enable_forward_pre_hook

Enable forward pre-hook for all model chunks.

disable_forward_pre_hook

Disable forward pre-hook for all model chunks.

get_start_time_from_progress_log

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

compute_throughputs_and_append_to_progress_log

Computes job and cumulative throughputs and appends to progress log.

save_checkpoint_and_time

Saves a checkpoint and logs the timing.

checkpoint_and_decide_exit

Handles checkpointing decisions and determines if training should exit.

_finish_train

_handle_mxfp8_param_buffer_copy

Copy main params to param buffer for mxfp8 with grad buffer reuse.

API#

bridge.training.train.train(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
global_state: megatron.bridge.training.state.GlobalState,
checkpointing_context: dict[str, Any],
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) None#

Main training loop.

Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.

Parameters:
  • forward_step_func – Callable that executes a single forward step.

  • model – list of model chunks (potentially wrapped in DDP).

  • optimizer – The optimizer instance.

  • scheduler – The learning rate scheduler instance.

  • train_data_iterator – Iterator for the training dataset.

  • valid_data_iterator – Iterator for the validation dataset.

  • global_state – The GlobalState object holding various training states.

  • checkpointing_context – Context dictionary for checkpointing.

  • process_non_loss_data_func – Optional function to process non-loss data during evaluation.

  • non_loss_data_func – Optional function to compute non-loss data during evaluation.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.train.train_step(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
global_state: megatron.bridge.training.state.GlobalState,
) tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]#

Single training step.

Parameters:
  • forward_step_func – Function that performs a forward step (already wrapped if needed)

  • data_iterator – Iterator over training data

  • model – list of model chunks

  • optimizer – Optimizer for model parameters

  • scheduler – Learning rate scheduler

  • global_state – Global training state

Returns:

  • loss_dict: Dictionary of reduced losses

  • skipped_iter: Whether the iteration was skipped (1) or not (0)

  • should_checkpoint: Whether a checkpoint should be saved

  • should_exit: Whether training should exit

  • exit_code: Exit code if should_exit is True

  • grad_norm: Gradient norm if available, None otherwise

  • num_zeros_in_grad: Number of zeros in gradient if available, None otherwise

Return type:

tuple containing

bridge.training.train.post_training_step_callbacks(
model: list[megatron.core.transformer.MegatronModule],
num_floating_point_operations_since_last_log_event: float,
straggler_timer: Any,
iteration: int,
prof: Optional[torch.profiler.profile],
config: megatron.bridge.training.config.ConfigContainer,
should_toggle_forward_pre_hook: bool,
nsys_nvtx_context: Optional[megatron.bridge.training.profiling.TNvtxContext] = None,
) None#

Run all post-training-step functions (e.g., FT heartbeats, GC).

Parameters:
  • model – list of model chunks wrapped in DDP

  • num_floating_point_operations_since_last_log_event – Number of floating point operations since last log

  • straggler_timer – Timer for straggler detection

  • iteration – Current training iteration

  • prof – PyTorch profiler instance

  • config – Configuration container

  • should_toggle_forward_pre_hook – Whether to toggle forward pre-hook

  • nsys_nvtx_context – NVTX context for nsys profiling (if active)

bridge.training.train.should_disable_forward_pre_hook(
use_megatron_fsdp: bool,
use_distributed_optimizer: bool,
overlap_param_gather: bool,
) bool#

Determine if forward pre-hooks should be disabled during checkpointing.

Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering

Parameters:
  • use_megatron_fsdp – Whether Megatron FSDP is enabled.

  • use_distributed_optimizer – Whether distributed optimizer is enabled.

  • overlap_param_gather – Whether parameter gathering is overlapped.

Returns:

True if forward pre-hooks should be disabled, False otherwise.

.. note::

This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.

bridge.training.train.enable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
) None#

Enable forward pre-hook for all model chunks.

Parameters:

model – list of model chunks wrapped in DDP

bridge.training.train.disable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
param_sync: bool = True,
) None#

Disable forward pre-hook for all model chunks.

Parameters:
  • model – list of model chunks wrapped in DDP

  • param_sync – Whether to synchronize parameters across model chunks

bridge.training.train.get_start_time_from_progress_log(
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[datetime.datetime, float]#

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

bridge.training.train.compute_throughputs_and_append_to_progress_log(
state: megatron.bridge.training.state.GlobalState,
num_floating_point_operations_so_far: float,
) None#

Computes job and cumulative throughputs and appends to progress log.

Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.

Parameters:
  • state – The GlobalState object.

  • num_floating_point_operations_so_far – Total floating-point operations completed.

bridge.training.train.save_checkpoint_and_time(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
) None#

Saves a checkpoint and logs the timing.

Wraps the save_checkpoint function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.

  • train_data_iterator – Optional training data iterator to save its state.

bridge.training.train.checkpoint_and_decide_exit(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) bool#

Handles checkpointing decisions and determines if training should exit.

Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • train_data_iterator – Optional training data iterator to save its state.

Returns:

True if the training loop should exit, False otherwise.

bridge.training.train._finish_train(
global_state: megatron.bridge.training.state.GlobalState,
)#
bridge.training.train._handle_mxfp8_param_buffer_copy(
optimizer: megatron.core.optimizer.MegatronOptimizer,
reuse_grad_buf_for_mxfp8_param_ag: bool,
overlap_param_gather: bool,
) None#

Copy main params to param buffer for mxfp8 with grad buffer reuse.

For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.

Parameters:
  • optimizer – The MegatronOptimizer instance

  • reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse

  • overlap_param_gather – Config flag for overlapping param gathering