MoE Training Optimization#
This page covers the optimization framework, techniques, and best practices for training Mixture-of-Experts models with Megatron-Core. It is based on Scalable Training of MoE Models with Megatron Core.
For tuning knobs, hardware-specific configs, and benchmark-oriented guidance, see:
The Three Walls#
MoE training is constrained by three tightly coupled barriers:
Wall |
Root Cause |
Metric |
|---|---|---|
Memory |
All E experts stored but only K active per token |
GB per GPU |
Communication |
EP all-to-all dispatches tokens across GPUs |
% of step time |
Compute Efficiency |
Small expert GEMMs + host overhead |
GPU SM utilization |
These walls interact: solving one often exposes another. FP8 reduces memory and can improve GEMM throughput, but it can also shift more of the remaining cost to quantization kernels and host overhead. Communication overlap can hide latency, but it may add scheduling and buffer constraints. Effective optimization treats all three as a unified system.
The Optimization Workflow#
Phase 1: Establish Memory-Feasible Parallelism#
Memory is the hard gate — if it doesn’t fit, nothing runs.
Strategy |
Activation |
Weight |
Optimizer |
Comm |
|---|---|---|---|---|
TP |
1/d (with SP) |
1/d |
1/d |
High |
EP |
~1 (load-dependent) |
1/d (MoE only) |
1/d |
Medium |
PP |
1 (>1 with VPP) |
1/d |
1/d |
Medium |
CP |
1/d |
1 |
1/d† |
Medium |
DP |
1 |
1 |
1/d† |
Low |
†Requires --use-distributed-optimizer.
Quick test: use --fake-init-process-group to emulate distributed training on a
single GPU for rapid parallelism iteration before spending cluster time.
Phase 2: Select Optimal Parallelism#
Five guidelines, in order of priority:
Minimize model parallelism, maximize DP. Model parallelism adds communication; use distributed optimizer to free memory for larger DP.
Keep EP × TP within the fast interconnect domain. EP and TP are communication-intensive, so keeping the hot path on the fastest available links usually matters more than theoretical FLOPs.
Use PP for multi-node scaling. PP’s point-to-point comms scale better across nodes than TP/EP. Enable VPP to reduce bubbles.
Prefer EP over TP for expert layers. EP gives better GEMM efficiency, lower communication, and eliminates local permutation when EP = num_experts. Use Parallel Folding to decouple attention TP from expert EP.
Enable CP once sequence length makes attention memory dominant. In practice that often starts around the 8K-class regime, but the exact point depends on model size and hardware. Use hierarchical CP (
a2a+p2p) on NVL72-class systems when appropriate.
Phase 3: Profile and Optimize Bottlenecks#
Profile the training run and identify which wall dominates:
Memory bottleneck — forced into full recompute or excessive parallelism:
Optimization |
Overhead |
Flag |
|---|---|---|
FP8 training |
Low |
|
Selective recompute |
Low |
|
Precision-aware optimizer |
Low |
|
Activation offloading |
Medium |
|
Optimizer offloading |
Medium |
|
Communication bottleneck — profiling shows time in collectives:
Comm Type |
Fix |
|---|---|
DP grad/param |
|
TP |
|
EP dispatcher |
|
EP all-to-all |
|
PP send/recv |
|
CPU overhead — gaps between GPU kernels in Nsight traces:
Fix |
Flag |
|---|---|
Disable Python GC |
|
CUDA Graphs |
|
Reduce kernel launches |
Decrease TP or increase MBS |
Compute inefficiency — low SM utilization despite no comm/CPU issues:
Fix |
Flag |
|---|---|
Grouped GEMM |
|
Kernel fusions |
|
FP8 precision |
|
This process is iterative: fitting the model, choosing parallelism, and profiling the dominant wall usually matter more than any single micro-optimization.
Parallel Folding#
Attention and MoE layers have conflicting optimal parallelisms. Parallel Folding decouples their configurations.
Attention layers: TP × CP × DP × PP
MoE layers: ETP × EP × EDP × PP (PP must match)
Key benefits:
Breaks EP ≤ DP constraint: EP can “fold” across TP×CP groups
Independent optimization: Attention uses high TP; MoE uses ETP=1
Fewer GPUs needed: CP=8 and EP=8 share the same 8 GPUs
NVLink-local comms: Both CP and EP stay in high-bandwidth domain
Example: 256 GPUs with attention TP=4, CP=2, DP=8, PP=4. Traditional: EP ≤ DP = 8. With folding: EP=64, ETP=1, EDP=1.
Memory Optimization Stack#
Ordered by overhead (lowest first):
Memory-efficient permutation (zero overhead): Absorbs routing weights into activations before FC2, eliminating saved tensors for router backward.
Standard:
y = Σ p_i · W2_i · φ(W1_i · x)Memory-efficient:y = Σ W2_i · (p_i · φ(W1_i · x))Mathematically equivalent when experts have no bias. Eliminates saving each expert output for router backward — activation is recomputed from already-saved inputs.
FP8/FP4 activations: Store linear-layer inputs in lower precision than BF16. This usually gives a modest but useful activation-memory reduction.
Fine-grained recompute: Recompute only cheap operations such as LayerNorm, activation functions, or model-specific up-projection modules. This often recovers much of the needed memory while keeping overhead much lower than full-layer recompute.
Fine-grained offloading: Module-level D2H/H2D with stream overlap. This can free a meaningful amount of memory at a small throughput cost and may allow a better parallelism layout that more than repays the offload overhead.
Optimizer state offloading: Move optimizer states to CPU between steps. This is especially attractive on GB200-class systems, where the host-device path is strong enough to make the trade practical.
FSDP for MoE: Dual DeviceMesh — primary mesh for attention, expert mesh for MoE. AllGather/ReduceScatter stay within small EDP groups. Zero-copy comms via NCCL User Buffer Registration.
FP8 Recipe Selection#
Recipe |
Platform |
Granularity |
Recommended |
|---|---|---|---|
Per-tensor FP8 |
Hopper/Blackwell |
1 scale/tensor |
Starting point |
Blockwise FP8 |
Hopper |
128×128 blocks |
Production on Hopper |
MXFP8 |
Blackwell |
1×32 elements |
Default on Blackwell |
NVFP4 |
Blackwell |
16 elements, 2-level |
Maximum throughput |
Key rules:
Router stays in FP32 always
Embeddings, output layer, gradients, optimizer stay in original precision
Expert GEMMs are the primary quantization target
MXFP8 on Blackwell communicates params in BF16 (can’t save on AllGather)
NVFP4 requires Random Hadamard Transforms, 2D scaling, stochastic rounding
CUDA Graphs for MoE#
Two modes, different trade-offs:
Mode |
What’s Captured |
When to Use |
|---|---|---|
Full CUDA Graphs |
Entire fwd+bwd |
Drop-and-pad MoE only |
Partial (layer-wise) |
attn + router + moe_preprocess |
Dropless MoE (default) |
Partial CUDA graphs capture static components while leaving dynamic expert computation outside the graph, which is why they are the safer default for dropless MoE.
For full CUDA Graphs on dropless MoE, three techniques are needed:
Device-initiated Grouped GEMM: Reads shapes from GPU memory. cuBLASLt (CUDA 13.1+) or cuteDSL with fused activation/quantization.
ECHO (Elastic Cloning for Hot Experts): Clones hot experts to underutilized ranks via bin-packing. Reduces load variance so worst-case buffer sizing is closer to actual.
Paged Stashing: Single worst-case tmp buffer shared across layers for computation; paged stashing buffer stores only actual tokens. Reduces memory from O(layers × worst_case) to O(worst_case + actual). 64 tokens per page, free list via circular buffer.
Flexible Asymmetric VPP#
PP layout string controls per-stage layer distribution:
--pipeline-model-parallel-layout "Et*3|(tt|)*29m|L"
E= embedding,t= transformer,m= MTP,L= loss,|= stage boundaryBalance workload: embedding + N dense layers ≈ fewer MoE layers
Place MTP and loss on dedicated stages for memory isolation
EP Communication Overlap#
For exact overlap constraints and verification guidance, see docs/training/communication-overlap.md and the related MoE overlap skills.
Two overlap patterns for 1F1B:
Merged FWD-FWD / BWD-BWD: Same-type passes from two microbatches run in parallel. Costs 2× activation memory. Less overlap (fwd compute is half of bwd).
Merged FWD-BWD (preferred): Forward of microbatch i+1 overlaps with backward of microbatch i. No extra memory. Matches DualPipe design. Limited: first FWD and last BWD can’t be hidden.
Key optimization: W/D split. Split backward MLP work into weight-gradient and data-gradient pieces so the weight-gradient portion can overlap with forward compute when forward MLP alone is too short to hide dispatch cost.
MoE Token Dispatchers#
MoE models route tokens to experts via all-to-all communication. The dispatcher backend controls how this communication is implemented:
Dispatcher |
Backend |
Mechanism |
|---|---|---|
|
Standard MoE |
Torch-native all-to-all collectives |
|
DeepEP library |
Low-latency SM-based dispatch with GPU-side routing |
|
HybridEP library |
Fused intra-node NVLink + inter-node IB dispatch |
Hardware affinity#
Hardware |
Recommended Dispatcher |
Rationale |
|---|---|---|
H100 / B200 (NVL8) |
DeepEP |
Optimized for node-based topologies |
GB200 / GB300 (NVL72) |
HybridEP |
Exploits NVLink domain for lower latency |
HybridEP advantage usually grows with EP degree because it fuses intra-node NVLink transfers with inter-node IB work, avoiding much of the two-phase overhead of standard all-to-all at large EP sizes.
Long-Context MoE Training#
At long sequences (64K+), SDPA dominates FLOPs. Context parallelism (CP) is the primary mechanism for scaling sequence length.
CP sizing rules of thumb:
Start with CP ≈ seq_len / 4096: then round to a practical layout.
Keep DP ≥ 1: CP × EP × TP × PP must not exceed total GPUs.
Prefer selective recompute over full: Recompute
up_proj, norm, moe, mlprather than full recompute for better throughput.TP can sometimes substitute for some CP on NVLink systems: on NVL72 systems, higher TP can be competitive with a more CP-heavy plan.
Optimizer CPU offload is often critical at long context because activation pressure consumes so much of the memory budget.
Long-context recommendations:
Keep sub-sequence length ~4096–8192 per CP/TP shard
Don’t recompute SDPA at long context: SDPA recompute adds significant compute overhead while saving relatively little memory. Recompute non-SDPA modules instead.
TP preferred within node (fast comms, reduces param memory)
P2P CP preferred across nodes (natural overlap with attention)
a2a CP + TP within node when ring exchange is undesirable
Dynamic Context Parallelism#
For variable-length training (RL, SFT):
Per-microbatch CP sizing instead of static CP for all
Pre-constructs multiple CP groups during init (powers of 2)
Scheduler selects effective cp_size per microbatch
Works with packed sequences (THD format)
MoE VLM Training#
MoE vision-language models combine a vision encoder with a MoE language decoder. Training requires choosing between two strategies:
Approach |
Mechanism |
When to Use |
|---|---|---|
FSDP |
Shards params, grads, and optimizer across all GPUs |
Simpler setup and a better first bring-up path |
3D Parallel |
TP + PP + EP + DP |
Higher throughput ceiling once the multimodal path is already stable |
Key principles:
Always benchmark with real vision data — image-free mock runs can significantly overestimate throughput.
Freezing vision encoder saves compute when fine-tuning only the decoder.
MBS is critical for 3D-parallel VLM — larger micro-batch sizes often matter more than they do for text-only MoE.
FSDP is simpler and often competitive for initial bring-up.
Production Features Summary#
Feature |
Purpose |
|---|---|
Force-balance routing |
Even token distribution; best for benchmarking |
Aux-loss-free balancing |
Learnable expert bias; adapts over time |
Shared expert overlap |
Hides shared expert latency behind dispatch/combine |
LatentMoE |
Reduces comms and per-expert params by compression ratio α |
Distributed checkpoint |
Parallelism-agnostic save/load with automatic resharding |
Upcycling |
Convert dense checkpoint to MoE without retraining |
MTP |
Multi-token prediction with flexible VPP placement |
Muon optimizer |
Matrix-aware updates; fewer steps than AdamW |