bridge.training.nvrx_straggler#

Module Contents#

Classes#

NVRxStragglerDetectionManager

Manager for NVIDIA Resiliency Extension straggler detection in lightning-free training loops.

Functions#

check_nvrx_straggler_detection

Check NVRx straggler detection and determine if training should exit.

safe_shutdown_nvrx_straggler_manager

Safely shutdown the NVRx straggler detection manager with error handling.

API#

class bridge.training.nvrx_straggler.NVRxStragglerDetectionManager(
config: megatron.bridge.training.config.NVRxStragglerDetectionConfig,
)#

Manager for NVIDIA Resiliency Extension straggler detection in lightning-free training loops.

Initialization

Initialize the NVRx straggler detection manager.

Parameters:

config – Configuration for NVRx straggler detection.

Raises:
  • ImportError – If nvidia-resiliency-ext is not available.

  • ValueError – If invalid configuration is provided.

initialize() None#

Initialize the straggler detector.

Raises:

RuntimeError – If already initialized.

wrap_train_step_function(
train_step_func: Callable,
) Callable#

Wrap the training step function with straggler detection monitoring.

Parameters:

train_step_func – The actual training step function to wrap for monitoring.

Returns:

The wrapped training step function.

check_stragglers(global_rank: int) bool#

Check for stragglers and handle reporting.

Parameters:

global_rank – The global rank of the current process.

Returns:

True if stragglers were detected and stop_if_detected is True, False otherwise.

_handle_straggler_report(report) bool#

Handle the straggler report from the detector.

Parameters:

report – The straggler detection report.

Returns:

True if stragglers were found, False otherwise.

_print_stragglers(stragglers) None#

Print straggler detection warnings.

static _format_gpu_scores(
rank_to_score,
rank_to_node,
num_best=3,
num_worst=3,
) str#

Format GPU performance scores for logging.

_print_gpu_scores(report) None#

Print GPU performance scores.

_log_gpu_scores(report) None#

Log GPU performance scores as structured data.

_log_gpu_perf_scores(
rank_to_score,
rank_to_node,
score_prefix,
) None#

Log GPU performance scores with statistics.

_gather_flag_from_rank0(flag: bool) bool#

Broadcast a boolean flag from rank 0 to all ranks.

shutdown() None#

Shutdown the straggler detector.

bridge.training.nvrx_straggler.check_nvrx_straggler_detection(
nvrx_straggler_manager: Optional[bridge.training.nvrx_straggler.NVRxStragglerDetectionManager],
) bool#

Check NVRx straggler detection and determine if training should exit.

Parameters:

nvrx_straggler_manager – The NVRx straggler detection manager, or None if disabled.

Returns:

True if stragglers were detected and training should exit, False otherwise.

Return type:

bool

bridge.training.nvrx_straggler.safe_shutdown_nvrx_straggler_manager(
manager: Optional[bridge.training.nvrx_straggler.NVRxStragglerDetectionManager],
logger_name: str = 'nvrx_straggler',
) None#

Safely shutdown the NVRx straggler detection manager with error handling.

Parameters:
  • manager – The NVRx straggler detection manager to shutdown, can be None.

  • logger_name – Logger name for error reporting.