bridge.training.config
#
Module Contents#
Classes#
Configuration settings for random number generation. |
|
Configuration settings for distributed training initialization. |
|
Configuration for the rerun state machine used for result validation or stats. |
|
Base configuration for data loading. |
|
Configuration specific to GPT datasets, inheriting from MCore and base DataloaderConfig. |
|
Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset. |
|
Configuration specific to finetuning datasets, inheriting from DataloaderConfig. |
|
Configuration settings for the learning rate scheduler and weight decay. |
|
Configuration settings related to the training loop and validation. |
|
Configuration settings for model checkpointing (saving and loading). |
|
Configuration settings for logging, including TensorBoard and WandB. |
|
Configuration settings for profiling the training process. |
|
Configuration settings related to fault tolerance mechanisms (NVIDIA internal use). |
|
Configuration settings for detecting and logging GPU stragglers. |
|
Configuration settings for NVIDIA Resiliency Extension straggler detection. |
|
Top-level container holding all configuration objects. |
API#
- class bridge.training.config.RNGConfig#
Configuration settings for random number generation.
- seed: int#
1234
Random seed used for python, numpy, pytorch, and cuda.
- te_rng_tracker: bool#
False
Use the Transformer Engine version of the random number generator. Required for CUDA graphs support.
- inference_rng_tracker: bool#
False
Use a random number generator configured for inference.
- data_parallel_random_init: bool#
False
Enable random initialization of params across data parallel ranks
- class bridge.training.config.DistributedInitConfig#
Configuration settings for distributed training initialization.
- distributed_backend: Literal[nccl, gloo]#
‘nccl’
Which backend to use for distributed training.
- distributed_timeout_minutes: int#
10
Timeout minutes for torch.distributed.
- align_grad_reduce: bool#
True
If not set, all PP stages will launch gradient reduces simultaneously. Otherwise, each PP stage will independently launch as needed.
- local_rank: int#
‘field(…)’
local rank passed from distributed launcher.
- lazy_init: bool#
False
If set to True, initialize_megatron() skips DDP initialization and returns function to complete it instead. Also turns on –use-cpu-initialization flag. This is for external DDP manager.
- use_torch_fsdp2: bool#
False
Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel. It is still not in a stable release stage, and may therefore contain bugs or other potential issues.
- nccl_communicator_config_path: Optional[str]#
None
Path to the yaml file with NCCL communicator configurations. The number of min/max thread groups and thread group cluster size of each communicator can be configured by setting
min_ctas
,max_ctas
, andcga_cluster_size
.
- use_tp_pp_dp_mapping: bool#
False
If set, distributed ranks initialize order is changed from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren’t used with this option enabled.
- use_gloo_process_groups: bool#
True
If set, create Gloo process groups for communications.
- use_sharp: bool#
False
Set the use of SHARP for the collective communications of data-parallel process groups. When
True
, run barrier within each data-parallel process group, which specifies the SHARP application target groups.
- high_priority_stream_groups: Optional[list[str]]#
None
Specify which communicator groups should use high priority streams during creation. Assigning high priority to communication streams ensures that communication kernels are scheduled with higher priority, minimizing the exposed communication when it is overlapped with other computation kernels.
- external_gpu_device_mapping: bool#
False
If True, indicates that GPU device mapping has been externally managed (e.g., via CUDA_VISIBLE_DEVICES environment variable). When True, uses device 0 instead of local rank for CUDA device selection. This is useful when launching with external process managers that handle GPU visibility.
- enable_megatron_core_experimental: bool#
False
Enable experimental features for Megatron Core.
- class bridge.training.config.RerunStateMachineConfig#
Configuration for the rerun state machine used for result validation or stats.
- error_injection_rate: int#
0
Rate at which to inject unexpected results, e.g. 1000 means once every 1000 result validations
- error_injection_type: Literal[correct_result, transient_error, persistent_error]#
‘transient_error’
Type of error to inject.
- rerun_mode: Literal[disabled, validate_results, report_stats]#
‘disabled’
Use re-run engine to validate results (default) or to emit stats on variability of computations due to non-deterministic algorithms.
- class bridge.training.config.DataloaderConfig#
Base configuration for data loading.
- dataloader_type: Optional[Literal[single, cyclic, external]]#
None
Single pass vs multiple pass data loader
- num_workers: int#
8
Dataloader number of workers.
- data_sharding: bool#
True
Disable data sharding.
- pin_memory: bool#
True
Whether to pin memory during data loading for faster GPU training.
- persistent_workers: bool#
False
Whether to keep data loading workers persistent across epochs.
- class bridge.training.config.GPTDatasetConfig#
Bases:
megatron.core.datasets.gpt_dataset.GPTDatasetConfig
,bridge.training.config.DataloaderConfig
Configuration specific to GPT datasets, inheriting from MCore and base DataloaderConfig.
- __post_init__() None #
Post-initialization checks for GPT dataset config.
- class bridge.training.config.MockGPTDatasetConfig#
Bases:
bridge.training.config.GPTDatasetConfig
Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.
- blend: None#
‘field(…)’
- blend_per_split: None#
‘field(…)’
- class bridge.training.config.FinetuningDatasetConfig#
Bases:
bridge.training.config.DataloaderConfig
Configuration specific to finetuning datasets, inheriting from DataloaderConfig.
- dataset_root: Optional[Union[str, pathlib.Path]]#
None
- seq_length: int#
None
- seed: int#
1234
- memmap_workers: int#
1
- max_train_samples: Optional[int]#
None
- packed_sequence_specs: Optional[megatron.bridge.data.datasets.packed_sequence.PackedSequenceSpecs]#
None
- dataset_kwargs: Optional[dict[str, Any]]#
None
- do_validation: bool#
True
- do_test: bool#
True
- class bridge.training.config.SchedulerConfig#
Configuration settings for the learning rate scheduler and weight decay.
- lr_decay_style: Literal[constant, linear, cosine, inverse-square-root, WSD]#
‘linear’
Learning rate decay function.
- lr_wsd_decay_style: Literal[exponential, linear, cosine]#
‘exponential’
Decay style for the annealing phase of WSD
- lr_decay_iters: Optional[int]#
None
number of iterations to decay learning rate over, If None defaults to
--train-iters
- lr_wsd_decay_iters: Optional[int]#
None
number of iterations for the annealing phase in the wsd schedule
- lr_warmup_fraction: Optional[float]#
None
fraction of lr-warmup-(iters/samples) to use for warmup (as a float)
- lr_warmup_iters: int#
0
number of iterations to linearly warmup learning rate over.
- lr_warmup_init: float#
0.0
Initial value for learning rate warmup. The scheduler starts warmup from this value.
- override_opt_param_scheduler: bool#
False
Reset the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.
- use_checkpoint_opt_param_scheduler: bool#
False
Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from checkpoint and ignore input arguments.
- start_weight_decay: Optional[float]#
None
Initial weight decay coefficient for L2 regularization.
- end_weight_decay: Optional[float]#
None
End of run weight decay coefficient for L2 regularization.
- weight_decay_incr_style: Literal[constant, linear, cosine]#
‘constant’
Weight decay increment function.
- lr_warmup_steps: Optional[int]#
‘field(…)’
- lr_decay_steps: Optional[int]#
‘field(…)’
- wd_incr_steps: Optional[int]#
‘field(…)’
- wsd_decay_steps: Optional[int]#
‘field(…)’
- __post_init__()#
Post-initialization checks for scheduler config.
- class bridge.training.config.TrainingConfig#
Configuration settings related to the training loop and validation.
- micro_batch_size: Optional[int]#
None
Batch size per model instance (local batch size). Global batch size is local batch size times data parallel size times number of micro batches.
- global_batch_size: Optional[int]#
None
Training batch size. If set, it should be a multiple of micro-batch-size times data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.
- rampup_batch_size: Optional[list[int]]#
None
Batch size ramp up with the following values:
, , For example: rampup-batch-size = [16, 8, 300000] global-batch-size 1024 will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase the batch size linearly to 1024. In each interval we will use approximately 300000 / 126 = 2380 samples.
- decrease_batch_size_if_needed: bool#
False
If set, decrease batch size if microbatch_size * dp_size does not divide batch_size. Useful for KSO (Keep Soldiering On) to continue making progress if number of healthy GPUs (and corresponding dp_size) does not support current batch_size. Old batch_size will be restored if training is re-started with dp_size that divides batch_size // microbatch_size.
- empty_unused_memory_level: Literal[0, 1, 2]#
0
Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. 0=off, 1=moderate, 2=aggressive.
- check_weight_hash_across_dp_replicas_interval: Optional[int]#
None
Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.
- train_sync_interval: Optional[int]#
None
Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.
- train_iters: Optional[int]#
None
Total number of iterations to train over all training runs. Note that either train-iters or train-samples should be provided.
- exit_interval: Optional[int]#
None
Exit the program after the iteration is divisible by this value.
- exit_duration_in_mins: Optional[int]#
None
Exit the program after this many minutes.
- exit_signal_handler: bool#
False
Dynamically save the checkpoint and shutdown the training if SIGTERM is received
- exit_signal: int#
None
Signal for the signal handler to detect.
- exit_signal_handler_for_dataloader: bool#
False
Use signal handler for dataloader workers
- manual_gc: bool#
False
Disable the threshold-based default garbage collector and trigger the garbage collection manually. Manual garbage collection helps to align the timing of the collection across ranks which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage collection is performed only at the start and the end of the validation routine by default.
- manual_gc_interval: int#
0
Training step interval to trigger manual garbage collection. When the value is set to 0, garbage collection is not triggered between training steps.
- manual_gc_eval: bool#
True
When using manual garbage collection, disable garbage collection at the start and the end of each evaluation run.
- eval_iters: int#
100
Number of iterations to run for evaluation validation/test for.
- eval_interval: Optional[int]#
1000
Interval between running evaluation on validation set.
- skip_train: bool#
False
If set, bypass the training loop, optionally do evaluation for validation/test, and exit.
- class bridge.training.config.CheckpointConfig#
Configuration settings for model checkpointing (saving and loading).
- save: Optional[str]#
None
Output directory to save checkpoints to.
- save_interval: Optional[int]#
None
Number of iterations between persistent checkpoint saves.
- save_optim: bool#
True
Do not save current optimizer.
- save_rng: bool#
True
Do not save current rng state.
- load: Optional[str]#
None
Directory containing a model checkpoint.
- load_optim: bool#
True
Do not load optimizer when loading checkpoint.
- load_main_params_from_ckpt: bool#
False
Load main parameters from checkpoint. When loading a model from a checkpoint without loading the optimizer, the model parameters are updated but for fp16 optimizer with main parameters, the main parameters need to also be updated.
- load_rng: bool#
True
Do not load rng state when loading checkpoint.
- non_persistent_save_interval: Optional[int]#
None
Number of iterations between non-persistent saves.
- non_persistent_ckpt_type: Optional[typing.Literal[global, local, in_memory, None]]#
None
Type of non-persistent model checkpoints. “global” - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. “local” - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). “in_memory” - [TBD] A special kind of local checkpoint that avoids serialization. None - No non-persistent checkpointing (default option).
- non_persistent_global_ckpt_dir: Optional[str]#
None
Directory containing global non-persistent model checkpoints.
- non_persistent_local_ckpt_dir: Optional[str]#
None
Directory containing local non-persistent model checkpoints.
- non_persistent_local_ckpt_algo: Literal[fully_parallel, atomic]#
‘fully_parallel’
Algorithm for local non-persistent checkpointing.
- finetune: bool#
False
Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0. Assumed when loading a release checkpoint.
- pretrained_checkpoint: Optional[str]#
None
Directory containing a pretrained model checkpoint for finetuning.
- ckpt_step: Optional[int]#
None
Checkpoint step to load model from.
- use_checkpoint_args: bool#
False
Override any command line arguments with arguments from the checkpoint
- exit_on_missing_checkpoint: bool#
False
If ‘load’ is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization.
- ckpt_format: Literal[torch_dist, zarr]#
‘torch_dist’
Checkpoint format to use.
- ckpt_convert_format: Optional[Literal[torch, torch_dist, zarr]]#
None
Checkpoint format for conversion.
- ckpt_convert_save: Optional[str]#
None
Save directory for converted checkpoint.
- fully_parallel_save: bool#
True
Disable applying full save parallelization across DP for distributed checkpoints. Depending on ckpt format might decrease the number of files in the checkpoint. Makes DistributedOptimizer checkpoint non-reshardable.
- async_save: bool#
False
Apply async checkpointing save. Currently works only with
torch_dist
distributed checkpoint format.
- use_persistent_ckpt_worker: bool#
True
Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated worker thread/process for handling async saves. When disabled, uses temporal workers that are created and destroyed for each save operation.
- fully_parallel_load: bool#
False
Apply full load parallelization across DP for distributed checkpoints.
- ckpt_assume_constant_structure: bool#
False
If the model and optimizer state dict structure is constant throughout a *single training job, it allows for different checkpointing performance optimizations.
- dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all]#
‘assume_ok_unexpected’
Determine handling of key mismatch during checkpoint load. Check StrictHandling docs for flags meaning. NOTE: This flag controls only distributed checkpoint load from storage, not loading state dict into the model.
- replication: bool#
False
If set, replication of local checkpoints is enabled. Needs to be enabled on all ranks.
- replication_jump: Optional[int]#
None
Specifies
J
, the spacing between ranks storing replicas of a given rank’s data. Replicas for rankn
may be on ranksn+J
,n+2J
, …, orn-J
,n-2J
, etc. This flag has an effect only if –replication is used. and must be consistent across all ranks.
- replication_factor: int#
2
Number of machines storing the replica of a given rank’s data.
- __post_init__() None #
Post-initialization checks for checkpoint config.
- class bridge.training.config.LoggerConfig#
Configuration settings for logging, including TensorBoard and WandB.
- log_interval: int#
100
Report loss and timing interval.
- log_params_norm: bool#
False
If set, calculate and log parameters norm.
- log_throughput: bool#
False
If set, calculate and log throughput per GPU.
- log_progress: bool#
False
If set, log progress (in terms of number of processed tokens and number of floating-point operations) to progress.txt file in checkpoint directory.
- timing_log_level: Literal[0, 1, 2]#
0
Granularity level to measure and report timing. 0: report only iteration time and make sure timing does not introduce extra overhead. 1: report timing for operations that are executed very limited times (basically once) during each iteration (such as gradient all-reduce) 2: report timing for operations that migh be executed numerous times during each iteration. Note that setting the level to 1 or 2 might cause increase in iteration time.
- timing_log_option: Literal[max, minmax, all]#
‘minmax’
Options for logging timing: max: report the max timing across all ranks minmax: report min and max timings across all ranks all: report timings of all ranks.
- tensorboard_dir: Optional[str]#
None
Write TensorBoard logs to this directory.
- tensorboard_log_interval: int#
1
Report to tensorboard interval.
- tensorboard_queue_size: int#
1000
Size of the tensorboard queue for pending events and summaries before one of the ‘add’ calls forces a flush to disk.
- log_timers_to_tensorboard: bool#
False
If set, write timers to tensorboard.
- log_loss_scale_to_tensorboard: bool#
True
Disable loss-scale logging to tensorboard.
- log_validation_ppl_to_tensorboard: bool#
False
If set, write validation perplexity to tensorboard.
- log_memory_to_tensorboard: bool#
False
Enable memory logging to tensorboard.
- log_world_size_to_tensorboard: bool#
False
Enable world size logging to tensorboard.
- wandb_project: Optional[str]#
None
The wandb project name. Ignore wandb by default.
- wandb_exp_name: Optional[str]#
None
The wandb experiment name.
- wandb_save_dir: Optional[str]#
None
Path to save the wandb results locally.
- wandb_entity: Optional[str]#
None
The wandb entity name.
- logging_level: int#
None
Set default logging level
- filter_warnings: bool#
True
Filter out warning messages
- modules_to_filter: Optional[list[str]]#
None
List of modules to filter out from the logs
- set_level_for_all_loggers: bool#
False
Set the logging level for all loggers. If False, only level for NeMo loggers will be set.
- log_energy: bool#
False
If set, log energy consumption (in Joules).
- class bridge.training.config.ProfilingConfig#
Configuration settings for profiling the training process.
- use_nsys_profiler: bool#
False
Enable nsys profiling. When using this option, nsys options should be specified in commandline. An example nsys commandline is
nsys profile -s none -t nvtx,cuda -o <path/to/output_file> --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop
.
- profile_step_start: int#
10
Global step to start profiling.
- profile_step_end: int#
12
Global step to stop profiling.
- use_pytorch_profiler: bool#
False
Use the built-in pytorch profiler. Useful if you wish to view profiles in tensorboard.
- profile_ranks: list[int]#
‘field(…)’
Global ranks to profile.
- record_memory_history: bool#
False
Record memory history in last rank.
- memory_snapshot_path: str#
‘snapshot.pickle’
Specifies where to dump the memory history pickle.
- record_shapes: bool#
False
Record shapes of tensors.
- __post_init__() None #
Validate profiling configuration.
- class bridge.training.config.FaultToleranceConfig#
Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).
- enable_ft_package: bool#
False
If set, Fault Tolerance package is enabled. Note: This feature is for Nvidia internal use only.
- calc_ft_timeouts: bool#
False
If set, FT package will try to automatically compute the timeouts. Note: This feature is for Nvidia internal use only.
- simulate_fault: bool#
False
Sets a simulated fault for fault tolerance. NOTE: This if for fault tolerance testing only.
- simulated_fault_type: Literal[rank_hung, rank_killed, random]#
‘random’
How the simulated fault should behave. ‘random’ will randomly choose one of the other two options.
- simulated_fault_rank: Optional[int]#
None
Rank on which simulated fault should occur.
- simulated_fault_base_delay: int#
0
Base delay before simulated fault thread is started. A small random delay is added to this.
- class bridge.training.config.StragglerDetectionConfig#
Configuration settings for detecting and logging GPU stragglers.
- log_straggler: bool#
False
If set, tracks and logs straggler per GPU.
- enable_straggler_on_startup: bool#
True
If set, StragglerDetector is disabled on startup.
- straggler_ctrlr_port: int#
65535
Port number to toggle StragglerDetector on/off at runtime
- straggler_minmax_count: int#
1
Number of ranks to report with high/low estimated throughput
- disable_straggler_on_startup: bool#
False
If set, StragglerDetector is disabled on startup.
- class bridge.training.config.NVRxStragglerDetectionConfig#
Configuration settings for NVIDIA Resiliency Extension straggler detection.
- enabled: bool#
False
Enable NVRx straggler detection.
- report_time_interval: float#
300.0
Interval [seconds] of the straggler check.
- calc_relative_gpu_perf: bool#
True
Calculate relative GPU performance scores.
- calc_individual_gpu_perf: bool#
True
Calculate individual GPU performance scores.
- num_gpu_perf_scores_to_print: int#
5
How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected).
- gpu_relative_perf_threshold: float#
0.7
Threshold for relative GPU performance scores.
- gpu_individual_perf_threshold: float#
0.7
Threshold for individual GPU performance scores.
- stop_if_detected: bool#
False
Set to True, to terminate the workload if stragglers are detected.
- enable_logging: bool#
True
Set to True, to log GPU performance scores.
- profiling_interval: int#
1
Profiling interval passed to straggler.Detector.initialize.
- logger_name: str#
‘megatron_hub.NVRxStragglerDetection’
Logger name for straggler detection messages.
- __post_init__() None #
Validate NVRx straggler detection configuration.
- class bridge.training.config.ConfigContainer#
Bases:
megatron.bridge.training.utils.config_utils._ConfigContainerBase
Top-level container holding all configuration objects.
- rng: bridge.training.config.RNGConfig#
‘field(…)’
- rerun_state_machine: bridge.training.config.RerunStateMachineConfig#
‘field(…)’
- train: bridge.training.config.TrainingConfig#
None
- model: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider | megatron.bridge.models.mamba.mamba_provider.MambaProvider#
None
- optimizer: megatron.core.optimizer.OptimizerConfig#
None
- ddp: megatron.core.distributed.DistributedDataParallelConfig#
‘field(…)’
- scheduler: bridge.training.config.SchedulerConfig#
None
- dataset: bridge.training.config.GPTDatasetConfig | bridge.training.config.FinetuningDatasetConfig#
None
- logger: bridge.training.config.LoggerConfig#
None
- tokenizer: megatron.bridge.training.tokenizers.config.TokenizerConfig#
None
- checkpoint: bridge.training.config.CheckpointConfig#
None
- dist: bridge.training.config.DistributedInitConfig#
‘field(…)’
- ft: Optional[bridge.training.config.FaultToleranceConfig]#
None
- straggler: Optional[bridge.training.config.StragglerDetectionConfig]#
None
- nvrx_straggler: Optional[bridge.training.config.NVRxStragglerDetectionConfig]#
None
- profiling: Optional[bridge.training.config.ProfilingConfig]#
None
- peft: Optional[megatron.bridge.peft.base.PEFT]#
None
- comm_overlap: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig]#
None
- mixed_precision: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]]#
None
- get_data_parallel_size(world_size: int) int #
Calculate the data parallel size based on the model configuration.
- validate() None #
Performs validation checks on the combined configuration.
Calculates dependent values like data_parallel_size and scheduler steps. Ensures compatibility between different configuration settings.