bridge.training.config#

Module Contents#

Classes#

DistributedDataParallelConfig

Megatron Core DistributedDataParallelConfig with deferred post-init.

OptimizerConfig

Megatron Core OptimizerConfig with deferred post-init.

DistributedInitConfig

Configuration settings for distributed training initialization.

RerunStateMachineConfig

Configuration for the rerun state machine used for result validation or stats.

DataloaderConfig

Base configuration for data loading.

DatasetBuildContext

Interface that encapsulates framework internals.

OptimizerConfigOverrideProviderContext

Context for providing config overrides.

OptimizerConfigOverrideProvider

Abstract base class for providing config overrides.

DatasetProvider

Abstract base class for custom dataset configurations.

GPTDatasetConfig

Megatron Core GPTDatasetConfig with deferred post-init.

GPTFIMDatasetConfig

Configuration object forGPT FIM datasets

MockGPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

FinetuningDatasetConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

SchedulerConfig

Configuration settings for the learning rate scheduler and weight decay.

TrainingConfig

Configuration settings related to the training loop and validation.

CheckpointConfig

Configuration settings for model checkpointing (saving and loading).

LoggerConfig

Configuration settings for logging, including TensorBoard and WandB.

ProfilingConfig

Configuration settings for profiling the training process.

TensorInspectConfig

Configuration for Nvidia-DL-Framework-Inspect integration.

FaultToleranceConfig

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

StragglerDetectionConfig

Configuration settings for detecting and logging GPU stragglers.

NVRxStragglerDetectionConfig

Configuration settings for NVIDIA Resiliency Extension straggler detection.

InProcessRestartConfig

Configuration settings for NVIDIA Resiliency Extension in-process restart functionality.

ConfigContainer

Top-level container holding all configuration objects.

Functions#

_get_mcore_transformer_parent

Determine the correct Mcore TransformerConfig parent class for a model.

_get_non_default_values

Get values that differ from Mcore parent class defaults.

_get_key_config_values

Get key configuration values for non-Mcore configs.

runtime_config_update

Apply runtime configuration updates prior to initialization.

mimo_runtime_config_update

MIMO-equivalent of runtime_config_update.

_validate_and_sync_distributed_optimizer_settings

Validate and synchronize distributed optimizer settings between DDP and optimizer configs.

_validate_mixed_precision_consistency

Validate that mixed precision settings are consistent between model and optimizer configs.

_validate_fine_grained_activation_offloading

Validate fine-grained activation offloading configuration.

API#

class bridge.training.config.DistributedDataParallelConfig#

Bases: megatron.core.distributed.DistributedDataParallelConfig

Megatron Core DistributedDataParallelConfig with deferred post-init.

This class inherits from Megatron Core’s DistributedDataParallelConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

param_name_patterns_for_fp32_local_accumulation: Tuple[str, ...]#

()

fnmatch patterns selecting parameters whose gradients should be locally accumulated in FP32. The special pattern 'all' matches every parameter. Synced from MCore c586f6d56 (#4028); field will be inherited from the base class after the next mcore bump.

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

finalize() None#

Execute the deferred MCore post-init logic.

This method calls the original Megatron Core DistributedDataParallelConfig.post_init() to compute derived fields based on the current field values.

class bridge.training.config.OptimizerConfig#

Bases: megatron.core.optimizer.OptimizerConfig

Megatron Core OptimizerConfig with deferred post-init.

This class inherits from Megatron Core’s OptimizerConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

finalize() None#

Execute the deferred MCore post-init logic.

This method calls the original Megatron Core OptimizerConfig.post_init() to compute derived fields based on the current field values.

class bridge.training.config.DistributedInitConfig#

Bases: megatron.training.config.DistributedInitConfig

Configuration settings for distributed training initialization.

external_gpu_device_mapping: bool#

False

If True, indicates that GPU device mapping has been externally managed (e.g., via CUDA_VISIBLE_DEVICES environment variable). When True, uses device 0 instead of local rank for CUDA device selection. This is useful when launching with external process managers that handle GPU visibility.

enable_megatron_core_experimental: bool#

False

Enable experimental features for Megatron Core.

use_decentralized_pg: bool#

False

Use ProcessGroupCollection passed through functions instead of relying on mcore’s global parallel state (mpu) variables. When True, parallel groups are obtained from the pg_collection object rather than the global megatron.core.parallel_state module.

property lazy_init: bool#
class bridge.training.config.RerunStateMachineConfig#

Bases: megatron.training.config.RerunStateMachineConfig

Configuration for the rerun state machine used for result validation or stats.

rerun_mode: Literal[disabled, validate_results, report_determinism_stats]#

β€˜disabled’

Use re-run engine to validate results (default) or to emit stats on variability of computations due to non-deterministic algorithms.

spiky_loss_factor: float#

10.0

Factor for detecting spiky loss. A loss is considered spiky if it exceeds this multiple of the max observed loss over the sample window.

class bridge.training.config.DataloaderConfig#

Base configuration for data loading.

dataloader_type: Optional[Literal[single, cyclic, batch, external]]#

None

Dataloader type: β€˜single’ for single pass, β€˜cyclic’ for multiple passes with shuffling, β€˜batch’ for global batch sampling (used in fine-tuning), or β€˜external’ for custom dataloaders.

num_workers: int#

2

Dataloader number of workers.

data_sharding: bool#

True

Disable data sharding.

pin_memory: bool#

True

Whether to pin memory during data loading for faster GPU training.

drop_last: bool#

True

Whether to drop the last incomplete batch.

persistent_workers: bool#

True

Whether to keep data loading workers persistent across epochs. Automatically set to False when num_workers is 0.

trust_remote_code: Optional[bool]#

None

Whether remote code execution should be trusted for a given HF path.

finalize()#

Finalize dataloader config field constraints.

class bridge.training.config.DatasetBuildContext#

Interface that encapsulates framework internals.

This context provides metadata needed to build datasets while hiding implementation details of the framework.

.. attribute:: train_samples

Number of samples for training dataset

.. attribute:: valid_samples

Number of samples for validation dataset

.. attribute:: test_samples

Number of samples for test dataset

.. attribute:: tokenizer

Optional tokenizer instance for text processing

.. attribute:: pg_collection

Optional process group collection for distributed training

train_samples: int#

None

valid_samples: int#

None

test_samples: int#

None

tokenizer: Optional[megatron.bridge.training.tokenizers.tokenizer.MegatronTokenizer]#

None

pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#

None

class bridge.training.config.OptimizerConfigOverrideProviderContext#

Context for providing config overrides.

scheduler_config: SchedulerConfig#

None

optimizer_config: bridge.training.config.OptimizerConfig#

None

model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]]#

None

class bridge.training.config.OptimizerConfigOverrideProvider#

Abstract base class for providing config overrides.

build_config_overrides(
context: bridge.training.config.OptimizerConfigOverrideProviderContext,
) dict[megatron.core.optimizer.ParamKey, megatron.core.optimizer.ParamGroupOverride] | None#

Build config overrides for weight decay based on scheduler configuration.

This function creates parameter-specific overrides for weight decay behavior. By default, weight decay is skipped for bias parameters and 1D parameters. For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.

Parameters:

context – OptimizerConfigOverrideProviderContext which packages the scheduler configuration, optimizer configuration, and model.

Returns:

Dictionary of ParamKey to ParamGroupOverride for the optimizer

class bridge.training.config.DatasetProvider#

Bases: bridge.training.config.DataloaderConfig, abc.ABC

Abstract base class for custom dataset configurations.

Provides an interface for users to implement their own dataset builders while automatically inheriting all DataloaderConfig functionality.

Users must:

  1. Inherit from this class

  2. Implement the build_datasets() method

.. rubric:: Example

@dataclass class S3DatasetConfig(DatasetProvider): bucket_name: str data_prefix: str seq_length: int

def build_datasets(self, context: DatasetBuildContext) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]:
    # Custom implementation to load data from S3
    train_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/train", context.tokenizer)
    valid_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/valid", context.tokenizer)
    test_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/test", context.tokenizer)
    return train_ds, valid_ds, test_ds
abstractmethod build_datasets(
context: bridge.training.config.DatasetBuildContext,
) Tuple[Optional[Any], Optional[Any], Optional[Any]]#

Build train, validation, and test datasets.

This method is called by the framework during dataset initialization. Implementations should use the provided context to create appropriate datasets for each split.

Parameters:

context – Build context with sample counts and tokenizer

Returns:

Tuple of (train_dataset, valid_dataset, test_dataset) Any element can be None if that split shouldn’t be created.

Raises:

NotImplementedError – Must be implemented by subclasses

class bridge.training.config.GPTDatasetConfig(
seq_length: int | None = None,
skip_getting_attention_mask_from_dataset: bool = True,
data_path: str | list[str] | None = None,
*args,
**kwargs,
)#

Bases: megatron.core.datasets.gpt_dataset.GPTDatasetConfig, bridge.training.config.DataloaderConfig

Megatron Core GPTDatasetConfig with deferred post-init.

This class inherits from MCore’s GPTDatasetConfig and DataloaderConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.

Initialization

Parameters:
  • seq_length (int | None) – the sequence length. If not provided, sequence_length must be in kwargs.

  • skip_getting_attention_mask_from_dataset (bool) – if set, the dataset will pass a None attention mask and the attention mask is autogenerated from the attn backend.

  • data_path – CLI-friendly data path(s). Converted to blend in finalize().

data_path: str | list[str] | None#

None

CLI-friendly alternative to blend. Accepts a single path string, a space-separated multi-path string, or a list of paths (with optional interleaved weights, matching Megatron-LM --data-path semantics). Converted to blend automatically during finalize().

__post_init__() None#

Skip MCore post_init during initial construction.

The original post_init logic is deferred until finalize() is called.

property seq_length#

Alias for MCore’s sequence_length field.

finalize() None#

Execute the deferred MCore post-init logic and Bridge-specific checks.

This method calls the original Megatron Core GPTDatasetConfig.post_init() and then performs Bridge-specific validation.

class bridge.training.config.GPTFIMDatasetConfig(
fim_rate: float = None,
fim_spm_rate: float = None,
fim_extra_tokens: Dict = None,
fim_split_sample: Optional[str] = None,
fim_fragment_rate: Optional[float] = None,
fim_no_prefix: Optional[str] = None,
**kwargs,
)#

Bases: bridge.training.config.GPTDatasetConfig

Configuration object forGPT FIM datasets

Initialization

Parameters:
  • fim_rate – float: probability to convert a training sample into a FIM format.

  • fim_spm_rate (float) – probability that the a FIM sample uses the SPM format over the PSM format.

  • fim_extra_tokens (Dict) – should consist of prefix, middle, suffix, PAD, and EOD tokens.

  • fim_split_sample (str) – string around which to split the sample for FIM.

  • fim_fragment_rate (float) – rate of FIM on each fragment when split_sample is not None.

  • fim_no_prefix (str) – do not apply FIM to fragments that start with this prefix.

class bridge.training.config.MockGPTDatasetConfig(seq_length: int, **kwargs)#

Bases: bridge.training.config.GPTDatasetConfig

Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.

Initialization

Parameters:
  • seq_length (int | None) – the sequence length. If not provided, sequence_length must be in kwargs.

  • skip_getting_attention_mask_from_dataset (bool) – if set, the dataset will pass a None attention mask and the attention mask is autogenerated from the attn backend.

  • data_path – CLI-friendly data path(s). Converted to blend in finalize().

finalize()#
class bridge.training.config.FinetuningDatasetConfig#

Bases: bridge.training.config.DataloaderConfig

Configuration specific to finetuning datasets, inheriting from DataloaderConfig.

Note: For fine-tuning, dataloader_type defaults to β€˜batch’ which ensures sequences within each global batch are padded to the same length.

dataloader_type: Optional[Literal[single, cyclic, batch, external]]#

β€˜batch’

Dataloader type for fine-tuning. Defaults to β€˜batch’ for optimal padding behavior.

dataset_root: Optional[Union[str, pathlib.Path]]#

None

seq_length: int#

None

seed: int#

1234

memmap_workers: int#

1

max_train_samples: Optional[int]#

None

packed_sequence_specs: Optional[megatron.bridge.data.datasets.packed_sequence.PackedSequenceSpecs]#

None

dataset_kwargs: Optional[dict[str, Any]]#

None

do_validation: bool#

True

do_test: bool#

True

class bridge.training.config.SchedulerConfig#

Bases: megatron.training.config.SchedulerConfig

Configuration settings for the learning rate scheduler and weight decay.

finalize() None#

Post-initialization checks for scheduler config.

class bridge.training.config.TrainingConfig#

Bases: megatron.training.config.TrainingConfig

Configuration settings related to the training loop and validation.

check_optimizer_step_success: bool#

True

Checks optimizer.step() succeeded at each training step .

skip_sync_grad_norm_across_mp: bool#

False

Skips syncing the grad norm across the model parallel group.

eval_iters: int | None#

None

Number of iterations to run for evaluation validation/test for. Deprecated in favor of ValidationConfig.

eval_interval: int | None#

None

Interval between running evaluation on validation set. Deprecated in favor of ValidationConfig.

skip_train: bool | None#

None

If set, bypass the training loop, optionally do evaluation for validation/test, and exit. Deprecated in favor of ValidationConfig.

finalize() None#

Validate training mode specification and calculate train_iters from train_samples if needed.

class bridge.training.config.CheckpointConfig#

Bases: megatron.training.config.CheckpointConfig

Configuration settings for model checkpointing (saving and loading).

pretrained_checkpoint: Optional[str]#

None

Directory containing a pretrained model checkpoint for finetuning.

This can be either:

  • A parent checkpoint directory (e.g. /checkpoints/my_model/) that contains tracker files (latest_train_state.pt) and iter_* subdirectories.

  • A specific iteration directory (e.g. /checkpoints/my_model/iter_0001000/) that directly contains the checkpoint payload (run_config.yaml, weight shards, etc.).

storage_writers_per_rank: int#

1

Number of storage writers per rank for torch_dist checkpoint format. Affects the number of checkpoint files: saving_ranks * storage_writers_per_rank.

use_persistent_ckpt_worker: bool#

True

Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated worker thread/process for handling async saves. When disabled, uses temporal workers that are created and destroyed for each save operation.

async_strategy: str#

β€˜nvrx’

Async checkpoint strategy to use. Options: "nvrx" (default) or "mcore". The "nvrx" strategy uses nvidia_resiliency_ext for async checkpointing and falls back to "mcore" if the package is not installed.

async_write_results_mp_mode: str#

β€˜fork’

Multiprocessing start method for the async write results queue. Options: "fork" (default), "spawn", "forkserver".

strict_fsdp_dtensor_load: bool#

False

Whether to enforce strict loading for FSDP DTensor checkpoints. When False, allows partial loading.

custom_manager_class: str | None#

None

Fully qualified class name for a custom CheckpointManager implementation.

When set, checkpoint operations will instantiate and delegate to this class instead of the default checkpoint manager. The custom class must implement the CheckpointManager protocol defined in megatron.bridge.training.checkpointing.

Example: 'mypackage.checkpoint.MyCheckpointManager'

finalize() None#

Post-initialization checks for checkpoint config.

class bridge.training.config.LoggerConfig#

Bases: megatron.training.config.LoggerConfig

Configuration settings for logging, including TensorBoard and WandB.

skip_train_metrics_log: bool#

False

Skips logging of training metrics to all logging backends and to the console as well.

timing_log_level: Literal[-1, 0, 1, 2]#

0

Granularity level to measure and report timing. -1: To disable timing logging as the timer start from 0 and above. 0: report only iteration time and make sure timing does not introduce extra overhead. 1: report timing for operations that are executed very limited times (basically once) during each iteration (such as gradient all-reduce) 2: report timing for operations that migh be executed numerous times during each iteration. Note that setting the level to 1 or 2 might cause increase in iteration time.

mlflow_experiment: Optional[str]#

None

The MLFlow experiment name.

mlflow_run_name: Optional[str]#

None

The MLFlow run name.

mlflow_tracking_uri: Optional[str]#

None

Optional MLFlow tracking URI.

mlflow_tags: Optional[dict[str, str]]#

None

Optional tags to apply to the MLFlow run.

comet_project: Optional[str]#

None

The Comet ML project name. Comet logging is disabled when this is None.

comet_experiment_name: Optional[str]#

None

The Comet ML experiment name.

comet_workspace: Optional[str]#

None

The Comet ML workspace. If not set, uses the default workspace for the API key.

comet_api_key: Optional[str]#

None

The Comet ML API key. Can also be set via COMET_API_KEY environment variable.

comet_tags: Optional[list[str]]#

None

Optional list of tags to apply to the Comet ML experiment.

logging_level: int#

None

Set default logging level

finalize() None#

Validate logger settings and optional MLFlow dependency.

class bridge.training.config.ProfilingConfig#

Bases: megatron.training.config.ProfilingConfig

Configuration settings for profiling the training process.

finalize() None#

Validate profiling configuration.

class bridge.training.config.TensorInspectConfig#

Configuration for Nvidia-DL-Framework-Inspect integration.

enabled: bool#

False

Enable tensor inspection and statistics collection.

features: dict[str, Any] | str | pathlib.Path | None#

None

Feature configuration as a Python dict or a YAML file path.

feature_dirs: list[str] | None#

None

Directories containing feature implementations (searched recursively).

log_dir: str | None#

None

Root directory to store inspection logs/statistics. Defaults to checkpoint save dir if unset.

init_training_step: int#

0

Initial training step for the inspector (used when resuming).

finalize() None#

Populate sensible defaults when inspection is enabled.

  • If feature_dirs is unset, default to the installed TransformerEngine debug features package path (transformer_engine.debug.features), when available.

class bridge.training.config.FaultToleranceConfig#

Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).

enable_ft_package: bool#

False

If set, Fault Tolerance package is enabled. Note: This feature is for Nvidia internal use only.

calc_ft_timeouts: bool#

False

If set, FT package will try to automatically compute the timeouts. Note: This feature is for Nvidia internal use only.

simulate_fault: bool#

False

Sets a simulated fault for fault tolerance. NOTE: This if for fault tolerance testing only.

simulated_fault_type: Literal[rank_hung, rank_killed, random]#

β€˜random’

How the simulated fault should behave. β€˜random’ will randomly choose one of the other two options.

simulated_fault_rank: Optional[int]#

None

Rank on which simulated fault should occur.

simulated_fault_base_delay: int#

0

Base delay before simulated fault thread is started. A small random delay is added to this.

class bridge.training.config.StragglerDetectionConfig#

Bases: megatron.training.config.StragglerDetectionConfig

Configuration settings for detecting and logging GPU stragglers.

enable_straggler_on_startup: bool#

True

If set, StragglerDetector is enabled on startup.

class bridge.training.config.NVRxStragglerDetectionConfig#

Configuration settings for NVIDIA Resiliency Extension straggler detection.

enabled: bool#

False

Enable NVRx straggler detection.

report_time_interval: float#

300.0

Interval [seconds] of the straggler check.

calc_relative_gpu_perf: bool#

True

Calculate relative GPU performance scores.

calc_individual_gpu_perf: bool#

True

Calculate individual GPU performance scores.

num_gpu_perf_scores_to_print: int#

5

How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected).

gpu_relative_perf_threshold: float#

0.7

Threshold for relative GPU performance scores.

gpu_individual_perf_threshold: float#

0.7

Threshold for individual GPU performance scores.

stop_if_detected: bool#

False

Set to True, to terminate the workload if stragglers are detected.

enable_logging: bool#

True

Set to True, to log GPU performance scores.

profiling_interval: int#

1

Profiling interval passed to straggler.Detector.initialize.

logger_name: str#

β€˜megatron.bridge.NVRxStragglerDetection’

Logger name for straggler detection messages.

finalize() None#

Validate NVRx straggler detection configuration.

class bridge.training.config.InProcessRestartConfig#

Configuration settings for NVIDIA Resiliency Extension in-process restart functionality.

enabled: bool#

False

Enable in-process restart mechanism from nvidia-resiliency-ext.

max_iterations: Optional[int]#

None

Maximum number of in-process restart iterations.

monitor_thread_interval: float#

1.0

Monitoring interval (in seconds) for the monitoring thread.

monitor_process_interval: float#

1.0

Monitoring interval (in seconds) for the monitoring process.

progress_watchdog_interval: float#

1.0

Interval (in seconds) for automatic progress watchdog timestamp updates.

heartbeat_interval: float#

30.0

Monitoring interval (in seconds) for detecting unresponsive ranks.

soft_timeout: float#

60.0

Soft progress timeout (in seconds).

hard_timeout: float#

90.0

Hard progress timeout (in seconds).

heartbeat_timeout: float#

60.0

Timeout (in seconds) for a missing rank detection heartbeat.

barrier_timeout: float#

120.0

Timeout (in seconds) for internal distributed barrier.

completion_timeout: float#

120.0

Timeout (in seconds) for barrier on completion on all ranks.

last_call_wait: float#

1.0

Time interval (in seconds) for other ranks to report concurrent terminal failures.

termination_grace_time: float#

1.0

Interval (in seconds) between SIGTERM and SIGKILL issued on hard timeout.

granularity: Literal[node, rank]#

β€˜node’

Granularity for in-process restart.

active_world_size: Optional[int]#

None

The number of ranks initially executing the workload. The remaining ranks from the allocation are set aside as warm reserve. If None, defaults to WORLD_SIZE environment variable.

empty_cuda_cache: bool#

True

Empty CUDA cache during restart finalization.

max_rank_faults: Optional[int]#

None

Maximum number of rank faults allowed before terminating the job.

monitor_process_logdir: Optional[str]#

None

Directory for monitor process log files. If None, monitor process logging is disabled.

class bridge.training.config.ConfigContainer#

Bases: megatron.bridge.training.utils.config_utils._ConfigContainerBase

Top-level container holding all configuration objects.

rng: megatron.training.config.RNGConfig#

β€˜field(…)’

rerun_state_machine: bridge.training.config.RerunStateMachineConfig#

β€˜field(…)’

train: bridge.training.config.TrainingConfig#

None

model: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider | megatron.bridge.models.mamba.mamba_provider.MambaModelProvider | megatron.bridge.models.mimo.mimo_provider.MimoModelProvider | megatron.bridge.models.gpt.gpt_builder.GPTModelConfig | megatron.bridge.models.mamba.mamba_builder.MambaModelConfig#

None

optimizer: bridge.training.config.OptimizerConfig#

None

optimizer_config_override_provider: bridge.training.config.OptimizerConfigOverrideProvider#

β€˜field(…)’

ddp: bridge.training.config.DistributedDataParallelConfig#

β€˜field(…)’

validation: megatron.training.config.ValidationConfig#

β€˜field(…)’

scheduler: bridge.training.config.SchedulerConfig#

None

dataset: bridge.training.config.GPTDatasetConfig | bridge.training.config.FinetuningDatasetConfig | bridge.training.config.DatasetProvider#

None

logger: bridge.training.config.LoggerConfig#

None

tokenizer: megatron.bridge.training.tokenizers.config.TokenizerConfig#

None

checkpoint: bridge.training.config.CheckpointConfig#

None

dist: bridge.training.config.DistributedInitConfig#

β€˜field(…)’

ft: Optional[bridge.training.config.FaultToleranceConfig]#

None

straggler: Optional[bridge.training.config.StragglerDetectionConfig]#

None

nvrx_straggler: Optional[bridge.training.config.NVRxStragglerDetectionConfig]#

None

profiling: bridge.training.config.ProfilingConfig#

β€˜field(…)’

peft: Optional[megatron.bridge.peft.base.PEFT]#

None

comm_overlap: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig]#

None

mixed_precision: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]]#

None

tensor_inspect: bridge.training.config.TensorInspectConfig | None#

None

inprocess_restart: Optional[bridge.training.config.InProcessRestartConfig]#

None

get_data_parallel_size(world_size: int) int#

Calculate the data parallel size based on the model configuration.

set_data_parallel_size() None#

Calculate and set data_parallel_size for this config and comm_overlap config.

This method calculates the data parallel size needed by setup methods, without triggering full validation or finalization of Megatron Core configs.

_validate_and_apply_deterministic_mode() None#

Apply and validate deterministic mode requirements.

This enforces restrictions and settings that must hold when the model is configured to run in deterministic mode.

validate() None#

Performs validation checks on the combined configuration.

Calculates dependent values like data_parallel_size and scheduler steps. Ensures compatibility between different configuration settings.

_validate_cp_comm_type() None#

Validate cp_comm_type and hierarchical_context_parallel_sizes consistency.

_validate_training_scheduler_compatibility() None#

Cross-validation between training and scheduler configs.

_calculate_scheduler_steps() None#

Calculate scheduler steps for both iteration-based and sample-based training.

log_non_default_values() None#

Log configuration values that differ from Megatron Core defaults.

For configs that inherit from Megatron Core (e.g., OptimizerConfig, DDPConfig, TransformerConfig), this method logs only the values that differ from the Mcore defaults. This makes it easier to spot unintended deviations from baseline settings.

For configs that don’t inherit from Mcore, key values are logged via _get_key_config_values, which excludes None values and callables.

bridge.training.config._get_mcore_transformer_parent(model_config: Any) type#

Determine the correct Mcore TransformerConfig parent class for a model.

Some models (e.g., DeepSeek v2/v3) inherit from MLATransformerConfig instead of the base TransformerConfig. This function checks the inheritance chain to find the appropriate Mcore class to use as the baseline for comparison.

Parameters:

model_config – The model configuration object.

Returns:

The appropriate Mcore TransformerConfig class (MCoreMLATransformerConfig or MCoreTransformerConfig).

bridge.training.config._get_non_default_values(
config_obj: Any,
mcore_class: type,
) Dict[str, Tuple[Any, Any]]#

Get values that differ from Mcore parent class defaults.

Parameters:
  • config_obj – The config object to compare.

  • mcore_class – The Megatron Core parent class to compare against.

Returns:

Dictionary mapping field name to (current_value, default_value) for non-default fields.

bridge.training.config._get_key_config_values(
config_obj: Any,
) Dict[str, Any]#

Get key configuration values for non-Mcore configs.

Parameters:

config_obj – The config object to extract values from.

Returns:

Dictionary mapping field name to value for key fields.

bridge.training.config.runtime_config_update(
cfg: bridge.training.config.ConfigContainer,
) None#

Apply runtime configuration updates prior to initialization.

This function handles all configuration modifications that need to happen after initial config creation but before final validation and model setup.

Steps:

  1. Resolve mixed precision configuration from string if needed

  2. Apply mixed precision settings to model, optimizer, and DDP configs

  3. Calculate data parallel size (needed for comm overlap)

  4. Apply communication overlap configuration

  5. Validate configuration after all modifications

Parameters:

cfg – Configuration container to update

bridge.training.config.mimo_runtime_config_update(
cfg: bridge.training.config.ConfigContainer,
) None#

MIMO-equivalent of runtime_config_update.

The standard runtime_config_update cannot be used directly because it accesses cfg.model attributes (bf16, tensor_model_parallel_size, cuda_graph_impl, …) that do not exist on MimoModelProvider.

This function cherry-picks the safe, model-agnostic parts:

Keeps (safe for MIMO):

  • data_parallel_size = 1 (MIMO-specific hard-code)

  • Sub-config finalization (optimizer, ddp, logger, train, scheduler, checkpoint)

  • Distributed optimizer sync validation

  • Deterministic mode validation

Skips (would crash or is N/A):

  • Mixed precision resolution (per-module, not container-level)

  • Communication overlap setup (not supported for MIMO)

  • Model-level validations (FSDP, CUDA graphs, TE RNG tracker sync, etc.)

See playground/runtime_config_update_analysis.md for the full analysis.

bridge.training.config._validate_and_sync_distributed_optimizer_settings(
config: bridge.training.config.ConfigContainer,
) None#

Validate and synchronize distributed optimizer settings between DDP and optimizer configs.

This function ensures that distributed optimizer settings are consistent across DDP and optimizer configurations. If either setting is enabled, both will be enabled to maintain consistency.

Parameters:

config – The configuration container to validate and potentially modify.

bridge.training.config._validate_mixed_precision_consistency(
config: bridge.training.config.ConfigContainer,
) None#

Validate that mixed precision settings are consistent between model and optimizer configs.

Parameters:

config – The configuration container to validate.

Raises:

AssertionError – If precision settings are inconsistent in a way that would indicate ambiguous behavior.

bridge.training.config._validate_fine_grained_activation_offloading(
config: bridge.training.config.ConfigContainer,
) None#

Validate fine-grained activation offloading configuration.

This function ensures that fine-grained activation offloading is only enabled with compatible configurations (transformer_engine implementation) and that necessary environment variables are set for newer TE versions.

Parameters:

config – The configuration container to validate.

Raises:

ValueError – If fine-grained activation offloading is enabled with incompatible settings.