bridge.training.train#

Module Contents#

Functions#

train

Main training loop.

train_step

Single training step.

post_training_step_callbacks

Run all post-training-step functions (e.g., FT heartbeats, GC).

enable_forward_pre_hook

Enable forward pre-hook for all model chunks.

disable_forward_pre_hook

Disable forward pre-hook for all model chunks.

get_start_time_from_progress_log

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

compute_throughputs_and_append_to_progress_log

Computes job and cumulative throughputs and appends to progress log.

save_checkpoint_and_time

Saves a checkpoint and logs the timing.

checkpoint_and_decide_exit

Handles checkpointing decisions and determines if training should exit.

_finish_train

API#

bridge.training.train.train(
forward_step_func: Callable,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
global_state: megatron.bridge.training.state.GlobalState,
checkpointing_context: dict[str, Any],
process_non_loss_data_func: Optional[Callable] = None,
non_loss_data_func: Optional[Callable] = None,
) None#

Main training loop.

Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.

Parameters:
  • forward_step_func – Callable that executes a single forward step.

  • model – list of model chunks (potentially wrapped in DDP).

  • optimizer – The optimizer instance.

  • scheduler – The learning rate scheduler instance.

  • train_data_iterator – Iterator for the training dataset.

  • valid_data_iterator – Iterator for the validation dataset.

  • global_state – The GlobalState object holding various training states.

  • checkpointing_context – Context dictionary for checkpointing.

  • process_non_loss_data_func – Optional function to process non-loss data during evaluation.

  • non_loss_data_func – Optional function to compute non-loss data during evaluation.

.. warning::

This is an experimental API and is subject to change in backwards incompatible ways without notice.

bridge.training.train.train_step(
forward_step_func: Callable,
num_fw_args: int,
data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
global_state: megatron.bridge.training.state.GlobalState,
) tuple[dict[str, torch.Tensor], int, bool, bool, int, Optional[float], Optional[int]]#

Single training step.

Parameters:
  • forward_step_func – Function that performs a forward step

  • num_fw_args – Number of arguments expected by forward_step_func

  • data_iterator – Iterator over training data

  • model – list of model chunks

  • optimizer – Optimizer for model parameters

  • scheduler – Learning rate scheduler

  • global_state – Global training state

Returns:

  • loss_dict: Dictionary of reduced losses

  • skipped_iter: Whether the iteration was skipped (1) or not (0)

  • should_checkpoint: Whether a checkpoint should be saved

  • should_exit: Whether training should exit

  • exit_code: Exit code if should_exit is True

  • grad_norm: Gradient norm if available, None otherwise

  • num_zeros_in_grad: Number of zeros in gradient if available, None otherwise

Return type:

tuple containing

bridge.training.train.post_training_step_callbacks(
model: list[megatron.core.transformer.MegatronModule],
num_floating_point_operations_since_last_log_event: float,
straggler_timer: Any,
iteration: int,
prof: Optional[torch.profiler.profile],
config: megatron.bridge.training.config.ConfigContainer,
should_toggle_forward_pre_hook: bool,
) None#

Run all post-training-step functions (e.g., FT heartbeats, GC).

Parameters:
  • model – list of model chunks wrapped in DDP

  • num_floating_point_operations_since_last_log_event – Number of floating point operations since last log

  • straggler_timer – Timer for straggler detection

  • iteration – Current training iteration

  • prof – PyTorch profiler instance

  • config – Configuration container

  • should_toggle_forward_pre_hook – Whether to toggle forward pre-hook

bridge.training.train.enable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
) None#

Enable forward pre-hook for all model chunks.

Parameters:

model – list of model chunks wrapped in DDP

bridge.training.train.disable_forward_pre_hook(
model: list[megatron.core.distributed.DistributedDataParallel],
param_sync: bool = True,
) None#

Disable forward pre-hook for all model chunks.

Parameters:
  • model – list of model chunks wrapped in DDP

  • param_sync – Whether to synchronize parameters across model chunks

bridge.training.train.get_start_time_from_progress_log(
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[datetime.datetime, float]#

Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.

bridge.training.train.compute_throughputs_and_append_to_progress_log(
state: megatron.bridge.training.state.GlobalState,
num_floating_point_operations_so_far: float,
) None#

Computes job and cumulative throughputs and appends to progress log.

Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.

Parameters:
  • state – The GlobalState object.

  • num_floating_point_operations_so_far – Total floating-point operations completed.

bridge.training.train.save_checkpoint_and_time(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
) None#

Saves a checkpoint and logs the timing.

Wraps the save_checkpoint function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • non_persistent_ckpt – Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.

  • train_data_iterator – Optional training data iterator to save its state.

bridge.training.train.checkpoint_and_decide_exit(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: megatron.core.optimizer.MegatronOptimizer,
opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
num_floating_point_operations_so_far: float,
checkpointing_context: dict[str, Any],
train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
) bool#

Handles checkpointing decisions and determines if training should exit.

Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).

Parameters:
  • state – The global state object.

  • model – list of model chunks (MegatronModule instances).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Cumulative TFLOPs up to this point.

  • checkpointing_context – Dictionary holding checkpointing-related state.

  • train_data_iterator – Optional training data iterator to save its state.

Returns:

True if the training loop should exit, False otherwise.

bridge.training.train._finish_train(
global_state: megatron.bridge.training.state.GlobalState,
)#