Resiliency#
Stable docs: docs/training/resiliency.md, docs/training/checkpointing.md
Card: card.yaml (co-located)
Enablement#
Fault tolerance (Slurm only)#
Option 1: NeMo Run plugin (recommended)#
from megatron.bridge.recipes.run_plugins import FaultTolerancePlugin
import nemo_run as run
task = run.Script(...)
run_plugins = [
FaultTolerancePlugin(
enable_ft_package=True,
calc_ft_timeouts=True,
num_in_job_restarts=3,
num_job_retries_on_failure=2,
initial_rank_heartbeat_timeout=1800,
rank_heartbeat_timeout=300,
)
]
run.run(task, plugins=run_plugins, executor=executor)
Plugin parameter |
Default |
Description |
|---|---|---|
|
3 |
Max restarts within same job |
|
2 |
Max new job launches on failure |
|
1800 |
First heartbeat timeout (seconds) |
|
300 |
Subsequent heartbeat timeout (seconds) |
Option 2: Direct config + ft_launcher#
from megatron.bridge.training.config import FaultToleranceConfig
cfg.ft = FaultToleranceConfig(
enable_ft_package=True,
calc_ft_timeouts=True,
simulate_fault=False,
simulated_fault_type="random",
)
Launch with ft_launcher (not torchrun):
export GROUP_RANK=0 # required for non-Slurm
ft_launcher \
--rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} \
--nnodes=${NUM_NODES} --nproc-per-node=${NUM_GPUS_PER_NODE} \
--ft-rank_section_timeouts=setup:600,step:180,checkpointing:420 \
--ft-rank_out_of_section_timeout=300 \
your_training_script.py
Config parameter |
Default |
Description |
|---|---|---|
|
False |
Enable fault tolerance |
|
False |
Auto-compute optimal timeouts |
|
False |
Enable fault simulation for testing |
|
|
|
|
None |
Specific rank to fault (random if None) |
|
0 |
Base delay before simulating fault |
Section-based timeout monitoring covers setup, training steps, checkpointing,
and out-of-section time independently. Timeouts are saved to ft_state.json
for subsequent runs when calc_ft_timeouts=True.
NVRx straggler detection#
from megatron.bridge.training.config import NVRxStragglerDetectionConfig
cfg.nvrx_straggler = NVRxStragglerDetectionConfig(
enabled=True,
report_time_interval=300.0,
calc_relative_gpu_perf=True,
calc_individual_gpu_perf=True,
num_gpu_perf_scores_to_print=5,
gpu_relative_perf_threshold=0.7,
gpu_individual_perf_threshold=0.7,
stop_if_detected=False,
enable_logging=True,
)
Parameter |
Default |
Description |
|---|---|---|
|
False |
Enable straggler detection |
|
300.0 |
Seconds between straggler checks |
|
True |
Compare ranks against each other |
|
True |
Track per-rank degradation over time |
|
0.7 |
Threshold for relative performance (0-1) |
|
0.7 |
Threshold for individual performance (0-1) |
|
False |
Terminate training on straggler |
|
5 |
Number of best/worst scores to print |
|
1 |
Profiling interval for detector |
Preemption#
Plugin (Slurm)#
from megatron.bridge.recipes.run_plugins import PreemptionPlugin
plugins = [
PreemptionPlugin(
preempt_time=60,
enable_exit_handler=True,
enable_exit_handler_for_data_loader=False,
)
]
Plugin parameter |
Default |
Description |
|---|---|---|
|
60 |
Seconds before job limit to send signal |
|
True |
Enable signal handler in training |
|
False |
Enable for dataloader workers |
Direct config#
import signal
cfg.train.exit_signal_handler = True
cfg.train.exit_signal = signal.SIGTERM
cfg.train.exit_signal_handler_for_dataloader = False
Re-run state machine (experimental)#
from megatron.bridge.training.config import RerunStateMachineConfig
cfg.rerun_state_machine = RerunStateMachineConfig(
rerun_mode="validate_results",
check_for_nan_in_loss=True,
check_for_spiky_loss=False,
spiky_loss_factor=10.0,
)
Parameter |
Default |
Description |
|---|---|---|
|
|
|
|
True |
Check for NaN in loss |
|
False |
Check for unexpectedly large loss |
|
10.0 |
Loss flagged if > factor * max observed (increase for large models) |
Exit codes: 16 = resume to disambiguate, 17 = failed validation.
In-process restart (experimental)#
from megatron.bridge.training.config import InProcessRestartConfig
cfg.inprocess_restart = InProcessRestartConfig(
enabled=True,
granularity="node",
soft_timeout=60.0,
hard_timeout=90.0,
)
Parameter |
Default |
Description |
|---|---|---|
|
False |
Enable in-process restart |
|
None |
Ranks executing workload (rest are warm reserves) |
|
|
|
|
None |
Max restart attempts (None = unlimited) |
|
60.0 |
Detect GIL-released hangs (seconds) |
|
90.0 |
Force-terminate hung ranks (seconds) |
|
30.0 |
Heartbeat interval (seconds) |
|
60.0 |
Missing heartbeat timeout (seconds) |
|
120.0 |
Distributed barrier timeout (seconds) |
|
120.0 |
Completion barrier timeout (seconds) |
|
True |
Clear CUDA cache during restart |
|
None |
Max rank faults before terminating |
|
None |
Directory for monitor logs |
Required environment variables:
export TORCH_CPP_LOG_LEVEL=error
export TORCH_NCCL_RETHROW_CUDA_ERRORS=0
export NCCL_NVLS_ENABLE=0
The PyTorch NCCL watchdog timeout must exceed hard_timeout. NeMo-Runβs
Slurm Executor is not supported; launch directly with srun --kill-on-bad-exit=0.
Async checkpoint save#
cfg.checkpoint.async_save = True
cfg.checkpoint.ckpt_format = "torch_dist"
Local checkpointing (NVRx)#
cfg.checkpoint.non_persistent_local_ckpt_dir = "/local/scratch/ckpt"
cfg.checkpoint.non_persistent_local_ckpt_algo = "fully_parallel"
Code Anchors#
Fault tolerance#
Config:
src/megatron/bridge/training/config.pyβFaultToleranceConfigRuntime:
src/megatron/bridge/training/fault_tolerance.pyPlugin:
src/megatron/bridge/recipes/run_plugins.pyβFaultTolerancePluginPerf plugin:
scripts/performance/resiliency_plugins.pyTests:
tests/unit_tests/training/test_fault_tolerance.pyExample:
examples/resiliency/fault_tolerance/
Straggler detection#
Config:
src/megatron/bridge/training/config.pyβNVRxStragglerDetectionConfigRuntime:
src/megatron/bridge/training/nvrx_straggler.pyTrain loop:
src/megatron/bridge/training/train.pyβcheck_nvrx_straggler_detectionTests:
tests/unit_tests/training/test_nvrx_straggler.py,tests/functional_tests/training/test_nvrx_straggler.pyExample:
examples/resiliency/straggler_detection/
In-process restart#
Config:
src/megatron/bridge/training/config.pyβInProcessRestartConfigRuntime:
src/megatron/bridge/training/inprocess_restart.pyEntry point:
src/megatron/bridge/training/pretrain.pyβmaybe_wrap_for_inprocess_restartTests:
tests/unit_tests/training/test_inprocess_restart.py,tests/functional_tests/training/test_inprocess_restart.py
Preemption#
Plugin:
src/megatron/bridge/recipes/run_plugins.pyβPreemptionPluginSignal handler:
src/megatron/bridge/training/utils/sig_utils.pyTests:
tests/unit_tests/recipes/test_run_plugins.py
Re-run state machine#
Config:
src/megatron/bridge/training/config.pyβRerunStateMachineConfigInit:
src/megatron/bridge/training/initialize.pyβinit_rerun_state
Checkpointing#
Async save:
src/megatron/bridge/training/checkpointing.pyβschedule_async_saveLocal ckpt:
src/megatron/bridge/training/checkpointing.pyβLocalCheckpointManagerTests:
tests/functional_tests/training/test_local_checkpointing.py
Pitfalls#
ft_launcher, not torchrun: Direct
FaultToleranceConfigrequiresft_launcher. Usingtorchrunsilently disables FT. For non-Slurm, setGROUP_RANK=0.Async save requires torch_dist:
async_save=Trueonly works withckpt_format="torch_dist". Other formats silently fail or error.IPR + NeMo-Run: In-process restart is not compatible with NeMo-Run or Slurm preemption plugins. Requires specific PyTorch/NCCL versions and env vars.
NVRx vs legacy straggler: Two detectors exist. Use NVRx (
nvrx_straggler); do not enable both.stop_if_detected default: NVRx logs but does not stop training by default. Set
stop_if_detected=Truefor automatic termination.NCCL watchdog vs hard_timeout: For IPR, NCCL watchdog timeout must exceed
hard_timeoutor PyTorch kills the process before recovery.Rerun state machine is alpha: Use
check_for_nan_in_loss=Truefor NaN detection, but donβt rely on full rerun workflows yet.
Verification#
Fault tolerance#
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh
./examples/resiliency/fault_tolerance/run_fault_tolerance.sh --simulate-fault
Look for [FaultTolerance] / [RankMonitorServer] log lines with section
timeouts. Simulated fault should trigger restart from checkpoint.
Straggler detection#
uv run python -m torch.distributed.run --nproc_per_node=2 \
examples/resiliency/straggler_detection/straggler_detection_example.py
Look for GPU relative performance and GPU individual performance reports
with per-rank scores.
Async checkpoint#
Look for Scheduling async checkpoint save in logs. Training iterations
should continue while checkpoint files are being written.
In-process restart#
pytest tests/functional_tests/training/test_inprocess_restart.py -v
Requires compatible PyTorch/NCCL versions.