bridge.training.nvrx_straggler
#
Module Contents#
Classes#
Manager for NVIDIA Resiliency Extension straggler detection in lightning-free training loops. |
Functions#
Check NVRx straggler detection and determine if training should exit. |
|
Safely shutdown the NVRx straggler detection manager with error handling. |
API#
- class bridge.training.nvrx_straggler.NVRxStragglerDetectionManager(
- config: megatron.bridge.training.config.NVRxStragglerDetectionConfig,
Manager for NVIDIA Resiliency Extension straggler detection in lightning-free training loops.
Initialization
Initialize the NVRx straggler detection manager.
- Parameters:
config β Configuration for NVRx straggler detection.
- Raises:
ImportError β If nvidia-resiliency-ext is not available.
ValueError β If invalid configuration is provided.
- initialize() None #
Initialize the straggler detector.
- Raises:
RuntimeError β If already initialized.
- wrap_train_step_function(
- train_step_func: Callable,
Wrap the training step function with straggler detection monitoring.
- Parameters:
train_step_func β The actual training step function to wrap for monitoring.
- Returns:
The wrapped training step function.
- check_stragglers(global_rank: int) bool #
Check for stragglers and handle reporting.
- Parameters:
global_rank β The global rank of the current process.
- Returns:
True if stragglers were detected and stop_if_detected is True, False otherwise.
- _handle_straggler_report(report) bool #
Handle the straggler report from the detector.
- Parameters:
report β The straggler detection report.
- Returns:
True if stragglers were found, False otherwise.
- _print_stragglers(stragglers) None #
Print straggler detection warnings.
- static _format_gpu_scores(
- rank_to_score,
- rank_to_node,
- num_best=3,
- num_worst=3,
Format GPU performance scores for logging.
- _print_gpu_scores(report) None #
Print GPU performance scores.
- _log_gpu_scores(report) None #
Log GPU performance scores as structured data.
- _log_gpu_perf_scores(
- rank_to_score,
- rank_to_node,
- score_prefix,
Log GPU performance scores with statistics.
- _gather_flag_from_rank0(flag: bool) bool #
Broadcast a boolean flag from rank 0 to all ranks.
- shutdown() None #
Shutdown the straggler detector.
- bridge.training.nvrx_straggler.check_nvrx_straggler_detection(
- nvrx_straggler_manager: Optional[bridge.training.nvrx_straggler.NVRxStragglerDetectionManager],
Check NVRx straggler detection and determine if training should exit.
- Parameters:
nvrx_straggler_manager β The NVRx straggler detection manager, or None if disabled.
- Returns:
True if stragglers were detected and training should exit, False otherwise.
- Return type:
bool
- bridge.training.nvrx_straggler.safe_shutdown_nvrx_straggler_manager(
- manager: Optional[bridge.training.nvrx_straggler.NVRxStragglerDetectionManager],
- logger_name: str = 'nvrx_straggler',
Safely shutdown the NVRx straggler detection manager with error handling.
- Parameters:
manager β The NVRx straggler detection manager to shutdown, can be None.
logger_name β Logger name for error reporting.