bridge.training.fault_tolerance
#
Fault Tolerance (FT) package integration for Megatron-Hub, using the FT section-based API.
The FT package is included in “nvidia-resiliency-ext” (https://github.com/NVIDIA/nvidia-resiliency-ext).
NOTE: The workload must be run using the ft_launcher
tool provided by nvidia-resiliency-ext.
NOTE: Calls to the public API of this module are no-ops if FT is not initialized
(ft_integration.setup
was not called).
NOTE: Default distributed process group should be initialized before calling ft_integration.setup
The “setup” FT section is opened during FT initialization and closed before the first training or eval iteration. Training and evaluation steps are wrapped in the “step” section, but only after a few warmup iterations. This is because the initial iterations may be slower, and we want the “step” timeout to be short. These warmup steps, which are not wrapped in the “step” section, will fall into the out-of-section area. All checkpoint-saving-related operations (including asynchronous checkpointing finalization) are wrapped in the “checkpointing” section.
If timeout calculation is enabled (–calc-ft-timeouts), FT timeouts are updated after each checkpoint and at the end of the run. Updated values are based on observed intervals.
ft_launcher
command example:
ft_launcher --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} --nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} --ft-param-rank_section_timeouts=setup:600,step:180,checkpointing:420 --ft-param-rank_out_of_section_timeout=300 train_script_with_ft.py
Module Contents#
Functions#
Initialize fault tolerance integration. |
|
Callback executed at the start of each training step. |
|
Callback executed at the end of each training step. |
|
Callback executed at the start of each evaluation step. |
|
Callback executed at the end of each evaluation step. |
|
Callback executed before checkpoint-saving related operations. |
|
Callback executed after checkpoint-saving related operations. |
|
Callback executed after a checkpoint is loaded. |
|
Shuts down fault tolerance monitoring. |
|
Sets up a simulated fault for fault tolerance testing, if configured. |
|
Load fault tolerance state from file if it exists. |
|
Update fault tolerance timeouts based on observed intervals. |
|
Update timeouts if conditions are met. |
Data#
API#
- bridge.training.fault_tolerance._NUM_WARMUP_ITERS: int#
1
- bridge.training.fault_tolerance._MIN_ITERS_FOR_STEP_TIMEOUT_UPDATE: int#
16
- bridge.training.fault_tolerance.setup(
- config: megatron.bridge.training.config.ConfigContainer,
- global_state: megatron.bridge.training.state.GlobalState,
Initialize fault tolerance integration.
Opens the ‘setup’ FT section.
- Parameters:
config – Configuration container.
global_state – Global training state.
- Raises:
ValueError – If checkpoint save directory is not configured.
- bridge.training.fault_tolerance.on_training_step_start(
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed at the start of each training step.
Closes the ‘setup’ section if open, and starts the ‘step’ section after warmup iterations.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.on_training_step_end(
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed at the end of each training step.
Ends the ‘step’ section if it was started.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.on_eval_step_start(
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed at the start of each evaluation step.
Closes the ‘setup’ section if open, and starts the ‘step’ section after warmup iterations.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.on_eval_step_end(
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed at the end of each evaluation step.
Ends the ‘step’ section if it was started.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.on_checkpointing_start(
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed before checkpoint-saving related operations.
Starts the ‘checkpointing’ FT section.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.on_checkpointing_end(
- is_async_finalization: bool,
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed after checkpoint-saving related operations.
Ends the ‘checkpointing’ FT section and potentially updates timeouts.
- Parameters:
is_async_finalization – True if called after async checkpoint finalization.
global_state – Global training state.
- bridge.training.fault_tolerance.on_checkpoint_loaded(
- is_local_chkpt: bool,
- global_state: megatron.bridge.training.state.GlobalState,
Callback executed after a checkpoint is loaded.
Records whether a persistent checkpoint was loaded for timeout calculation.
- Parameters:
is_local_chkpt – True if a local (non-persistent) checkpoint was loaded.
global_state – Global training state.
- bridge.training.fault_tolerance.shutdown(
- global_state: megatron.bridge.training.state.GlobalState,
Shuts down fault tolerance monitoring.
Updates timeouts if applicable and closes the FT client.
- Parameters:
global_state – Global training state.
- bridge.training.fault_tolerance.maybe_setup_simulated_fault(
- config: megatron.bridge.training.config.FaultToleranceConfig,
Sets up a simulated fault for fault tolerance testing, if configured.
Starts a background thread that will hang or kill a specific rank after a delay.
- Parameters:
config – Fault tolerance configuration object.
- bridge.training.fault_tolerance._load_state_if_exists(
- global_state: megatron.bridge.training.state.GlobalState,
Load fault tolerance state from file if it exists.
- bridge.training.fault_tolerance._update_timeouts(
- selected_sections: List[str],
- calc_out_of_section: bool,
- global_state: megatron.bridge.training.state.GlobalState,
Update fault tolerance timeouts based on observed intervals.
- bridge.training.fault_tolerance._maybe_update_timeouts(
- global_state: megatron.bridge.training.state.GlobalState,
- is_closing_ft: bool = False,
Update timeouts if conditions are met.