bridge.training.train#

Module Contents#

Functions#

train

Main training loop.

train_step

Single training step.

post_training_step_callbacks

Run all post-training-step functions (e.g., FT heartbeats, GC).

should_disable_forward_pre_hook

Determine if forward pre-hooks should be disabled during checkpointing.

enable_forward_pre_hook

Enable forward pre-hook for all model chunks.

disable_forward_pre_hook

Disable forward pre-hook for all model chunks.

get_start_time_from_progress_log

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

compute_throughputs_and_append_to_progress_log

Computes job and cumulative throughputs and appends to progress log.

save_checkpoint_and_time

Saves a checkpoint and logs the timing.

checkpoint_and_decide_exit

Handles checkpointing decisions and determines if training should exit.

_finish_train

_should_skip_and_handle_iteration

Check if the current iteration should be skipped and handle it if so.

_dummy_train_step

Single dummy training step to fast forward train_data_iterator.

_handle_mxfp8_param_buffer_copy

Copy main params to param buffer for mxfp8 with grad buffer reuse.

API#

bridge.training.train.train(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
global_state: megatron.bridge.training.state.GlobalState,
checkpointing_context: dict[str, Any],
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) None#

Main training loop.

Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.

Parameters:
  • forward_step_func – Callable that executes a single forward step.

  • model – list of model chunks (potentially wrapped in DDP).

  • optimizer – The optimizer instance.

  • scheduler – The learning rate scheduler instance.

  • train_data_iterator – Iterator for the training dataset.

  • valid_data_iterator – Iterator for the validation dataset.

  • global_state – The GlobalState object holding various training states.

  • checkpointing_context – Context dictionary for checkpointing.

  • process_non_loss_data_func – Optional function to process non-loss data during evaluation.

  • non_loss_data_func – Optional function to compute non-loss data during evaluation.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.train.train_step(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
global_state: megatron.bridge.training.state.GlobalState,
) tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]#

Single training step.

Parameters:
  • forward_step_func – Function that performs a forward step (already wrapped if needed)

  • data_iterator – Iterator over training data

  • model – list of model chunks

  • optimizer – Optimizer for model parameters

  • scheduler – Learning rate scheduler

  • global_state – Global training state

Returns:

  • loss_dict: Dictionary of reduced losses

  • skipped_iter: Whether the iteration was skipped (1) or not (0)

  • should_checkpoint: Whether a checkpoint should be saved

  • should_exit: Whether training should exit

  • exit_code: Exit code if should_exit is True

  • grad_norm: Gradient norm if available, None otherwise

  • num_zeros_in_grad: Number of zeros in gradient if available, None otherwise

Return type:

tuple containing

bridge.training.train.post_training_step_callbacks(
model: list[megatron.core.transformer.MegatronModule],
num_floating_point_operations_since_last_log_event: float,
straggler_timer: Any,
iteration: int,
prof: Optional[torch.profiler.profile],
config: megatron.bridge.training.config.ConfigContainer,
should_toggle_forward_pre_hook: bool,
nsys_nvtx_context: Optional[megatron.bridge.training.profiling.TNvtxContext] = None,
) None#

Run all post-training-step functions (e.g., FT heartbeats, GC).

Parameters:
  • model – list of model chunks wrapped in DDP

  • num_floating_point_operations_since_last_log_event – Number of floating point operations since last log

  • straggler_timer – Timer for straggler detection

  • iteration – Current training iteration

  • prof – PyTorch profiler instance

  • config – Configuration container

  • should_toggle_forward_pre_hook – Whether to toggle forward pre-hook

  • nsys_nvtx_context – NVTX context for nsys profiling (if active)

bridge.training.train.should_disable_forward_pre_hook(
use_megatron_fsdp: bool,
use_distributed_optimizer: bool,
overlap_param_gather: bool,
) bool#

Determine if forward pre-hooks should be disabled during checkpointing.

Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering

Parameters:
  • use_megatron_fsdp – Whether Megatron FSDP is enabled.

  • use_distributed_optimizer – Whether distributed optimizer is enabled.

  • overlap_param_gather – Whether parameter gathering is overlapped.

Returns:

True if forward pre-hooks should be disabled, False otherwise.

.. note::

This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.

bridge.training.train.enable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
) None#

Enable forward pre-hook for all model chunks.

Parameters:

model – list of model chunks wrapped in DDP

bridge.training.train.disable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
param_sync: bool = True,
) None#

Disable forward pre-hook for all model chunks.

Parameters:
  • model – list of model chunks wrapped in DDP

  • param_sync – Whether to synchronize parameters across model chunks

bridge.training.train.get_start_time_from_progress_log(
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[datetime.datetime, float]#

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

bridge.training.train.compute_throughputs_and_append_to_progress_log(
state: megatron.bridge.training.state.GlobalState,
num_floating_point_operations_so_far: float,
) None#

Computes job and cumulative throughputs and appends to progress log.

Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.

Parameters:
  • state – The GlobalState object.

  • num_floating_point_operations_so_far – Total floating-point operations completed.

bridge.training.train.save_checkpoint_and_time(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
) None#

Saves a checkpoint and logs the timing.

Wraps the save_checkpoint function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.

  • train_data_iterator – Optional training data iterator to save its state.

bridge.training.train.checkpoint_and_decide_exit(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) bool#

Handles checkpointing decisions and determines if training should exit.

Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • train_data_iterator – Optional training data iterator to save its state.

Returns:

True if the training loop should exit, False otherwise.

bridge.training.train._finish_train(
global_state: megatron.bridge.training.state.GlobalState,
)#
bridge.training.train._should_skip_and_handle_iteration(
global_state: megatron.bridge.training.state.GlobalState,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) bool#

Check if the current iteration should be skipped and handle it if so.

This function checks if the current training step is in the iterations_to_skip list, and if so, performs a dummy training step to consume data and update counters.

Parameters:
  • global_state – Global state containing training state and configuration

  • train_data_iterator – Iterator over training data

Returns:

True if the iteration was skipped, False otherwise

Return type:

bool

bridge.training.train._dummy_train_step(
global_state: megatron.bridge.training.state.GlobalState,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) None#

Single dummy training step to fast forward train_data_iterator.

This function consumes data from the iterator without performing any actual computation, effectively skipping the iteration while maintaining data iterator consistency.

Advance the data iterator on first and last PP stages when data_iterator is not None.

Parameters:
  • global_state – Global state containing configuration

  • train_data_iterator – Iterator over training data

bridge.training.train._handle_mxfp8_param_buffer_copy(
optimizer: megatron.core.optimizer.MegatronOptimizer,
reuse_grad_buf_for_mxfp8_param_ag: bool,
overlap_param_gather: bool,
) None#

Copy main params to param buffer for mxfp8 with grad buffer reuse.

For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.

Parameters:
  • optimizer – The MegatronOptimizer instance

  • reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse

  • overlap_param_gather – Config flag for overlapping param gathering