bridge.training.train
#
Module Contents#
Functions#
Main training loop. |
|
Single training step. |
|
Run all post-training-step functions (e.g., FT heartbeats, GC). |
|
Enable forward pre-hook for all model chunks. |
|
Disable forward pre-hook for all model chunks. |
|
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint. |
|
Computes job and cumulative throughputs and appends to progress log. |
|
Saves a checkpoint and logs the timing. |
|
Handles checkpointing decisions and determines if training should exit. |
|
API#
- bridge.training.train.train(
- forward_step_func: Callable,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- valid_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- global_state: megatron.bridge.training.state.GlobalState,
- checkpointing_context: dict[str, Any],
- process_non_loss_data_func: Optional[Callable] = None,
- non_loss_data_func: Optional[Callable] = None,
Main training loop.
Handles the overall training process, including the iteration loop, calling train_step, evaluation, checkpointing, logging, and exit conditions.
- Parameters:
forward_step_func β Callable that executes a single forward step.
model β list of model chunks (potentially wrapped in DDP).
optimizer β The optimizer instance.
scheduler β The learning rate scheduler instance.
train_data_iterator β Iterator for the training dataset.
valid_data_iterator β Iterator for the validation dataset.
global_state β The GlobalState object holding various training states.
checkpointing_context β Context dictionary for checkpointing.
process_non_loss_data_func β Optional function to process non-loss data during evaluation.
non_loss_data_func β Optional function to compute non-loss data during evaluation.
.. warning::
This is an experimental API and is subject to change in backwards incompatible ways without notice.
- bridge.training.train.train_step(
- forward_step_func: Callable,
- num_fw_args: int,
- data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- global_state: megatron.bridge.training.state.GlobalState,
Single training step.
- Parameters:
forward_step_func β Function that performs a forward step
num_fw_args β Number of arguments expected by forward_step_func
data_iterator β Iterator over training data
model β list of model chunks
optimizer β Optimizer for model parameters
scheduler β Learning rate scheduler
global_state β Global training state
- Returns:
loss_dict: Dictionary of reduced losses
skipped_iter: Whether the iteration was skipped (1) or not (0)
should_checkpoint: Whether a checkpoint should be saved
should_exit: Whether training should exit
exit_code: Exit code if should_exit is True
grad_norm: Gradient norm if available, None otherwise
num_zeros_in_grad: Number of zeros in gradient if available, None otherwise
- Return type:
tuple containing
- bridge.training.train.post_training_step_callbacks(
- model: list[megatron.core.transformer.MegatronModule],
- num_floating_point_operations_since_last_log_event: float,
- straggler_timer: Any,
- iteration: int,
- prof: Optional[torch.profiler.profile],
- config: megatron.bridge.training.config.ConfigContainer,
- should_toggle_forward_pre_hook: bool,
Run all post-training-step functions (e.g., FT heartbeats, GC).
- Parameters:
model β list of model chunks wrapped in DDP
num_floating_point_operations_since_last_log_event β Number of floating point operations since last log
straggler_timer β Timer for straggler detection
iteration β Current training iteration
prof β PyTorch profiler instance
config β Configuration container
should_toggle_forward_pre_hook β Whether to toggle forward pre-hook
- bridge.training.train.enable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
Enable forward pre-hook for all model chunks.
- Parameters:
model β list of model chunks wrapped in DDP
- bridge.training.train.disable_forward_pre_hook(
- model: list[megatron.core.distributed.DistributedDataParallel],
- param_sync: bool = True,
Disable forward pre-hook for all model chunks.
- Parameters:
model β list of model chunks wrapped in DDP
param_sync β Whether to synchronize parameters across model chunks
- bridge.training.train.get_start_time_from_progress_log(
- cfg: megatron.bridge.training.config.ConfigContainer,
Gets start time of earliest job with same world size. Also returns the number of floating-point operations completed in last saved checkpoint.
- bridge.training.train.compute_throughputs_and_append_to_progress_log(
- state: megatron.bridge.training.state.GlobalState,
- num_floating_point_operations_so_far: float,
Computes job and cumulative throughputs and appends to progress log.
Calculates TFLOP/s/GPU based on floating-point operations and elapsed time. Appends the computed throughputs, total FLOPs, and processed tokens to the progress log file.
- Parameters:
state β The GlobalState object.
num_floating_point_operations_so_far β Total floating-point operations completed.
- bridge.training.train.save_checkpoint_and_time(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- non_persistent_ckpt: bool = False,
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]] = None,
Saves a checkpoint and logs the timing.
Wraps the
save_checkpoint
function with timers and potentially disables/ enables forward pre-hooks if distributed optimizer with overlapped parameter gather is used.- Parameters:
state β The global state object.
model β list of model chunks (MegatronModule instances).
optimizer β The optimizer instance.
opt_param_scheduler β The optimizer parameter scheduler instance.
num_floating_point_operations_so_far β Cumulative TFLOPs up to this point.
checkpointing_context β Dictionary holding checkpointing-related state.
non_persistent_ckpt β Flag indicating if this is a non-persistent (local) checkpoint. Defaults to False.
train_data_iterator β Optional training data iterator to save its state.
- bridge.training.train.checkpoint_and_decide_exit(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- opt_param_scheduler: megatron.core.optimizer_param_scheduler.OptimizerParamScheduler,
- num_floating_point_operations_so_far: float,
- checkpointing_context: dict[str, Any],
- train_data_iterator: Optional[Union[megatron.core.rerun_state_machine.RerunDataIterator, list[megatron.core.rerun_state_machine.RerunDataIterator]]],
Handles checkpointing decisions and determines if training should exit.
Checks various conditions for saving a checkpoint (signal received, interval, duration) and determines if the training loop should terminate based on exit conditions (signal, duration, iteration interval).
- Parameters:
state β The global state object.
model β list of model chunks (MegatronModule instances).
optimizer β The optimizer instance.
opt_param_scheduler β The optimizer parameter scheduler instance.
num_floating_point_operations_so_far β Cumulative TFLOPs up to this point.
checkpointing_context β Dictionary holding checkpointing-related state.
train_data_iterator β Optional training data iterator to save its state.
- Returns:
True if the training loop should exit, False otherwise.
- bridge.training.train._finish_train(
- global_state: megatron.bridge.training.state.GlobalState,