bridge.training.train#

Module Contents#

Functions#

train

Main training loop.

train_step

Single training step.

maybe_synchronize_training_step

Synchronizes CUDA streams when the configured interval is reached.

maybe_report_stragglers

Reports straggler metrics if logging is enabled.

maybe_check_weight_hash_across_dp_replicas

Verifies weight hashes across data-parallel replicas when requested.

maybe_run_manual_gc

Runs manual garbage collection according to the configured interval.

should_disable_forward_pre_hook

Determine if forward pre-hooks should be disabled during checkpointing.

enable_forward_pre_hook

Enable forward pre-hook for all model chunks.

disable_forward_pre_hook

Disable forward pre-hook for all model chunks.

force_param_sync

Force parameter synchronization for all model chunks.

get_start_time_from_progress_log

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

compute_throughputs_and_append_to_progress_log

Computes job and cumulative throughputs and appends to progress log.

save_checkpoint_and_time

Saves a checkpoint and logs the timing.

checkpoint_and_decide_exit

Handles checkpointing decisions and determines if training should exit.

_finish_train

_should_skip_and_handle_iteration

Check if the current iteration should be skipped and handle it if so.

_dummy_train_step

Single dummy training step to fast forward train_data_iterator.

_handle_mxfp8_param_buffer_copy

Copy main params to param buffer for mxfp8 with grad buffer reuse.

API#

bridge.training.train.train(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
global_state: megatron.bridge.training.state.GlobalState,
checkpointing_context: dict[str, Any],
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) None#

Main training loop.

Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.

Parameters:
  • forward_step_func – Callable that executes a single forward step.

  • model – list of model chunks (potentially wrapped in DDP).

  • optimizer – The optimizer instance.

  • scheduler – The learning rate scheduler instance.

  • train_data_iterator – Iterator for the training dataset.

  • valid_data_iterator – Iterator for the validation dataset.

  • global_state – The GlobalState object holding various training states.

  • checkpointing_context – Context dictionary for checkpointing.

  • process_non_loss_data_func – Optional function to process non-loss data during evaluation.

  • non_loss_data_func – Optional function to compute non-loss data during evaluation.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.train.train_step(
forward_step_func: megatron.bridge.training.forward_step_func_types.ForwardStepCallable,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
global_state: megatron.bridge.training.state.GlobalState,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
forward_backward_func: Callable,
) tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]#

Single training step.

Parameters:
  • forward_step_func – Function that performs a forward step (already wrapped if needed)

  • data_iterator – Iterator over training data

  • model – list of model chunks

  • optimizer – Optimizer for model parameters

  • scheduler – Learning rate scheduler

  • global_state – Global training state

  • pg_collection – Process group collection

  • forward_backward_func – forward-backward function

Returns:

  • loss_dict: Dictionary of reduced losses

  • skipped_iter: Whether the iteration was skipped (1) or not (0)

  • should_checkpoint: Whether a checkpoint should be saved

  • should_exit: Whether training should exit

  • exit_code: Exit code if should_exit is True

  • grad_norm: Gradient norm if available, None otherwise

  • num_zeros_in_grad: Number of zeros in gradient if available, None otherwise

Return type:

tuple containing

bridge.training.train.maybe_synchronize_training_step(
train_sync_interval: Optional[int],
iteration: int,
) None#

Synchronizes CUDA streams when the configured interval is reached.

Parameters:
  • train_sync_interval – Number of iterations between synchronizations; None disables it.

  • iteration – Zero-based training iteration counter.

bridge.training.train.maybe_report_stragglers(
log_interval: int,
log_straggler: bool,
straggler_timer: Any,
iteration: int,
num_floating_point_operations_since_last_log_event: float,
) float#

Reports straggler metrics if logging is enabled.

Parameters:
  • log_interval – Iteration interval for logging.

  • log_straggler – Whether straggler logging is enabled.

  • straggler_timer – Timer utility used to record straggler metrics.

  • iteration – Zero-based training iteration counter.

  • num_floating_point_operations_since_last_log_event – FLOPs accumulated since the last logging event.

Returns:

Updated FLOP counter, reset to 0.0 when a report is emitted; otherwise the original value.

Return type:

float

bridge.training.train.maybe_check_weight_hash_across_dp_replicas(
model: list[megatron.core.transformer.MegatronModule],
check_weight_hash_across_dp_replicas_interval: Optional[int],
iteration: int,
should_toggle_forward_pre_hook: bool,
) None#

Verifies weight hashes across data-parallel replicas when requested.

Parameters:
  • model – List of model chunks to validate.

  • check_weight_hash_across_dp_replicas_interval – Interval at which to verify; None to skip.

  • iteration – Zero-based training iteration counter.

  • should_toggle_forward_pre_hook – Whether the pre-hook must be disabled during the check.

bridge.training.train.maybe_run_manual_gc(
manual_gc_enabled: bool,
manual_gc_interval: int,
iteration: int,
) None#

Runs manual garbage collection according to the configured interval.

Parameters:
  • manual_gc_enabled – Whether manual garbage collection is enabled.

  • manual_gc_interval – Number of iterations between collections; 0 disables periodic runs.

  • iteration – Zero-based training iteration counter.

bridge.training.train.should_disable_forward_pre_hook(
use_megatron_fsdp: bool,
use_distributed_optimizer: bool,
overlap_param_gather: bool,
) bool#

Determine if forward pre-hooks should be disabled during checkpointing.

Forward pre-hooks need to be disabled during checkpoint saving when using distributed optimizer with overlapped parameter gathering

Parameters:
  • use_megatron_fsdp – Whether Megatron FSDP is enabled.

  • use_distributed_optimizer – Whether distributed optimizer is enabled.

  • overlap_param_gather – Whether parameter gathering is overlapped.

Returns:

True if forward pre-hooks should be disabled, False otherwise.

.. note::

This is needed to prevent autograd issues during checkpoint saving when using distributed optimizer with parameter gathering overlap.

bridge.training.train.enable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
) None#

Enable forward pre-hook for all model chunks.

Parameters:

model – list of model chunks wrapped in DDP

bridge.training.train.disable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
param_sync: bool = True,
) None#

Disable forward pre-hook for all model chunks.

Parameters:
  • model – list of model chunks wrapped in DDP

  • param_sync – Whether to synchronize parameters across model chunks

bridge.training.train.force_param_sync(
model: list[megatron.core.distributed.DistributedDataParallel],
) None#

Force parameter synchronization for all model chunks.

Parameters:

model – list of model chunks wrapped in DDP.

bridge.training.train.get_start_time_from_progress_log(
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[datetime.datetime, float]#

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

bridge.training.train.compute_throughputs_and_append_to_progress_log(
state: megatron.bridge.training.state.GlobalState,
num_floating_point_operations_so_far: float,
) None#

Computes job and cumulative throughputs and appends to progress log.

Calculates Model TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.

Parameters:
  • state – The GlobalState object.

  • num_floating_point_operations_so_far – Total floating-point operations completed.

bridge.training.train.save_checkpoint_and_time(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
) None#

Saves a checkpoint and logs the timing.

Wraps the save_checkpoint function with timers and forces parameter synchronization when using distributed optimizer with overlapped parameter gather to ensure checkpoint correctness.

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative Model TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.

  • train_data_iterator – Optional training data iterator to save its state.

bridge.training.train.checkpoint_and_decide_exit(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) bool#

Handles checkpointing decisions and determines if training should exit.

Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • train_data_iterator – Optional training data iterator to save its state.

Returns:

True if the training loop should exit, False otherwise.

bridge.training.train._finish_train(
global_state: megatron.bridge.training.state.GlobalState,
)#
bridge.training.train._should_skip_and_handle_iteration(
global_state: megatron.bridge.training.state.GlobalState,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) bool#

Check if the current iteration should be skipped and handle it if so.

This function checks if the current training step is in the iterations_to_skip list, and if so, performs a dummy training step to consume data and update counters.

Parameters:
  • global_state – Global state containing training state and configuration

  • train_data_iterator – Iterator over training data

Returns:

True if the iteration was skipped, False otherwise

Return type:

bool

bridge.training.train._dummy_train_step(
global_state: megatron.bridge.training.state.GlobalState,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) None#

Single dummy training step to fast forward train_data_iterator.

This function consumes data from the iterator without performing any actual computation, effectively skipping the iteration while maintaining data iterator consistency.

Advance the data iterator on first and last PP stages when data_iterator is not None.

Parameters:
  • global_state – Global state containing configuration

  • train_data_iterator – Iterator over training data

bridge.training.train._handle_mxfp8_param_buffer_copy(
optimizer: megatron.core.optimizer.MegatronOptimizer,
reuse_grad_buf_for_mxfp8_param_ag: bool,
overlap_param_gather: bool,
) None#

Copy main params to param buffer for mxfp8 with grad buffer reuse.

For mxfp8_param with reuse_grad_buf_for_mxfp8_param_ag and dp_ag_overlap, we need to call _copy_main_params_to_param_buffer() after the grad buffer is zeroed because param and grad buffer are shared.

Parameters:
  • optimizer – The MegatronOptimizer instance

  • reuse_grad_buf_for_mxfp8_param_ag – Config flag for grad buffer reuse

  • overlap_param_gather – Config flag for overlapping param gathering