bridge.training.config#

Module Contents#

Classes#

DistributedDataParallelConfig

Megatron Core DistributedDataParallelConfig with deferred post-init.

OptimizerConfig

Megatron Core OptimizerConfig with deferred post-init.

RNGConfig

Configuration settings for random number generation.

DistributedInitConfig

Configuration settings for distributed training initialization.

RerunStateMachineConfig

Configuration for the rerun state machine used for result validation or stats.

DataloaderConfig

Base configuration for data loading.

DatasetBuildContext

Interface that encapsulates framework internals.

DatasetProvider

Abstract base class for custom dataset configurations.

GPTDatasetConfig

Megatron Core GPTDatasetConfig with deferred post-init.

MockGPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

FinetuningDatasetConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

SchedulerConfig

Configuration settings for the learning rate scheduler and weight decay.

TrainingConfig

Configuration settings related to the training loop and validation.

CheckpointConfig

Configuration settings for model checkpointing (saving and loading).

LoggerConfig

Configuration settings for logging, including TensorBoard and WandB.

ProfilingConfig

Configuration settings for profiling the training process.

FaultToleranceConfig

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

StragglerDetectionConfig

Configuration settings for detecting and logging GPU stragglers.

NVRxStragglerDetectionConfig

Configuration settings for NVIDIA Resiliency Extension straggler detection.

InProcessRestartConfig

Configuration settings for NVIDIA Resiliency Extension in-process restart functionality.

ConfigContainer

Top-level container holding all configuration objects.

Functions#

runtime_config_update

Apply runtime configuration updates prior to initialization.

_validate_and_sync_distributed_optimizer_settings

Validate and synchronize distributed optimizer settings between DDP and optimizer configs.

API#

class bridge.training.config.DistributedDataParallelConfig#

Bases: megatron.core.distributed.DistributedDataParallelConfig

Megatron Core DistributedDataParallelConfig with deferred post-init.

This class inherits from Megatron Core’s DistributedDataParallelConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

finalize() None#

Execute the deferred MCore post-init logic.

This method calls the original Megatron Core DistributedDataParallelConfig.post_init() to compute derived fields based on the current field values.

class bridge.training.config.OptimizerConfig#

Bases: megatron.core.optimizer.OptimizerConfig

Megatron Core OptimizerConfig with deferred post-init.

This class inherits from Megatron Core’s OptimizerConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

finalize() None#

Execute the deferred MCore post-init logic.

This method calls the original Megatron Core OptimizerConfig.post_init() to compute derived fields based on the current field values.

class bridge.training.config.RNGConfig#

Configuration settings for random number generation.

seed: int#

1234

Random seed used for python, numpy, pytorch, and cuda.

te_rng_tracker: bool#

False

Use the Transformer Engine version of the random number generator. Required for CUDA graphs support.

inference_rng_tracker: bool#

False

Use a random number generator configured for inference.

data_parallel_random_init: bool#

False

Enable random initialization of params across data parallel ranks

class bridge.training.config.DistributedInitConfig#

Configuration settings for distributed training initialization.

distributed_backend: Literal[nccl, gloo]#

‘nccl’

Which backend to use for distributed training.

distributed_timeout_minutes: int#

10

Timeout minutes for torch.distributed.

align_grad_reduce: bool#

True

If not set, all PP stages will launch gradient reduces simultaneously. Otherwise, each PP stage will independently launch as needed.

local_rank: int#

‘field(…)’

local rank passed from distributed launcher.

lazy_init: bool#

False

If set to True, initialize_megatron() skips DDP initialization and returns function to complete it instead. Also turns on –use-cpu-initialization flag. This is for external DDP manager.

use_megatron_fsdp: bool#

False

Use Megatron’s Fully Sharded Data Parallel. Cannot be used together with use_torch_fsdp2.

use_torch_fsdp2: bool#

False

Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel. It is still not in a stable release stage, and may therefore contain bugs or other potential issues.

nccl_communicator_config_path: Optional[str]#

None

Path to the yaml file with NCCL communicator configurations. The number of min/max thread groups and thread group cluster size of each communicator can be configured by setting min_ctas, max_ctas, and cga_cluster_size.

use_tp_pp_dp_mapping: bool#

False

If set, distributed ranks initialize order is changed from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren’t used with this option enabled.

use_gloo_process_groups: bool#

True

If set, create Gloo process groups for communications.

use_sharp: bool#

False

Set the use of SHARP for the collective communications of data-parallel process groups. When True, run barrier within each data-parallel process group, which specifies the SHARP application target groups.

sharp_enabled_group: Optional[Literal[dp, dp_replica]]#

None

IB SHARP can be enabled from only one communication group. By default, it is enabled from dp group if not specified and use_sharp=True. Available options: [dp, dp_replica]

high_priority_stream_groups: Optional[list[str]]#

None

Specify which communicator groups should use high priority streams during creation. Assigning high priority to communication streams ensures that communication kernels are scheduled with higher priority, minimizing the exposed communication when it is overlapped with other computation kernels.

external_gpu_device_mapping: bool#

False

If True, indicates that GPU device mapping has been externally managed (e.g., via CUDA_VISIBLE_DEVICES environment variable). When True, uses device 0 instead of local rank for CUDA device selection. This is useful when launching with external process managers that handle GPU visibility.

enable_megatron_core_experimental: bool#

False

Enable experimental features for Megatron Core.

class bridge.training.config.RerunStateMachineConfig#

Configuration for the rerun state machine used for result validation or stats.

error_injection_rate: int#

0

Rate at which to inject unexpected results, e.g. 1000 means once every 1000 result validations

error_injection_type: Literal[correct_result, transient_error, persistent_error]#

‘transient_error’

Type of error to inject.

rerun_mode: Literal[disabled, validate_results, report_stats]#

‘disabled’

Use re-run engine to validate results (default) or to emit stats on variability of computations due to non-deterministic algorithms.

check_for_nan_in_loss: bool#

True

Check for NaN in the loss.

check_for_spiky_loss: bool#

False

Check for spiky loss.

class bridge.training.config.DataloaderConfig#

Base configuration for data loading.

dataloader_type: Optional[Literal[single, cyclic, external]]#

None

Single pass vs multiple pass data loader

num_workers: int#

8

Dataloader number of workers.

data_sharding: bool#

True

Disable data sharding.

pin_memory: bool#

True

Whether to pin memory during data loading for faster GPU training.

persistent_workers: bool#

False

Whether to keep data loading workers persistent across epochs.

class bridge.training.config.DatasetBuildContext#

Interface that encapsulates framework internals.

This context provides metadata needed to build datasets while hiding implementation details of the framework.

.. attribute:: train_samples

Number of samples for training dataset

.. attribute:: valid_samples

Number of samples for validation dataset

.. attribute:: test_samples

Number of samples for test dataset

.. attribute:: tokenizer

Optional tokenizer instance for text processing

train_samples: int#

None

valid_samples: int#

None

test_samples: int#

None

tokenizer: Optional[megatron.bridge.training.tokenizers.tokenizer.MegatronTokenizer]#

None

class bridge.training.config.DatasetProvider#

Bases: bridge.training.config.DataloaderConfig, abc.ABC

Abstract base class for custom dataset configurations.

Provides an interface for users to implement their own dataset builders while automatically inheriting all DataloaderConfig functionality.

Users must:

  1. Inherit from this class

  2. Implement the build_datasets() method

.. rubric:: Example

@dataclass class S3DatasetConfig(DatasetProvider): bucket_name: str data_prefix: str seq_length: int

def build_datasets(self, context: DatasetBuildContext) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
    # Custom implementation to load data from S3
    train_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/train", context.tokenizer)
    valid_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/valid", context.tokenizer)
    test_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/test", context.tokenizer)
    return train_ds, valid_ds, test_ds
abstractmethod build_datasets(
context: bridge.training.config.DatasetBuildContext,
) Tuple[Optional[Any], Optional[Any], Optional[Any]]#

Build train, validation, and test datasets.

This method is called by the framework during dataset initialization. Implementations should use the provided context to create appropriate datasets for each split.

Parameters:

context – Build context with sample counts and tokenizer

Returns:

Tuple of (train_dataset, valid_dataset, test_dataset) Any element can be None if that split shouldn’t be created.

Raises:

NotImplementedError – Must be implemented by subclasses

class bridge.training.config.GPTDatasetConfig#

Bases: megatron.core.datasets.gpt_dataset.GPTDatasetConfig, bridge.training.config.DataloaderConfig

Megatron Core GPTDatasetConfig with deferred post-init.

This class inherits from MCore’s GPTDatasetConfig and DataloaderConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

skip_getting_attention_mask_from_dataset: bool#

True

If set, the dataset will pass a None attention mask and the attention mask is autogenerated from the attn backend

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

finalize() None#

Execute the deferred MCore post-init logic and Bridge-specific checks.

This method calls the original Megatron Core GPTDatasetConfig.post_init() and then performs Bridge-specific validation.

class bridge.training.config.MockGPTDatasetConfig#

Bases: bridge.training.config.GPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

blend: None#

‘field(…)’

blend_per_split: None#

‘field(…)’

class bridge.training.config.FinetuningDatasetConfig#

Bases: bridge.training.config.DataloaderConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

dataset_root: Optional[Union[str, pathlib.Path]]#

None

seq_length: int#

None

seed: int#

1234

memmap_workers: int#

1

max_train_samples: Optional[int]#

None

packed_sequence_specs: Optional[megatron.bridge.data.datasets.packed_sequence.PackedSequenceSpecs]#

None

dataset_kwargs: Optional[dict[str, Any]]#

None

do_validation: bool#

True

do_test: bool#

True

class bridge.training.config.SchedulerConfig#

Configuration settings for the learning rate scheduler and weight decay.

lr_decay_style: Literal[constant, linear, cosine, inverse-square-root, WSD]#

‘linear’

Learning rate decay function.

lr_wsd_decay_style: Literal[exponential, linear, cosine]#

‘exponential’

Decay style for the annealing phase of WSD

lr_decay_iters: Optional[int]#

None

number of iterations to decay learning rate over, If None defaults to train.train_iters

lr_decay_samples: Optional[int]#

None

number of samples to decay learning rate over, If None defaults to train.train_samples

lr_wsd_decay_iters: Optional[int]#

None

number of iterations for the annealing phase in the wsd schedule

lr_wsd_decay_samples: Optional[int]#

None

number of samples for the annealing phase in the wsd schedule

lr_warmup_fraction: Optional[float]#

None

fraction of lr-warmup-(iters/samples) to use for warmup (as a float)

lr_warmup_iters: int#

0

number of iterations to linearly warmup learning rate over.

lr_warmup_samples: int#

0

number of samples to linearly warmup learning rate over.

lr_warmup_init: float#

0.0

Initial value for learning rate warmup. The scheduler starts warmup from this value.

override_opt_param_scheduler: bool#

False

Reset the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.

use_checkpoint_opt_param_scheduler: bool#

False

Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from checkpoint and ignore input arguments.

start_weight_decay: Optional[float]#

None

Initial weight decay coefficient for L2 regularization.

end_weight_decay: Optional[float]#

None

End of run weight decay coefficient for L2 regularization.

weight_decay_incr_style: Literal[constant, linear, cosine]#

‘constant’

Weight decay increment function.

lr_warmup_steps: Optional[int]#

‘field(…)’

lr_decay_steps: Optional[int]#

‘field(…)’

wd_incr_steps: Optional[int]#

‘field(…)’

wsd_decay_steps: Optional[int]#

‘field(…)’

finalize() None#

Post-initialization checks for scheduler config.

class bridge.training.config.TrainingConfig#

Configuration settings related to the training loop and validation.

micro_batch_size: Optional[int]#

None

Batch size per model instance (local batch size). Global batch size is local batch size times data parallel size times number of micro batches.

global_batch_size: Optional[int]#

None

Training batch size. If set, it should be a multiple of micro-batch-size times data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.

rampup_batch_size: Optional[list[int]]#

None

Batch size ramp up with the following values: , , For example: rampup-batch-size = [16, 8, 300000] global-batch-size 1024 will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase the batch size linearly to 1024. In each interval we will use approximately 300000 / 126 = 2380 samples.

decrease_batch_size_if_needed: bool#

False

If set, decrease batch size if microbatch_size * dp_size does not divide batch_size. Useful for KSO (Keep Soldiering On) to continue making progress if number of healthy GPUs (and corresponding dp_size) does not support current batch_size. Old batch_size will be restored if training is re-started with dp_size that divides batch_size // microbatch_size.

empty_unused_memory_level: Literal[0, 1, 2]#

0

Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. 0=off, 1=moderate, 2=aggressive.

check_weight_hash_across_dp_replicas_interval: Optional[int]#

None

Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.

train_sync_interval: Optional[int]#

None

Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.

train_iters: Optional[int]#

None

Total number of iterations to train over all training runs. Note that either train_iters or train_samples should be provided.

train_samples: Optional[int]#

None

Total number of samples to train over all training runs. Note that either train_iters or train_samples should be provided.

exit_interval: Optional[int]#

None

Exit the program after the iteration is divisible by this value.

exit_duration_in_mins: Optional[int]#

None

Exit the program after this many minutes.

exit_signal_handler: bool#

False

Dynamically save the checkpoint and shutdown the training if SIGTERM is received

exit_signal: int#

None

Signal for the signal handler to detect.

exit_signal_handler_for_dataloader: bool#

False

Use signal handler for dataloader workers

manual_gc: bool#

False

Disable the threshold-based default garbage collector and trigger the garbage collection manually. Manual garbage collection helps to align the timing of the collection across ranks which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage collection is performed only at the start and the end of the validation routine by default.

manual_gc_interval: int#

0

Training step interval to trigger manual garbage collection. When the value is set to 0, garbage collection is not triggered between training steps.

manual_gc_eval: bool#

True

When using manual garbage collection, disable garbage collection at the start and the end of each evaluation run.

eval_iters: int#

100

Number of iterations to run for evaluation validation/test for.

eval_interval: Optional[int]#

1000

Interval between running evaluation on validation set.

skip_train: bool#

False

If set, bypass the training loop, optionally do evaluation for validation/test, and exit.

finalize() None#

Validate training mode specification and calculate train_iters from train_samples if needed.

class bridge.training.config.CheckpointConfig#

Configuration settings for model checkpointing (saving and loading).

save: Optional[str]#

None

Output directory to save checkpoints to.

save_interval: Optional[int]#

None

Number of iterations between persistent checkpoint saves.

most_recent_k: Optional[int]#

None

Number of latest checkpoint to be saved.

save_optim: bool#

True

Do not save current optimizer.

save_rng: bool#

True

Do not save current rng state.

load: Optional[str]#

None

Directory containing a model checkpoint.

load_optim: bool#

True

Do not load optimizer when loading checkpoint.

load_main_params_from_ckpt: bool#

False

Load main parameters from checkpoint. When loading a model from a checkpoint without loading the optimizer, the model parameters are updated but for fp16 optimizer with main parameters, the main parameters need to also be updated.

load_rng: bool#

True

Do not load rng state when loading checkpoint.

non_persistent_save_interval: Optional[int]#

None

Number of iterations between non-persistent saves.

non_persistent_ckpt_type: Optional[typing.Literal[global, local, in_memory, None]]#

None

Type of non-persistent model checkpoints. “global” - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. “local” - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). “in_memory” - [TBD] A special kind of local checkpoint that avoids serialization. None - No non-persistent checkpointing (default option).

non_persistent_global_ckpt_dir: Optional[str]#

None

Directory containing global non-persistent model checkpoints.

non_persistent_local_ckpt_dir: Optional[str]#

None

Directory containing local non-persistent model checkpoints.

non_persistent_local_ckpt_algo: Literal[fully_parallel, atomic]#

‘fully_parallel’

Algorithm for local non-persistent checkpointing.

finetune: bool#

False

Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0. Assumed when loading a release checkpoint.

pretrained_checkpoint: Optional[str]#

None

Directory containing a pretrained model checkpoint for finetuning.

ckpt_step: Optional[int]#

None

Checkpoint step to load model from.

use_checkpoint_args: bool#

False

Override any command line arguments with arguments from the checkpoint

exit_on_missing_checkpoint: bool#

False

If ‘load’ is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization.

ckpt_format: Literal[torch_dist, zarr, fsdp_dtensor]#

‘torch_dist’

Checkpoint format to use.

ckpt_convert_format: Optional[Literal[torch, torch_dist, zarr]]#

None

Checkpoint format for conversion.

ckpt_convert_save: Optional[str]#

None

Save directory for converted checkpoint.

fully_parallel_save: bool#

True

Disable applying full save parallelization across DP for distributed checkpoints. Depending on ckpt format might decrease the number of files in the checkpoint. Makes DistributedOptimizer checkpoint non-reshardable.

async_save: bool#

False

Apply async checkpointing save. Currently works only with torch_dist distributed checkpoint format.

use_persistent_ckpt_worker: bool#

True

Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated worker thread/process for handling async saves. When disabled, uses temporal workers that are created and destroyed for each save operation.

fully_parallel_load: bool#

False

Apply full load parallelization across DP for distributed checkpoints.

ckpt_assume_constant_structure: bool#

False

Assume the checkpoint structure is constant across saves to enable optimizations.

strict_fsdp_dtensor_load: bool#

False

Whether to enforce strict loading for FSDP DTensor checkpoints. When False, allows partial loading.

dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all]#

‘assume_ok_unexpected’

Determine handling of key mismatch during checkpoint load. Check StrictHandling docs for flags meaning. NOTE: This flag controls only distributed checkpoint load from storage, not loading state dict into the model.

replication: bool#

False

If set, replication of local checkpoints is enabled. Needs to be enabled on all ranks.

replication_jump: Optional[int]#

None

Specifies J, the spacing between ranks storing replicas of a given rank’s data. Replicas for rank n may be on ranks n+J, n+2J, …, or n-J, n-2J, etc. This flag has an effect only if –replication is used. and must be consistent across all ranks.

replication_factor: int#

2

Number of machines storing the replica of a given rank’s data.

finalize() None#

Post-initialization checks for checkpoint config.

class bridge.training.config.LoggerConfig#

Configuration settings for logging, including TensorBoard and WandB.

log_interval: int#

100

Report loss and timing interval.

log_params_norm: bool#

False

If set, calculate and log parameters norm.

log_throughput: bool#

False

If set, calculate and log throughput per GPU.

log_progress: bool#

False

If set, log progress (in terms of number of processed tokens and number of floating-point operations) to progress.txt file in checkpoint directory.

timing_log_level: Literal[0, 1, 2]#

0

Granularity level to measure and report timing. 0: report only iteration time and make sure timing does not introduce extra overhead. 1: report timing for operations that are executed very limited times (basically once) during each iteration (such as gradient all-reduce) 2: report timing for operations that migh be executed numerous times during each iteration. Note that setting the level to 1 or 2 might cause increase in iteration time.

timing_log_option: Literal[max, minmax, all]#

‘minmax’

Options for logging timing: max: report the max timing across all ranks minmax: report min and max timings across all ranks all: report timings of all ranks.

tensorboard_dir: Optional[str]#

None

Write TensorBoard logs to this directory.

tensorboard_log_interval: int#

1

Report to tensorboard interval.

tensorboard_queue_size: int#

1000

Size of the tensorboard queue for pending events and summaries before one of the ‘add’ calls forces a flush to disk.

log_timers_to_tensorboard: bool#

False

If set, write timers to tensorboard.

log_loss_scale_to_tensorboard: bool#

True

Disable loss-scale logging to tensorboard.

log_validation_ppl_to_tensorboard: bool#

False

If set, write validation perplexity to tensorboard.

log_memory_to_tensorboard: bool#

False

Enable memory logging to tensorboard.

log_world_size_to_tensorboard: bool#

False

Enable world size logging to tensorboard.

wandb_project: Optional[str]#

None

The wandb project name. Ignore wandb by default.

wandb_exp_name: Optional[str]#

None

The wandb experiment name.

wandb_save_dir: Optional[str]#

None

Path to save the wandb results locally.

wandb_entity: Optional[str]#

None

The wandb entity name.

logging_level: int#

None

Set default logging level

filter_warnings: bool#

True

Filter out warning messages

modules_to_filter: Optional[list[str]]#

None

List of modules to filter out from the logs

set_level_for_all_loggers: bool#

False

Set the logging level for all loggers. If False, only level for NeMo loggers will be set.

log_energy: bool#

False

If set, log energy consumption (in Joules).

class bridge.training.config.ProfilingConfig#

Configuration settings for profiling the training process.

use_nsys_profiler: bool#

False

Enable nsys profiling. When using this option, nsys options should be specified in commandline. An example nsys commandline is nsys profile -s none -t nvtx,cuda -o <path/to/output_file> --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop.

profile_step_start: int#

10

Global step to start profiling.

profile_step_end: int#

12

Global step to stop profiling.

use_pytorch_profiler: bool#

False

Use the built-in pytorch profiler. Useful if you wish to view profiles in tensorboard.

profile_ranks: list[int]#

‘field(…)’

Global ranks to profile.

record_memory_history: bool#

False

Record memory history in last rank.

memory_snapshot_path: str#

‘snapshot.pickle’

Specifies where to dump the memory history pickle.

record_shapes: bool#

False

Record shapes of tensors.

finalize() None#

Validate profiling configuration.

class bridge.training.config.FaultToleranceConfig#

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

enable_ft_package: bool#

False

If set, Fault Tolerance package is enabled. Note: This feature is for Nvidia internal use only.

calc_ft_timeouts: bool#

False

If set, FT package will try to automatically compute the timeouts. Note: This feature is for Nvidia internal use only.

simulate_fault: bool#

False

Sets a simulated fault for fault tolerance. NOTE: This if for fault tolerance testing only.

simulated_fault_type: Literal[rank_hung, rank_killed, random]#

‘random’

How the simulated fault should behave. ‘random’ will randomly choose one of the other two options.

simulated_fault_rank: Optional[int]#

None

Rank on which simulated fault should occur.

simulated_fault_base_delay: int#

0

Base delay before simulated fault thread is started. A small random delay is added to this.

class bridge.training.config.StragglerDetectionConfig#

Configuration settings for detecting and logging GPU stragglers.

log_straggler: bool#

False

If set, tracks and logs straggler per GPU.

enable_straggler_on_startup: bool#

True

If set, StragglerDetector is disabled on startup.

straggler_ctrlr_port: int#

65535

Port number to toggle StragglerDetector on/off at runtime

straggler_minmax_count: int#

1

Number of ranks to report with high/low estimated throughput

disable_straggler_on_startup: bool#

False

If set, StragglerDetector is disabled on startup.

class bridge.training.config.NVRxStragglerDetectionConfig#

Configuration settings for NVIDIA Resiliency Extension straggler detection.

enabled: bool#

False

Enable NVRx straggler detection.

report_time_interval: float#

300.0

Interval [seconds] of the straggler check.

calc_relative_gpu_perf: bool#

True

Calculate relative GPU performance scores.

calc_individual_gpu_perf: bool#

True

Calculate individual GPU performance scores.

num_gpu_perf_scores_to_print: int#

5

How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected).

gpu_relative_perf_threshold: float#

0.7

Threshold for relative GPU performance scores.

gpu_individual_perf_threshold: float#

0.7

Threshold for individual GPU performance scores.

stop_if_detected: bool#

False

Set to True, to terminate the workload if stragglers are detected.

enable_logging: bool#

True

Set to True, to log GPU performance scores.

profiling_interval: int#

1

Profiling interval passed to straggler.Detector.initialize.

logger_name: str#

‘megatron.bridge.NVRxStragglerDetection’

Logger name for straggler detection messages.

finalize() None#

Validate NVRx straggler detection configuration.

class bridge.training.config.InProcessRestartConfig#

Configuration settings for NVIDIA Resiliency Extension in-process restart functionality.

enabled: bool#

False

Enable in-process restart mechanism from nvidia-resiliency-ext.

max_iterations: Optional[int]#

None

Maximum number of in-process restart iterations.

monitor_thread_interval: float#

1.0

Monitoring interval (in seconds) for the monitoring thread.

monitor_process_interval: float#

1.0

Monitoring interval (in seconds) for the monitoring process.

progress_watchdog_interval: float#

1.0

Interval (in seconds) for automatic progress watchdog timestamp updates.

heartbeat_interval: float#

30.0

Monitoring interval (in seconds) for detecting unresponsive ranks.

soft_timeout: float#

60.0

Soft progress timeout (in seconds).

hard_timeout: float#

90.0

Hard progress timeout (in seconds).

heartbeat_timeout: float#

60.0

Timeout (in seconds) for a missing rank detection heartbeat.

barrier_timeout: float#

120.0

Timeout (in seconds) for internal distributed barrier.

completion_timeout: float#

120.0

Timeout (in seconds) for barrier on completion on all ranks.

last_call_wait: float#

1.0

Time interval (in seconds) for other ranks to report concurrent terminal failures.

termination_grace_time: float#

1.0

Interval (in seconds) between SIGTERM and SIGKILL issued on hard timeout.

granularity: Literal[node, rank]#

‘node’

Granularity for in-process restart.

active_world_size: Optional[int]#

None

The number of ranks initially executing the workload. The remaining ranks from the allocation are set aside as warm reserve. If None, defaults to WORLD_SIZE environment variable.

empty_cuda_cache: bool#

True

Empty CUDA cache during restart finalization.

max_rank_faults: Optional[int]#

None

Maximum number of rank faults allowed before terminating the job.

monitor_process_logdir: Optional[str]#

None

Directory for monitor process log files. If None, monitor process logging is disabled.

class bridge.training.config.ConfigContainer#

Bases: megatron.bridge.training.utils.config_utils._ConfigContainerBase

Top-level container holding all configuration objects.

rng: bridge.training.config.RNGConfig#

‘field(…)’

rerun_state_machine: bridge.training.config.RerunStateMachineConfig#

‘field(…)’

train: bridge.training.config.TrainingConfig#

None

model: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider | megatron.bridge.models.mamba.mamba_provider.MambaModelProvider#

None

optimizer: bridge.training.config.OptimizerConfig#

None

ddp: bridge.training.config.DistributedDataParallelConfig#

‘field(…)’

scheduler: bridge.training.config.SchedulerConfig#

None

dataset: bridge.training.config.GPTDatasetConfig | bridge.training.config.FinetuningDatasetConfig | bridge.training.config.DatasetProvider#

None

logger: bridge.training.config.LoggerConfig#

None

tokenizer: megatron.bridge.training.tokenizers.config.TokenizerConfig#

None

checkpoint: bridge.training.config.CheckpointConfig#

None

dist: bridge.training.config.DistributedInitConfig#

‘field(…)’

ft: Optional[bridge.training.config.FaultToleranceConfig]#

None

straggler: Optional[bridge.training.config.StragglerDetectionConfig]#

None

nvrx_straggler: Optional[bridge.training.config.NVRxStragglerDetectionConfig]#

None

profiling: Optional[bridge.training.config.ProfilingConfig]#

None

peft: Optional[megatron.bridge.peft.base.PEFT]#

None

comm_overlap: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig]#

None

mixed_precision: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]]#

None

inprocess_restart: Optional[bridge.training.config.InProcessRestartConfig]#

None

get_data_parallel_size(world_size: int) int#

Calculate the data parallel size based on the model configuration.

set_data_parallel_size() None#

Calculate and set data_parallel_size for this config and comm_overlap config.

This method calculates the data parallel size needed by setup methods, without triggering full validation or finalization of Megatron Core configs.

_sync_and_validate_external_cuda_graph() None#

Sync necessary configs for external CUDA Graphs and and validates it.

validate() None#

Performs validation checks on the combined configuration.

Calculates dependent values like data_parallel_size and scheduler steps. Ensures compatibility between different configuration settings.

_validate_training_scheduler_compatibility() None#

Cross-validation between training and scheduler configs.

_calculate_scheduler_steps() None#

Calculate scheduler steps for both iteration-based and sample-based training.

bridge.training.config.runtime_config_update(
cfg: bridge.training.config.ConfigContainer,
) None#

Apply runtime configuration updates prior to initialization.

This function handles all configuration modifications that need to happen after initial config creation but before final validation and model setup.

Steps:

  1. Resolve mixed precision configuration from string if needed

  2. Apply mixed precision settings to model, optimizer, and DDP configs

  3. Calculate data parallel size (needed for comm overlap)

  4. Apply communication overlap configuration

  5. Validate configuration after all modifications

Parameters:

cfg – Configuration container to update

bridge.training.config._validate_and_sync_distributed_optimizer_settings(
config: bridge.training.config.ConfigContainer,
) None#

Validate and synchronize distributed optimizer settings between DDP and optimizer configs.

This function ensures that distributed optimizer settings are consistent across DDP and optimizer configurations. If either setting is enabled, both will be enabled to maintain consistency.

Parameters:

config – The configuration container to validate and potentially modify.