Megatron-LM to Megatron Bridge Guide#
Megatron Bridge is Python-first: configure models, data, and training via typed Python APIs. All configuration lives in a structured ConfigContainer
(see Configuration overview). Any field can be overridden from the command line using Hydra/OmegaConf syntax in the example training scripts.
Quick start#
Run your example training entrypoint and override config keys directly:
python examples/recipes/llama/pretrain_llama3_8b.py \
train.micro_batch_size=2 \
train.global_batch_size=128 \
model.num_layers=32 model.hidden_size=4096 model.num_attention_heads=32 \
model.max_position_embeddings=4096 \
dataset.sequence_length=4096 \
checkpoint.save=/workspace/ckpts checkpoint.save_interval=1000 \
logger.wandb_project=my_proj logger.wandb_exp_name=exp1
Notes:
Config groups are nested:
rng
,train
,model
,optimizer
,ddp
,scheduler
,dataset
,logger
,tokenizer
,checkpoint
,dist
,profiling
,peft
,comm_overlap
,mixed_precision
,inprocess_restart
.After overrides are applied, runtime validation computes any dependent fields (e.g., data-parallel size, scheduler steps) and checks consistency.
Mapping Megatron-LM arguments to Megatron Bridge config#
Below is a concise mapping from common megatron-lm/megatron/training/arguments.py
flags to the new dataclass fields. If a field is not listed here (e.g., highly model-specific knobs), it typically lives under model.*
, optimizer.*
, dataset.*
, or tokenizer.*
with similar names.
Model topology and parallelisms#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
TP degree. |
|
|
PP degree. |
|
|
CP degree. |
|
|
EP degree. |
|
|
Expert TP degree. |
|
|
Enable sequence parallelism. |
|
|
Asymmetric PP: embeddings. |
|
|
Asymmetric PP: loss. |
Model architecture knobs#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Untie embeddings/output. |
|
|
|
|
|
Fraction of rotary dims. |
|
|
RoPE base. |
|
|
RoPE interpolation factor. |
|
|
LayerNorm/RMSNorm, etc. |
|
|
Enable SwiGLU MLP. |
|
|
Epsilon for norm layers. |
|
|
Number of transformer layers. |
|
|
Model hidden size. |
|
|
MLP expansion size. |
|
|
Attention heads. |
|
|
Key/Value channels per head. |
|
|
Set groups (enable GQA). |
|
|
Number of query groups. |
|
|
Enable QK LayerNorm. |
|
|
Max model sequence length. |
|
|
Alias used by HF conversions. |
|
|
TP padding multiple. |
|
|
Disable linear bias. |
|
|
Use FlashAttention backend. |
|
|
Weight init standard deviation. |
|
|
Attention dropout. |
|
|
Hidden dropout. |
MoE#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Experts per MoE layer. |
|
|
Expert MLP hidden size. |
|
|
e.g., aux_loss or seq_aux_loss. |
|
|
Top-k experts per token. |
|
|
Pre-softmax routing. |
|
|
Grouped GEMM for MoE. |
|
|
Aux loss coefficient. |
|
|
Token dispatcher: alltoall or flex. |
|
|
Enable DeepEP optimizations. |
|
|
Enable MoE permute fusion. |
|
|
Enable MoE router fusion. |
|
|
Router dtype (e.g., fp32). |
Mixed precision#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Select a mixed-precision recipe; sets |
Mixed precision is selected via the mixed_precision
config key (e.g., preset names like bf16_mixed
, bf16
, or fp16
, depending on your codebase) and is applied to model
, optimizer
, and ddp
during runtime_config_update
.
Training#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Per-rank batch size before gradient accumulation. |
|
|
Total batch across DP and micro-batches. |
|
|
Total training samples (sample-based mode). |
|
|
Start size, increment, and sample count for linear batch ramp-up. |
|
|
Adjust GBS to remain divisible when DP changes. |
|
|
PyTorch CUDA empty_cache cadence (0, 1, or 2). |
|
|
Interval to validate DP weight consistency. |
|
|
Number of training iterations. |
|
|
Exit when iteration % interval == 0. |
|
|
Exit after N minutes. |
|
|
Save and shut down on SIGTERM. |
|
|
Enable manual Python GC scheduling. |
|
|
Steps between manual GC runs. |
|
|
Disable GC at eval boundaries. |
|
|
Eval iterations per validation run. |
|
|
Steps between validations. |
|
|
Skip training loop (eval-only). |
Scheduler / Regularization#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
LR schedule: constant/linear/cosine/ISR/WSD. |
|
|
Iterations over which to decay LR. |
|
|
WSD anneal style. |
|
|
Iterations for WSD anneal phase. |
|
|
Warmup as fraction of decay span. |
|
|
Warmup iterations (absolute). |
|
|
Initial LR at start of warmup. |
|
|
Samples over which to decay LR (sample-based training). |
|
|
Warmup samples (sample-based training). |
|
|
Base learning rate. |
|
|
Minimum learning rate. |
|
|
Gradient clipping value. |
|
|
Weight decay. |
|
|
Adam beta1. |
|
|
Adam beta2. |
|
|
Ignore ckpt scheduler and use config. |
|
|
Load scheduler from checkpoint. |
|
|
WD at start (non-constant modes). |
|
|
WD at end (non-constant modes). |
|
|
WD schedule: constant/linear/cosine. |
Checkpointing#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Directory to write checkpoints. |
|
|
Iterations between persistent saves. |
|
|
Do not save optimizer state. |
|
|
Do not save RNG state. |
|
|
Directory to load from. |
|
|
Do not load optimizer state. |
|
|
Load FP32 main params directly. |
|
|
Do not load RNG state. |
|
|
Frequency for ephemeral saves. |
|
|
Kind of ephemeral checkpoint (global/local/memory). |
|
|
Dir for global ephemeral saves. |
|
|
Dir for local-per-rank ephemeral saves. |
|
|
Local save algorithm selection. |
|
|
Load weights, reset iters, no optim/rng. |
|
|
Path to pretrained weights for finetune/SFT. |
|
|
Explicit step to load. |
|
|
Override model args from checkpoint metadata. |
|
|
Exit if |
|
|
Format: torch_dist/zarr/fsdp_dtensor. |
|
|
Conversion target format. |
|
|
Output dir for converted ckpt. |
|
|
Disable DP-parallel save. |
|
|
Enable async saves (torch_dist only). |
|
|
Background worker for async saves. |
|
|
Enable DP-parallel load. |
|
|
Optimize for fixed structure. |
|
|
Handling of key mismatches on load. |
|
|
Auto-detect checkpoint format on load. |
|
|
Enable replication of local checkpoints. |
|
|
Spacing between replica ranks. |
|
|
Number of replicas. |
|
|
Relax FSDP-DTensor strict load. |
Logging#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Steps between console logs. |
|
|
Compute and log parameter L2 norm. |
|
|
Log tokens/sec per GPU. |
|
|
Write progress.txt with tokens and FLOPs. |
|
|
0=min; 1=coarse ops; 2=many ops. |
|
|
max/minmax/all across ranks. |
|
|
TensorBoard log directory. |
|
|
Steps between TB events. |
|
|
Pending TB event queue size. |
|
|
Write timers to TB. |
|
|
Disable loss-scale TB logs. |
|
|
Write validation perplexity (ppl) to TB. |
|
|
Enable memory stats in TB. |
|
|
Log world size in TB. |
|
|
Weights & Biases project. |
|
|
Weights & Biases entity/team. |
|
|
Run name in W&B. |
|
|
Local directory for W&B artifacts. |
|
|
Python logging level (e.g., 20=INFO). |
|
|
Log energy in Joules (if available). |
RNG / Initialization#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Global random seed. |
|
|
Enable per-DP-rank random init. |
|
|
Use TE RNG (needed for CUDA graphs). |
|
|
RNG tuned for inference stability. |
Distributed init and topology#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Process group backend (nccl/gloo). |
|
|
PG init and collective timeout. |
|
|
Launch DP reduces independently per PP stage. |
|
|
Disable auxiliary Gloo PG creation. |
|
|
Enable SHARP collectives for DP PG. |
|
|
Which DP group enables SHARP. |
|
|
Use high-priority comm streams for groups. |
|
|
Use TP-PP-DP rank ordering at init. |
Additional distributed/optimizer overlap settings:
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Enable distributed optimizer; settings are synchronized. |
|
|
Overlap DP gradient reduce-scatter. |
|
|
Overlap parameter all-gather with fprop. |
Profiling#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Enable nsys profiling (capture is controlled via external CLI). |
|
|
Enable PyTorch profiler (TB-friendly). |
|
|
Global step to start profiling. |
|
|
Global step to stop profiling. |
|
|
Global ranks to profile. |
|
|
Track memory history. |
|
|
Output path for memory snapshot. |
(shapes) |
|
Record tensor shapes (overhead). |
In-process restart#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Enable nvrx in-process restart. |
|
|
Max restart attempts. |
|
|
Monitor thread polling interval. |
|
|
Monitor process polling interval. |
|
|
Auto progress timestamp update cadence. |
|
|
Unresponsive-rank heartbeat cadence. |
|
|
Soft progress timeout. |
|
|
Hard timeout until kill. |
|
|
Missing heartbeat timeout. |
|
|
Timeout for internal barriers. |
|
|
Timeout for completion barrier. |
|
|
Delay to collect terminal failures. |
|
|
SIGTERM→SIGKILL grace period. |
|
|
Restart granularity (node/rank). |
|
|
Active ranks count; rest are reserve. |
|
|
Empty CUDA cache on restart finalize. |
Straggler detection#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Track and log straggler GPUs. |
|
|
Start with straggler detector disabled. |
|
|
Controller port for toggling. |
|
|
Num ranks to report for min/max throughput. |
Rerun state machine#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Frequency of injected validation perturbations. |
|
|
Kind of injection (correct/transient/persistent). |
|
|
Disabled/validate_results/report_stats. |
Data / Tokenizer args#
megatron-lm arguments |
Megatron Bridge config |
Description |
---|---|---|
|
|
Tokenizer implementation (e.g., HuggingFaceTokenizer). |
|
|
Model name/path for tokenizer. |
|
|
DataLoader workers. |
|
|
Use backend-generated masks. |