bridge.training.config#

Module Contents#

Classes#

RNGConfig

Configuration settings for random number generation.

DistributedInitConfig

Configuration settings for distributed training initialization.

RerunStateMachineConfig

Configuration for the rerun state machine used for result validation or stats.

DataloaderConfig

Base configuration for data loading.

GPTDatasetConfig

Configuration specific to GPT datasets, inheriting from MCore and base DataloaderConfig.

MockGPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

FinetuningDatasetConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

SchedulerConfig

Configuration settings for the learning rate scheduler and weight decay.

TrainingConfig

Configuration settings related to the training loop and validation.

CheckpointConfig

Configuration settings for model checkpointing (saving and loading).

LoggerConfig

Configuration settings for logging, including TensorBoard and WandB.

ProfilingConfig

Configuration settings for profiling the training process.

FaultToleranceConfig

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

StragglerDetectionConfig

Configuration settings for detecting and logging GPU stragglers.

NVRxStragglerDetectionConfig

Configuration settings for NVIDIA Resiliency Extension straggler detection.

ConfigContainer

Top-level container holding all configuration objects.

API#

class bridge.training.config.RNGConfig#

Configuration settings for random number generation.

seed: int#

1234

Random seed used for python, numpy, pytorch, and cuda.

te_rng_tracker: bool#

False

Use the Transformer Engine version of the random number generator. Required for CUDA graphs support.

inference_rng_tracker: bool#

False

Use a random number generator configured for inference.

data_parallel_random_init: bool#

False

Enable random initialization of params across data parallel ranks

class bridge.training.config.DistributedInitConfig#

Configuration settings for distributed training initialization.

distributed_backend: Literal[nccl, gloo]#

‘nccl’

Which backend to use for distributed training.

distributed_timeout_minutes: int#

10

Timeout minutes for torch.distributed.

align_grad_reduce: bool#

True

If not set, all PP stages will launch gradient reduces simultaneously. Otherwise, each PP stage will independently launch as needed.

local_rank: int#

‘field(…)’

local rank passed from distributed launcher.

lazy_init: bool#

False

If set to True, initialize_megatron() skips DDP initialization and returns function to complete it instead. Also turns on –use-cpu-initialization flag. This is for external DDP manager.

use_torch_fsdp2: bool#

False

Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel. It is still not in a stable release stage, and may therefore contain bugs or other potential issues.

nccl_communicator_config_path: Optional[str]#

None

Path to the yaml file with NCCL communicator configurations. The number of min/max thread groups and thread group cluster size of each communicator can be configured by setting min_ctas, max_ctas, and cga_cluster_size.

use_tp_pp_dp_mapping: bool#

False

If set, distributed ranks initialize order is changed from tp-dp-pp to tp-pp-dp. Make sure EP and CP aren’t used with this option enabled.

use_gloo_process_groups: bool#

True

If set, create Gloo process groups for communications.

use_sharp: bool#

False

Set the use of SHARP for the collective communications of data-parallel process groups. When True, run barrier within each data-parallel process group, which specifies the SHARP application target groups.

high_priority_stream_groups: Optional[list[str]]#

None

Specify which communicator groups should use high priority streams during creation. Assigning high priority to communication streams ensures that communication kernels are scheduled with higher priority, minimizing the exposed communication when it is overlapped with other computation kernels.

external_gpu_device_mapping: bool#

False

If True, indicates that GPU device mapping has been externally managed (e.g., via CUDA_VISIBLE_DEVICES environment variable). When True, uses device 0 instead of local rank for CUDA device selection. This is useful when launching with external process managers that handle GPU visibility.

enable_megatron_core_experimental: bool#

False

Enable experimental features for Megatron Core.

class bridge.training.config.RerunStateMachineConfig#

Configuration for the rerun state machine used for result validation or stats.

error_injection_rate: int#

0

Rate at which to inject unexpected results, e.g. 1000 means once every 1000 result validations

error_injection_type: Literal[correct_result, transient_error, persistent_error]#

‘transient_error’

Type of error to inject.

rerun_mode: Literal[disabled, validate_results, report_stats]#

‘disabled’

Use re-run engine to validate results (default) or to emit stats on variability of computations due to non-deterministic algorithms.

class bridge.training.config.DataloaderConfig#

Base configuration for data loading.

dataloader_type: Optional[Literal[single, cyclic, external]]#

None

Single pass vs multiple pass data loader

num_workers: int#

8

Dataloader number of workers.

data_sharding: bool#

True

Disable data sharding.

pin_memory: bool#

True

Whether to pin memory during data loading for faster GPU training.

persistent_workers: bool#

False

Whether to keep data loading workers persistent across epochs.

class bridge.training.config.GPTDatasetConfig#

Bases: megatron.core.datasets.gpt_dataset.GPTDatasetConfig, bridge.training.config.DataloaderConfig

Configuration specific to GPT datasets, inheriting from MCore and base DataloaderConfig.

__post_init__() None#

Post-initialization checks for GPT dataset config.

class bridge.training.config.MockGPTDatasetConfig#

Bases: bridge.training.config.GPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

blend: None#

‘field(…)’

blend_per_split: None#

‘field(…)’

class bridge.training.config.FinetuningDatasetConfig#

Bases: bridge.training.config.DataloaderConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

dataset_root: Optional[Union[str, pathlib.Path]]#

None

seq_length: int#

None

seed: int#

1234

memmap_workers: int#

1

max_train_samples: Optional[int]#

None

packed_sequence_specs: Optional[megatron.bridge.data.datasets.packed_sequence.PackedSequenceSpecs]#

None

dataset_kwargs: Optional[dict[str, Any]]#

None

do_validation: bool#

True

do_test: bool#

True

class bridge.training.config.SchedulerConfig#

Configuration settings for the learning rate scheduler and weight decay.

lr_decay_style: Literal[constant, linear, cosine, inverse-square-root, WSD]#

‘linear’

Learning rate decay function.

lr_wsd_decay_style: Literal[exponential, linear, cosine]#

‘exponential’

Decay style for the annealing phase of WSD

lr_decay_iters: Optional[int]#

None

number of iterations to decay learning rate over, If None defaults to --train-iters

lr_wsd_decay_iters: Optional[int]#

None

number of iterations for the annealing phase in the wsd schedule

lr_warmup_fraction: Optional[float]#

None

fraction of lr-warmup-(iters/samples) to use for warmup (as a float)

lr_warmup_iters: int#

0

number of iterations to linearly warmup learning rate over.

lr_warmup_init: float#

0.0

Initial value for learning rate warmup. The scheduler starts warmup from this value.

override_opt_param_scheduler: bool#

False

Reset the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from input arguments and ignore values from checkpoints. Note that all the above values will be reset.

use_checkpoint_opt_param_scheduler: bool#

False

Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, minimum learning rate, maximum number of iterations, and decay style from checkpoint and ignore input arguments.

start_weight_decay: Optional[float]#

None

Initial weight decay coefficient for L2 regularization.

end_weight_decay: Optional[float]#

None

End of run weight decay coefficient for L2 regularization.

weight_decay_incr_style: Literal[constant, linear, cosine]#

‘constant’

Weight decay increment function.

lr_warmup_steps: Optional[int]#

‘field(…)’

lr_decay_steps: Optional[int]#

‘field(…)’

wd_incr_steps: Optional[int]#

‘field(…)’

wsd_decay_steps: Optional[int]#

‘field(…)’

__post_init__()#

Post-initialization checks for scheduler config.

class bridge.training.config.TrainingConfig#

Configuration settings related to the training loop and validation.

micro_batch_size: Optional[int]#

None

Batch size per model instance (local batch size). Global batch size is local batch size times data parallel size times number of micro batches.

global_batch_size: Optional[int]#

None

Training batch size. If set, it should be a multiple of micro-batch-size times data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size as the global batch size. This choice will result in 1 for number of micro-batches.

rampup_batch_size: Optional[list[int]]#

None

Batch size ramp up with the following values: , , For example: rampup-batch-size = [16, 8, 300000] global-batch-size 1024 will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase the batch size linearly to 1024. In each interval we will use approximately 300000 / 126 = 2380 samples.

decrease_batch_size_if_needed: bool#

False

If set, decrease batch size if microbatch_size * dp_size does not divide batch_size. Useful for KSO (Keep Soldiering On) to continue making progress if number of healthy GPUs (and corresponding dp_size) does not support current batch_size. Old batch_size will be restored if training is re-started with dp_size that divides batch_size // microbatch_size.

empty_unused_memory_level: Literal[0, 1, 2]#

0

Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. 0=off, 1=moderate, 2=aggressive.

check_weight_hash_across_dp_replicas_interval: Optional[int]#

None

Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.

train_sync_interval: Optional[int]#

None

Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.

train_iters: Optional[int]#

None

Total number of iterations to train over all training runs. Note that either train-iters or train-samples should be provided.

exit_interval: Optional[int]#

None

Exit the program after the iteration is divisible by this value.

exit_duration_in_mins: Optional[int]#

None

Exit the program after this many minutes.

exit_signal_handler: bool#

False

Dynamically save the checkpoint and shutdown the training if SIGTERM is received

exit_signal: int#

None

Signal for the signal handler to detect.

exit_signal_handler_for_dataloader: bool#

False

Use signal handler for dataloader workers

manual_gc: bool#

False

Disable the threshold-based default garbage collector and trigger the garbage collection manually. Manual garbage collection helps to align the timing of the collection across ranks which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage collection is performed only at the start and the end of the validation routine by default.

manual_gc_interval: int#

0

Training step interval to trigger manual garbage collection. When the value is set to 0, garbage collection is not triggered between training steps.

manual_gc_eval: bool#

True

When using manual garbage collection, disable garbage collection at the start and the end of each evaluation run.

eval_iters: int#

100

Number of iterations to run for evaluation validation/test for.

eval_interval: Optional[int]#

1000

Interval between running evaluation on validation set.

skip_train: bool#

False

If set, bypass the training loop, optionally do evaluation for validation/test, and exit.

class bridge.training.config.CheckpointConfig#

Configuration settings for model checkpointing (saving and loading).

save: Optional[str]#

None

Output directory to save checkpoints to.

save_interval: Optional[int]#

None

Number of iterations between persistent checkpoint saves.

save_optim: bool#

True

Do not save current optimizer.

save_rng: bool#

True

Do not save current rng state.

load: Optional[str]#

None

Directory containing a model checkpoint.

load_optim: bool#

True

Do not load optimizer when loading checkpoint.

load_main_params_from_ckpt: bool#

False

Load main parameters from checkpoint. When loading a model from a checkpoint without loading the optimizer, the model parameters are updated but for fp16 optimizer with main parameters, the main parameters need to also be updated.

load_rng: bool#

True

Do not load rng state when loading checkpoint.

non_persistent_save_interval: Optional[int]#

None

Number of iterations between non-persistent saves.

non_persistent_ckpt_type: Optional[typing.Literal[global, local, in_memory, None]]#

None

Type of non-persistent model checkpoints. “global” - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. “local” - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). “in_memory” - [TBD] A special kind of local checkpoint that avoids serialization. None - No non-persistent checkpointing (default option).

non_persistent_global_ckpt_dir: Optional[str]#

None

Directory containing global non-persistent model checkpoints.

non_persistent_local_ckpt_dir: Optional[str]#

None

Directory containing local non-persistent model checkpoints.

non_persistent_local_ckpt_algo: Literal[fully_parallel, atomic]#

‘fully_parallel’

Algorithm for local non-persistent checkpointing.

finetune: bool#

False

Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0. Assumed when loading a release checkpoint.

pretrained_checkpoint: Optional[str]#

None

Directory containing a pretrained model checkpoint for finetuning.

ckpt_step: Optional[int]#

None

Checkpoint step to load model from.

use_checkpoint_args: bool#

False

Override any command line arguments with arguments from the checkpoint

exit_on_missing_checkpoint: bool#

False

If ‘load’ is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization.

ckpt_format: Literal[torch_dist, zarr]#

‘torch_dist’

Checkpoint format to use.

ckpt_convert_format: Optional[Literal[torch, torch_dist, zarr]]#

None

Checkpoint format for conversion.

ckpt_convert_save: Optional[str]#

None

Save directory for converted checkpoint.

fully_parallel_save: bool#

True

Disable applying full save parallelization across DP for distributed checkpoints. Depending on ckpt format might decrease the number of files in the checkpoint. Makes DistributedOptimizer checkpoint non-reshardable.

async_save: bool#

False

Apply async checkpointing save. Currently works only with torch_dist distributed checkpoint format.

use_persistent_ckpt_worker: bool#

True

Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated worker thread/process for handling async saves. When disabled, uses temporal workers that are created and destroyed for each save operation.

fully_parallel_load: bool#

False

Apply full load parallelization across DP for distributed checkpoints.

ckpt_assume_constant_structure: bool#

False

If the model and optimizer state dict structure is constant throughout a *single training job, it allows for different checkpointing performance optimizations.

dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all]#

‘assume_ok_unexpected’

Determine handling of key mismatch during checkpoint load. Check StrictHandling docs for flags meaning. NOTE: This flag controls only distributed checkpoint load from storage, not loading state dict into the model.

replication: bool#

False

If set, replication of local checkpoints is enabled. Needs to be enabled on all ranks.

replication_jump: Optional[int]#

None

Specifies J, the spacing between ranks storing replicas of a given rank’s data. Replicas for rank n may be on ranks n+J, n+2J, …, or n-J, n-2J, etc. This flag has an effect only if –replication is used. and must be consistent across all ranks.

replication_factor: int#

2

Number of machines storing the replica of a given rank’s data.

__post_init__() None#

Post-initialization checks for checkpoint config.

class bridge.training.config.LoggerConfig#

Configuration settings for logging, including TensorBoard and WandB.

log_interval: int#

100

Report loss and timing interval.

log_params_norm: bool#

False

If set, calculate and log parameters norm.

log_throughput: bool#

False

If set, calculate and log throughput per GPU.

log_progress: bool#

False

If set, log progress (in terms of number of processed tokens and number of floating-point operations) to progress.txt file in checkpoint directory.

timing_log_level: Literal[0, 1, 2]#

0

Granularity level to measure and report timing. 0: report only iteration time and make sure timing does not introduce extra overhead. 1: report timing for operations that are executed very limited times (basically once) during each iteration (such as gradient all-reduce) 2: report timing for operations that migh be executed numerous times during each iteration. Note that setting the level to 1 or 2 might cause increase in iteration time.

timing_log_option: Literal[max, minmax, all]#

‘minmax’

Options for logging timing: max: report the max timing across all ranks minmax: report min and max timings across all ranks all: report timings of all ranks.

tensorboard_dir: Optional[str]#

None

Write TensorBoard logs to this directory.

tensorboard_log_interval: int#

1

Report to tensorboard interval.

tensorboard_queue_size: int#

1000

Size of the tensorboard queue for pending events and summaries before one of the ‘add’ calls forces a flush to disk.

log_timers_to_tensorboard: bool#

False

If set, write timers to tensorboard.

log_loss_scale_to_tensorboard: bool#

True

Disable loss-scale logging to tensorboard.

log_validation_ppl_to_tensorboard: bool#

False

If set, write validation perplexity to tensorboard.

log_memory_to_tensorboard: bool#

False

Enable memory logging to tensorboard.

log_world_size_to_tensorboard: bool#

False

Enable world size logging to tensorboard.

wandb_project: Optional[str]#

None

The wandb project name. Ignore wandb by default.

wandb_exp_name: Optional[str]#

None

The wandb experiment name.

wandb_save_dir: Optional[str]#

None

Path to save the wandb results locally.

wandb_entity: Optional[str]#

None

The wandb entity name.

logging_level: int#

None

Set default logging level

filter_warnings: bool#

True

Filter out warning messages

modules_to_filter: Optional[list[str]]#

None

List of modules to filter out from the logs

set_level_for_all_loggers: bool#

False

Set the logging level for all loggers. If False, only level for NeMo loggers will be set.

log_energy: bool#

False

If set, log energy consumption (in Joules).

class bridge.training.config.ProfilingConfig#

Configuration settings for profiling the training process.

use_nsys_profiler: bool#

False

Enable nsys profiling. When using this option, nsys options should be specified in commandline. An example nsys commandline is nsys profile -s none -t nvtx,cuda -o <path/to/output_file> --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop.

profile_step_start: int#

10

Global step to start profiling.

profile_step_end: int#

12

Global step to stop profiling.

use_pytorch_profiler: bool#

False

Use the built-in pytorch profiler. Useful if you wish to view profiles in tensorboard.

profile_ranks: list[int]#

‘field(…)’

Global ranks to profile.

record_memory_history: bool#

False

Record memory history in last rank.

memory_snapshot_path: str#

‘snapshot.pickle’

Specifies where to dump the memory history pickle.

record_shapes: bool#

False

Record shapes of tensors.

__post_init__() None#

Validate profiling configuration.

class bridge.training.config.FaultToleranceConfig#

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

enable_ft_package: bool#

False

If set, Fault Tolerance package is enabled. Note: This feature is for Nvidia internal use only.

calc_ft_timeouts: bool#

False

If set, FT package will try to automatically compute the timeouts. Note: This feature is for Nvidia internal use only.

simulate_fault: bool#

False

Sets a simulated fault for fault tolerance. NOTE: This if for fault tolerance testing only.

simulated_fault_type: Literal[rank_hung, rank_killed, random]#

‘random’

How the simulated fault should behave. ‘random’ will randomly choose one of the other two options.

simulated_fault_rank: Optional[int]#

None

Rank on which simulated fault should occur.

simulated_fault_base_delay: int#

0

Base delay before simulated fault thread is started. A small random delay is added to this.

class bridge.training.config.StragglerDetectionConfig#

Configuration settings for detecting and logging GPU stragglers.

log_straggler: bool#

False

If set, tracks and logs straggler per GPU.

enable_straggler_on_startup: bool#

True

If set, StragglerDetector is disabled on startup.

straggler_ctrlr_port: int#

65535

Port number to toggle StragglerDetector on/off at runtime

straggler_minmax_count: int#

1

Number of ranks to report with high/low estimated throughput

disable_straggler_on_startup: bool#

False

If set, StragglerDetector is disabled on startup.

class bridge.training.config.NVRxStragglerDetectionConfig#

Configuration settings for NVIDIA Resiliency Extension straggler detection.

enabled: bool#

False

Enable NVRx straggler detection.

report_time_interval: float#

300.0

Interval [seconds] of the straggler check.

calc_relative_gpu_perf: bool#

True

Calculate relative GPU performance scores.

calc_individual_gpu_perf: bool#

True

Calculate individual GPU performance scores.

num_gpu_perf_scores_to_print: int#

5

How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected).

gpu_relative_perf_threshold: float#

0.7

Threshold for relative GPU performance scores.

gpu_individual_perf_threshold: float#

0.7

Threshold for individual GPU performance scores.

stop_if_detected: bool#

False

Set to True, to terminate the workload if stragglers are detected.

enable_logging: bool#

True

Set to True, to log GPU performance scores.

profiling_interval: int#

1

Profiling interval passed to straggler.Detector.initialize.

logger_name: str#

‘megatron_hub.NVRxStragglerDetection’

Logger name for straggler detection messages.

__post_init__() None#

Validate NVRx straggler detection configuration.

class bridge.training.config.ConfigContainer#

Bases: megatron.bridge.training.utils.config_utils._ConfigContainerBase

Top-level container holding all configuration objects.

rng: bridge.training.config.RNGConfig#

‘field(…)’

rerun_state_machine: bridge.training.config.RerunStateMachineConfig#

‘field(…)’

train: bridge.training.config.TrainingConfig#

None

model: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider | megatron.bridge.models.mamba.mamba_provider.MambaProvider#

None

optimizer: megatron.core.optimizer.OptimizerConfig#

None

ddp: megatron.core.distributed.DistributedDataParallelConfig#

‘field(…)’

scheduler: bridge.training.config.SchedulerConfig#

None

dataset: bridge.training.config.GPTDatasetConfig | bridge.training.config.FinetuningDatasetConfig#

None

logger: bridge.training.config.LoggerConfig#

None

tokenizer: megatron.bridge.training.tokenizers.config.TokenizerConfig#

None

checkpoint: bridge.training.config.CheckpointConfig#

None

dist: bridge.training.config.DistributedInitConfig#

‘field(…)’

ft: Optional[bridge.training.config.FaultToleranceConfig]#

None

straggler: Optional[bridge.training.config.StragglerDetectionConfig]#

None

nvrx_straggler: Optional[bridge.training.config.NVRxStragglerDetectionConfig]#

None

profiling: Optional[bridge.training.config.ProfilingConfig]#

None

peft: Optional[megatron.bridge.peft.base.PEFT]#

None

comm_overlap: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig]#

None

mixed_precision: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]]#

None

get_data_parallel_size(world_size: int) int#

Calculate the data parallel size based on the model configuration.

validate() None#

Performs validation checks on the combined configuration.

Calculates dependent values like data_parallel_size and scheduler steps. Ensures compatibility between different configuration settings.