bridge.training.config#
Module Contents#
Classes#
Megatron Core DistributedDataParallelConfig with deferred post-init. |
|
Megatron Core OptimizerConfig with deferred post-init. |
|
Configuration settings for distributed training initialization. |
|
Configuration for the rerun state machine used for result validation or stats. |
|
Base configuration for data loading. |
|
Interface that encapsulates framework internals. |
|
Context for providing config overrides. |
|
Abstract base class for providing config overrides. |
|
Abstract base class for custom dataset configurations. |
|
Megatron Core GPTDatasetConfig with deferred post-init. |
|
Configuration object forGPT FIM datasets |
|
Modifies GPTDatasetConfig to enforce necessary options for creating a mock dataset. |
|
Configuration specific to finetuning datasets, inheriting from DataloaderConfig. |
|
Configuration settings for the learning rate scheduler and weight decay. |
|
Configuration settings related to the training loop and validation. |
|
Configuration settings for model checkpointing (saving and loading). |
|
Configuration settings for logging, including TensorBoard and WandB. |
|
Configuration settings for profiling the training process. |
|
Configuration for Nvidia-DL-Framework-Inspect integration. |
|
Configuration settings related to fault tolerance mechanisms (NVIDIA internal use). |
|
Configuration settings for detecting and logging GPU stragglers. |
|
Configuration settings for NVIDIA Resiliency Extension straggler detection. |
|
Configuration settings for NVIDIA Resiliency Extension in-process restart functionality. |
|
Top-level container holding all configuration objects. |
Functions#
Determine the correct Mcore TransformerConfig parent class for a model. |
|
Get values that differ from Mcore parent class defaults. |
|
Get key configuration values for non-Mcore configs. |
|
Apply runtime configuration updates prior to initialization. |
|
MIMO-equivalent of |
|
Validate and synchronize distributed optimizer settings between DDP and optimizer configs. |
|
Validate that mixed precision settings are consistent between model and optimizer configs. |
|
Validate fine-grained activation offloading configuration. |
API#
- class bridge.training.config.DistributedDataParallelConfig#
Bases:
megatron.core.distributed.DistributedDataParallelConfigMegatron Core DistributedDataParallelConfig with deferred post-init.
This class inherits from Megatron Coreβs DistributedDataParallelConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.
- param_name_patterns_for_fp32_local_accumulation: Tuple[str, ...]#
()
fnmatch patterns selecting parameters whose gradients should be locally accumulated in FP32. The special pattern
'all'matches every parameter. Synced from MCore c586f6d56 (#4028); field will be inherited from the base class after the next mcore bump.
- __post_init__() None#
Skip MCore post_init during initial construction.
The original post_init logic is deferred until finalize() is called.
- finalize() None#
Execute the deferred MCore post-init logic.
This method calls the original Megatron Core DistributedDataParallelConfig.post_init() to compute derived fields based on the current field values.
- class bridge.training.config.OptimizerConfig#
Bases:
megatron.core.optimizer.OptimizerConfigMegatron Core OptimizerConfig with deferred post-init.
This class inherits from Megatron Coreβs OptimizerConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.
- __post_init__() None#
Skip MCore post_init during initial construction.
The original post_init logic is deferred until finalize() is called.
- finalize() None#
Execute the deferred MCore post-init logic.
This method calls the original Megatron Core OptimizerConfig.post_init() to compute derived fields based on the current field values.
- class bridge.training.config.DistributedInitConfig#
Bases:
megatron.training.config.DistributedInitConfigConfiguration settings for distributed training initialization.
- external_gpu_device_mapping: bool#
False
If True, indicates that GPU device mapping has been externally managed (e.g., via CUDA_VISIBLE_DEVICES environment variable). When True, uses device 0 instead of local rank for CUDA device selection. This is useful when launching with external process managers that handle GPU visibility.
- enable_megatron_core_experimental: bool#
False
Enable experimental features for Megatron Core.
- use_decentralized_pg: bool#
False
Use ProcessGroupCollection passed through functions instead of relying on mcoreβs global parallel state (mpu) variables. When True, parallel groups are obtained from the pg_collection object rather than the global megatron.core.parallel_state module.
- property lazy_init: bool#
- class bridge.training.config.RerunStateMachineConfig#
Bases:
megatron.training.config.RerunStateMachineConfigConfiguration for the rerun state machine used for result validation or stats.
- rerun_mode: Literal[disabled, validate_results, report_determinism_stats]#
βdisabledβ
Use re-run engine to validate results (default) or to emit stats on variability of computations due to non-deterministic algorithms.
- spiky_loss_factor: float#
10.0
Factor for detecting spiky loss. A loss is considered spiky if it exceeds this multiple of the max observed loss over the sample window.
- class bridge.training.config.DataloaderConfig#
Base configuration for data loading.
- dataloader_type: Optional[Literal[single, cyclic, batch, external]]#
None
Dataloader type: βsingleβ for single pass, βcyclicβ for multiple passes with shuffling, βbatchβ for global batch sampling (used in fine-tuning), or βexternalβ for custom dataloaders.
- num_workers: int#
2
Dataloader number of workers.
- data_sharding: bool#
True
Disable data sharding.
- pin_memory: bool#
True
Whether to pin memory during data loading for faster GPU training.
- drop_last: bool#
True
Whether to drop the last incomplete batch.
- persistent_workers: bool#
True
Whether to keep data loading workers persistent across epochs. Automatically set to False when num_workers is 0.
- trust_remote_code: Optional[bool]#
None
Whether remote code execution should be trusted for a given HF path.
- finalize()#
Finalize dataloader config field constraints.
- class bridge.training.config.DatasetBuildContext#
Interface that encapsulates framework internals.
This context provides metadata needed to build datasets while hiding implementation details of the framework.
.. attribute:: train_samples
Number of samples for training dataset
.. attribute:: valid_samples
Number of samples for validation dataset
.. attribute:: test_samples
Number of samples for test dataset
.. attribute:: tokenizer
Optional tokenizer instance for text processing
.. attribute:: pg_collection
Optional process group collection for distributed training
- train_samples: int#
None
- valid_samples: int#
None
- test_samples: int#
None
- tokenizer: Optional[megatron.bridge.training.tokenizers.tokenizer.MegatronTokenizer]#
None
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection]#
None
- class bridge.training.config.OptimizerConfigOverrideProviderContext#
Context for providing config overrides.
- scheduler_config: SchedulerConfig#
None
- optimizer_config: bridge.training.config.OptimizerConfig#
None
- model: Union[megatron.core.transformer.module.MegatronModule, list[megatron.core.transformer.module.MegatronModule]]#
None
- class bridge.training.config.OptimizerConfigOverrideProvider#
Abstract base class for providing config overrides.
- build_config_overrides( ) dict[megatron.core.optimizer.ParamKey, megatron.core.optimizer.ParamGroupOverride] | None#
Build config overrides for weight decay based on scheduler configuration.
This function creates parameter-specific overrides for weight decay behavior. By default, weight decay is skipped for bias parameters and 1D parameters. For Qwen3-Next models, weight decay is applied to q_layernorm and k_layernorm.
- Parameters:
context β OptimizerConfigOverrideProviderContext which packages the scheduler configuration, optimizer configuration, and model.
- Returns:
Dictionary of ParamKey to ParamGroupOverride for the optimizer
- class bridge.training.config.DatasetProvider#
Bases:
bridge.training.config.DataloaderConfig,abc.ABCAbstract base class for custom dataset configurations.
Provides an interface for users to implement their own dataset builders while automatically inheriting all DataloaderConfig functionality.
Users must:
Inherit from this class
Implement the build_datasets() method
.. rubric:: Example
@dataclass class S3DatasetConfig(DatasetProvider): bucket_name: str data_prefix: str seq_length: int
def build_datasets(self, context: DatasetBuildContext) -> Tuple[Optional[Any], Optional[Any], Optional[Any]]: # Custom implementation to load data from S3 train_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/train", context.tokenizer) valid_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/valid", context.tokenizer) test_ds = load_s3_dataset(self.bucket_name, f"{self.data_prefix}/test", context.tokenizer) return train_ds, valid_ds, test_ds- abstractmethod build_datasets( ) Tuple[Optional[Any], Optional[Any], Optional[Any]]#
Build train, validation, and test datasets.
This method is called by the framework during dataset initialization. Implementations should use the provided context to create appropriate datasets for each split.
- Parameters:
context β Build context with sample counts and tokenizer
- Returns:
Tuple of (train_dataset, valid_dataset, test_dataset) Any element can be None if that split shouldnβt be created.
- Raises:
NotImplementedError β Must be implemented by subclasses
- class bridge.training.config.GPTDatasetConfig(
- seq_length: int | None = None,
- skip_getting_attention_mask_from_dataset: bool = True,
- data_path: str | list[str] | None = None,
- *args,
- **kwargs,
Bases:
megatron.core.datasets.gpt_dataset.GPTDatasetConfig,bridge.training.config.DataloaderConfigMegatron Core GPTDatasetConfig with deferred post-init.
This class inherits from MCoreβs GPTDatasetConfig and DataloaderConfig but defers the execution of post_init() until finalize() is explicitly called. This allows for field modifications after construction but before computed fields are calculated.
Initialization
- Parameters:
seq_length (int | None) β the sequence length. If not provided,
sequence_lengthmust be in kwargs.skip_getting_attention_mask_from_dataset (bool) β if set, the dataset will pass a None attention mask and the attention mask is autogenerated from the attn backend.
data_path β CLI-friendly data path(s). Converted to
blendinfinalize().
- data_path: str | list[str] | None#
None
CLI-friendly alternative to
blend. Accepts a single path string, a space-separated multi-path string, or a list of paths (with optional interleaved weights, matching Megatron-LM--data-pathsemantics). Converted toblendautomatically duringfinalize().
- __post_init__() None#
Skip MCore post_init during initial construction.
The original post_init logic is deferred until finalize() is called.
- property seq_length#
Alias for MCoreβs
sequence_lengthfield.
- finalize() None#
Execute the deferred MCore post-init logic and Bridge-specific checks.
This method calls the original Megatron Core GPTDatasetConfig.post_init() and then performs Bridge-specific validation.
- class bridge.training.config.GPTFIMDatasetConfig(
- fim_rate: float = None,
- fim_spm_rate: float = None,
- fim_extra_tokens: Dict = None,
- fim_split_sample: Optional[str] = None,
- fim_fragment_rate: Optional[float] = None,
- fim_no_prefix: Optional[str] = None,
- **kwargs,
Bases:
bridge.training.config.GPTDatasetConfigConfiguration object forGPT FIM datasets
Initialization
- Parameters:
fim_rate β float: probability to convert a training sample into a FIM format.
fim_spm_rate (float) β probability that the a FIM sample uses the SPM format over the PSM format.
fim_extra_tokens (Dict) β should consist of prefix, middle, suffix, PAD, and EOD tokens.
fim_split_sample (str) β string around which to split the sample for FIM.
fim_fragment_rate (float) β rate of FIM on each fragment when split_sample is not None.
fim_no_prefix (str) β do not apply FIM to fragments that start with this prefix.
- class bridge.training.config.MockGPTDatasetConfig(seq_length: int, **kwargs)#
Bases:
bridge.training.config.GPTDatasetConfigModifies GPTDatasetConfig to enforce necessary options for creating a mock dataset.
Initialization
- Parameters:
seq_length (int | None) β the sequence length. If not provided,
sequence_lengthmust be in kwargs.skip_getting_attention_mask_from_dataset (bool) β if set, the dataset will pass a None attention mask and the attention mask is autogenerated from the attn backend.
data_path β CLI-friendly data path(s). Converted to
blendinfinalize().
- finalize()#
- class bridge.training.config.FinetuningDatasetConfig#
Bases:
bridge.training.config.DataloaderConfigConfiguration specific to finetuning datasets, inheriting from DataloaderConfig.
Note: For fine-tuning, dataloader_type defaults to βbatchβ which ensures sequences within each global batch are padded to the same length.
- dataloader_type: Optional[Literal[single, cyclic, batch, external]]#
βbatchβ
Dataloader type for fine-tuning. Defaults to βbatchβ for optimal padding behavior.
- dataset_root: Optional[Union[str, pathlib.Path]]#
None
- seq_length: int#
None
- seed: int#
1234
- memmap_workers: int#
1
- max_train_samples: Optional[int]#
None
- packed_sequence_specs: Optional[megatron.bridge.data.datasets.packed_sequence.PackedSequenceSpecs]#
None
- dataset_kwargs: Optional[dict[str, Any]]#
None
- do_validation: bool#
True
- do_test: bool#
True
- class bridge.training.config.SchedulerConfig#
Bases:
megatron.training.config.SchedulerConfigConfiguration settings for the learning rate scheduler and weight decay.
- finalize() None#
Post-initialization checks for scheduler config.
- class bridge.training.config.TrainingConfig#
Bases:
megatron.training.config.TrainingConfigConfiguration settings related to the training loop and validation.
- check_optimizer_step_success: bool#
True
Checks optimizer.step() succeeded at each training step .
- skip_sync_grad_norm_across_mp: bool#
False
Skips syncing the grad norm across the model parallel group.
- eval_iters: int | None#
None
Number of iterations to run for evaluation validation/test for. Deprecated in favor of ValidationConfig.
- eval_interval: int | None#
None
Interval between running evaluation on validation set. Deprecated in favor of ValidationConfig.
- skip_train: bool | None#
None
If set, bypass the training loop, optionally do evaluation for validation/test, and exit. Deprecated in favor of ValidationConfig.
- finalize() None#
Validate training mode specification and calculate train_iters from train_samples if needed.
- class bridge.training.config.CheckpointConfig#
Bases:
megatron.training.config.CheckpointConfigConfiguration settings for model checkpointing (saving and loading).
- pretrained_checkpoint: Optional[str]#
None
Directory containing a pretrained model checkpoint for finetuning.
This can be either:
A parent checkpoint directory (e.g.
/checkpoints/my_model/) that contains tracker files (latest_train_state.pt) anditer_*subdirectories.A specific iteration directory (e.g.
/checkpoints/my_model/iter_0001000/) that directly contains the checkpoint payload (run_config.yaml, weight shards, etc.).
- storage_writers_per_rank: int#
1
Number of storage writers per rank for torch_dist checkpoint format. Affects the number of checkpoint files: saving_ranks * storage_writers_per_rank.
- use_persistent_ckpt_worker: bool#
True
Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated worker thread/process for handling async saves. When disabled, uses temporal workers that are created and destroyed for each save operation.
- async_strategy: str#
βnvrxβ
Async checkpoint strategy to use. Options:
"nvrx"(default) or"mcore". The"nvrx"strategy uses nvidia_resiliency_ext for async checkpointing and falls back to"mcore"if the package is not installed.
- async_write_results_mp_mode: str#
βforkβ
Multiprocessing start method for the async write results queue. Options:
"fork"(default),"spawn","forkserver".
- strict_fsdp_dtensor_load: bool#
False
Whether to enforce strict loading for FSDP DTensor checkpoints. When False, allows partial loading.
- custom_manager_class: str | None#
None
Fully qualified class name for a custom CheckpointManager implementation.
When set, checkpoint operations will instantiate and delegate to this class instead of the default checkpoint manager. The custom class must implement the
CheckpointManagerprotocol defined inmegatron.bridge.training.checkpointing.Example:
'mypackage.checkpoint.MyCheckpointManager'
- finalize() None#
Post-initialization checks for checkpoint config.
- class bridge.training.config.LoggerConfig#
Bases:
megatron.training.config.LoggerConfigConfiguration settings for logging, including TensorBoard and WandB.
- skip_train_metrics_log: bool#
False
Skips logging of training metrics to all logging backends and to the console as well.
- timing_log_level: Literal[-1, 0, 1, 2]#
0
Granularity level to measure and report timing. -1: To disable timing logging as the timer start from 0 and above. 0: report only iteration time and make sure timing does not introduce extra overhead. 1: report timing for operations that are executed very limited times (basically once) during each iteration (such as gradient all-reduce) 2: report timing for operations that migh be executed numerous times during each iteration. Note that setting the level to 1 or 2 might cause increase in iteration time.
- mlflow_experiment: Optional[str]#
None
The MLFlow experiment name.
- mlflow_run_name: Optional[str]#
None
The MLFlow run name.
- mlflow_tracking_uri: Optional[str]#
None
Optional MLFlow tracking URI.
- mlflow_tags: Optional[dict[str, str]]#
None
Optional tags to apply to the MLFlow run.
- comet_project: Optional[str]#
None
The Comet ML project name. Comet logging is disabled when this is None.
- comet_experiment_name: Optional[str]#
None
The Comet ML experiment name.
- comet_workspace: Optional[str]#
None
The Comet ML workspace. If not set, uses the default workspace for the API key.
- comet_api_key: Optional[str]#
None
The Comet ML API key. Can also be set via COMET_API_KEY environment variable.
- comet_tags: Optional[list[str]]#
None
Optional list of tags to apply to the Comet ML experiment.
- logging_level: int#
None
Set default logging level
- finalize() None#
Validate logger settings and optional MLFlow dependency.
- class bridge.training.config.ProfilingConfig#
Bases:
megatron.training.config.ProfilingConfigConfiguration settings for profiling the training process.
- finalize() None#
Validate profiling configuration.
- class bridge.training.config.TensorInspectConfig#
Configuration for Nvidia-DL-Framework-Inspect integration.
- enabled: bool#
False
Enable tensor inspection and statistics collection.
- features: dict[str, Any] | str | pathlib.Path | None#
None
Feature configuration as a Python dict or a YAML file path.
- feature_dirs: list[str] | None#
None
Directories containing feature implementations (searched recursively).
- log_dir: str | None#
None
Root directory to store inspection logs/statistics. Defaults to checkpoint save dir if unset.
- init_training_step: int#
0
Initial training step for the inspector (used when resuming).
- finalize() None#
Populate sensible defaults when inspection is enabled.
If feature_dirs is unset, default to the installed TransformerEngine debug features package path (transformer_engine.debug.features), when available.
- class bridge.training.config.FaultToleranceConfig#
Configuration settings related to fault tolerance mechanisms (NVIDIA internal use).
- enable_ft_package: bool#
False
If set, Fault Tolerance package is enabled. Note: This feature is for Nvidia internal use only.
- calc_ft_timeouts: bool#
False
If set, FT package will try to automatically compute the timeouts. Note: This feature is for Nvidia internal use only.
- simulate_fault: bool#
False
Sets a simulated fault for fault tolerance. NOTE: This if for fault tolerance testing only.
- simulated_fault_type: Literal[rank_hung, rank_killed, random]#
βrandomβ
How the simulated fault should behave. βrandomβ will randomly choose one of the other two options.
- simulated_fault_rank: Optional[int]#
None
Rank on which simulated fault should occur.
- simulated_fault_base_delay: int#
0
Base delay before simulated fault thread is started. A small random delay is added to this.
- class bridge.training.config.StragglerDetectionConfig#
Bases:
megatron.training.config.StragglerDetectionConfigConfiguration settings for detecting and logging GPU stragglers.
- enable_straggler_on_startup: bool#
True
If set, StragglerDetector is enabled on startup.
- class bridge.training.config.NVRxStragglerDetectionConfig#
Configuration settings for NVIDIA Resiliency Extension straggler detection.
- enabled: bool#
False
Enable NVRx straggler detection.
- report_time_interval: float#
300.0
Interval [seconds] of the straggler check.
- calc_relative_gpu_perf: bool#
True
Calculate relative GPU performance scores.
- calc_individual_gpu_perf: bool#
True
Calculate individual GPU performance scores.
- num_gpu_perf_scores_to_print: int#
5
How many best and worst perf scores to print (0 - does not print periodically, but only if stragglers are detected).
- gpu_relative_perf_threshold: float#
0.7
Threshold for relative GPU performance scores.
- gpu_individual_perf_threshold: float#
0.7
Threshold for individual GPU performance scores.
- stop_if_detected: bool#
False
Set to True, to terminate the workload if stragglers are detected.
- enable_logging: bool#
True
Set to True, to log GPU performance scores.
- profiling_interval: int#
1
Profiling interval passed to straggler.Detector.initialize.
- logger_name: str#
βmegatron.bridge.NVRxStragglerDetectionβ
Logger name for straggler detection messages.
- finalize() None#
Validate NVRx straggler detection configuration.
- class bridge.training.config.InProcessRestartConfig#
Configuration settings for NVIDIA Resiliency Extension in-process restart functionality.
- enabled: bool#
False
Enable in-process restart mechanism from nvidia-resiliency-ext.
- max_iterations: Optional[int]#
None
Maximum number of in-process restart iterations.
- monitor_thread_interval: float#
1.0
Monitoring interval (in seconds) for the monitoring thread.
- monitor_process_interval: float#
1.0
Monitoring interval (in seconds) for the monitoring process.
- progress_watchdog_interval: float#
1.0
Interval (in seconds) for automatic progress watchdog timestamp updates.
- heartbeat_interval: float#
30.0
Monitoring interval (in seconds) for detecting unresponsive ranks.
- soft_timeout: float#
60.0
Soft progress timeout (in seconds).
- hard_timeout: float#
90.0
Hard progress timeout (in seconds).
- heartbeat_timeout: float#
60.0
Timeout (in seconds) for a missing rank detection heartbeat.
- barrier_timeout: float#
120.0
Timeout (in seconds) for internal distributed barrier.
- completion_timeout: float#
120.0
Timeout (in seconds) for barrier on completion on all ranks.
- last_call_wait: float#
1.0
Time interval (in seconds) for other ranks to report concurrent terminal failures.
- termination_grace_time: float#
1.0
Interval (in seconds) between SIGTERM and SIGKILL issued on hard timeout.
- granularity: Literal[node, rank]#
βnodeβ
Granularity for in-process restart.
- active_world_size: Optional[int]#
None
The number of ranks initially executing the workload. The remaining ranks from the allocation are set aside as warm reserve. If None, defaults to WORLD_SIZE environment variable.
- empty_cuda_cache: bool#
True
Empty CUDA cache during restart finalization.
- max_rank_faults: Optional[int]#
None
Maximum number of rank faults allowed before terminating the job.
- monitor_process_logdir: Optional[str]#
None
Directory for monitor process log files. If None, monitor process logging is disabled.
- class bridge.training.config.ConfigContainer#
Bases:
megatron.bridge.training.utils.config_utils._ConfigContainerBaseTop-level container holding all configuration objects.
- rng: megatron.training.config.RNGConfig#
βfield(β¦)β
- rerun_state_machine: bridge.training.config.RerunStateMachineConfig#
βfield(β¦)β
- train: bridge.training.config.TrainingConfig#
None
- model: megatron.bridge.models.GPTModelProvider | megatron.bridge.models.T5ModelProvider | megatron.bridge.models.mamba.mamba_provider.MambaModelProvider | megatron.bridge.models.mimo.mimo_provider.MimoModelProvider | megatron.bridge.models.gpt.gpt_builder.GPTModelConfig | megatron.bridge.models.mamba.mamba_builder.MambaModelConfig#
None
- optimizer: bridge.training.config.OptimizerConfig#
None
- optimizer_config_override_provider: bridge.training.config.OptimizerConfigOverrideProvider#
βfield(β¦)β
- ddp: bridge.training.config.DistributedDataParallelConfig#
βfield(β¦)β
- validation: megatron.training.config.ValidationConfig#
βfield(β¦)β
- scheduler: bridge.training.config.SchedulerConfig#
None
- dataset: bridge.training.config.GPTDatasetConfig | bridge.training.config.FinetuningDatasetConfig | bridge.training.config.DatasetProvider#
None
- logger: bridge.training.config.LoggerConfig#
None
- tokenizer: megatron.bridge.training.tokenizers.config.TokenizerConfig#
None
- checkpoint: bridge.training.config.CheckpointConfig#
None
- dist: bridge.training.config.DistributedInitConfig#
βfield(β¦)β
- ft: Optional[bridge.training.config.FaultToleranceConfig]#
None
- straggler: Optional[bridge.training.config.StragglerDetectionConfig]#
None
- nvrx_straggler: Optional[bridge.training.config.NVRxStragglerDetectionConfig]#
None
- profiling: bridge.training.config.ProfilingConfig#
βfield(β¦)β
- peft: Optional[megatron.bridge.peft.base.PEFT]#
None
- comm_overlap: Optional[megatron.bridge.training.comm_overlap.CommOverlapConfig]#
None
- mixed_precision: Optional[Union[megatron.bridge.training.mixed_precision.MixedPrecisionConfig, str]]#
None
- tensor_inspect: bridge.training.config.TensorInspectConfig | None#
None
- inprocess_restart: Optional[bridge.training.config.InProcessRestartConfig]#
None
- get_data_parallel_size(world_size: int) int#
Calculate the data parallel size based on the model configuration.
- set_data_parallel_size() None#
Calculate and set data_parallel_size for this config and comm_overlap config.
This method calculates the data parallel size needed by setup methods, without triggering full validation or finalization of Megatron Core configs.
- _validate_and_apply_deterministic_mode() None#
Apply and validate deterministic mode requirements.
This enforces restrictions and settings that must hold when the model is configured to run in deterministic mode.
- validate() None#
Performs validation checks on the combined configuration.
Calculates dependent values like data_parallel_size and scheduler steps. Ensures compatibility between different configuration settings.
- _validate_cp_comm_type() None#
Validate cp_comm_type and hierarchical_context_parallel_sizes consistency.
- _validate_training_scheduler_compatibility() None#
Cross-validation between training and scheduler configs.
- _calculate_scheduler_steps() None#
Calculate scheduler steps for both iteration-based and sample-based training.
- log_non_default_values() None#
Log configuration values that differ from Megatron Core defaults.
For configs that inherit from Megatron Core (e.g., OptimizerConfig, DDPConfig, TransformerConfig), this method logs only the values that differ from the Mcore defaults. This makes it easier to spot unintended deviations from baseline settings.
For configs that donβt inherit from Mcore, key values are logged via
_get_key_config_values, which excludes None values and callables.
- bridge.training.config._get_mcore_transformer_parent(model_config: Any) type#
Determine the correct Mcore TransformerConfig parent class for a model.
Some models (e.g., DeepSeek v2/v3) inherit from MLATransformerConfig instead of the base TransformerConfig. This function checks the inheritance chain to find the appropriate Mcore class to use as the baseline for comparison.
- Parameters:
model_config β The model configuration object.
- Returns:
The appropriate Mcore TransformerConfig class (MCoreMLATransformerConfig or MCoreTransformerConfig).
- bridge.training.config._get_non_default_values(
- config_obj: Any,
- mcore_class: type,
Get values that differ from Mcore parent class defaults.
- Parameters:
config_obj β The config object to compare.
mcore_class β The Megatron Core parent class to compare against.
- Returns:
Dictionary mapping field name to (current_value, default_value) for non-default fields.
- bridge.training.config._get_key_config_values(
- config_obj: Any,
Get key configuration values for non-Mcore configs.
- Parameters:
config_obj β The config object to extract values from.
- Returns:
Dictionary mapping field name to value for key fields.
- bridge.training.config.runtime_config_update( ) None#
Apply runtime configuration updates prior to initialization.
This function handles all configuration modifications that need to happen after initial config creation but before final validation and model setup.
Steps:
Resolve mixed precision configuration from string if needed
Apply mixed precision settings to model, optimizer, and DDP configs
Calculate data parallel size (needed for comm overlap)
Apply communication overlap configuration
Validate configuration after all modifications
- Parameters:
cfg β Configuration container to update
- bridge.training.config.mimo_runtime_config_update( ) None#
MIMO-equivalent of
runtime_config_update.The standard
runtime_config_updatecannot be used directly because it accessescfg.modelattributes (bf16,tensor_model_parallel_size,cuda_graph_impl, β¦) that do not exist onMimoModelProvider.This function cherry-picks the safe, model-agnostic parts:
Keeps (safe for MIMO):
data_parallel_size = 1(MIMO-specific hard-code)Sub-config finalization (optimizer, ddp, logger, train, scheduler, checkpoint)
Distributed optimizer sync validation
Deterministic mode validation
Skips (would crash or is N/A):
Mixed precision resolution (per-module, not container-level)
Communication overlap setup (not supported for MIMO)
Model-level validations (FSDP, CUDA graphs, TE RNG tracker sync, etc.)
See
playground/runtime_config_update_analysis.mdfor the full analysis.
- bridge.training.config._validate_and_sync_distributed_optimizer_settings( ) None#
Validate and synchronize distributed optimizer settings between DDP and optimizer configs.
This function ensures that distributed optimizer settings are consistent across DDP and optimizer configurations. If either setting is enabled, both will be enabled to maintain consistency.
- Parameters:
config β The configuration container to validate and potentially modify.
- bridge.training.config._validate_mixed_precision_consistency( ) None#
Validate that mixed precision settings are consistent between model and optimizer configs.
- Parameters:
config β The configuration container to validate.
- Raises:
AssertionError β If precision settings are inconsistent in a way that would indicate ambiguous behavior.
- bridge.training.config._validate_fine_grained_activation_offloading( ) None#
Validate fine-grained activation offloading configuration.
This function ensures that fine-grained activation offloading is only enabled with compatible configurations (transformer_engine implementation) and that necessary environment variables are set for newer TE versions.
- Parameters:
config β The configuration container to validate.
- Raises:
ValueError β If fine-grained activation offloading is enabled with incompatible settings.