Mixed-Precision Training
Mixed-Precision Training
NeMo AutoModel uses FSDP2’s MixedPrecisionPolicy to control compute precision during forward and backward, and the model’s storage dtype (model.torch_dtype) to control the precision of the resident sharded parameter. Together, these decide what numeric precision the optimizer state ends up in, which is the part that determines whether long full-parameter training runs converge cleanly.
This page describes the precision patterns we recommend, and the trap that sits between them. For any long full-parameter training run (pre-training or extended fine-tuning), the key rule is: do not combine torch.optim.AdamW with bf16 resident parameters unless you have explicitly accepted bf16 Adam state.
Storage Dtype vs. Compute Dtype
Precision settings have distinct effects:
When model.torch_dtype: bfloat16 is used with torch.optim.AdamW, the AdamW EMA buffers (exp_avg and exp_avg_sq) are also stored in bf16. This is fragile for long full-parameter training runs: bf16 has only a 7-bit mantissa, so small EMA updates can be rounded away even though the values themselves are still in range. Symptoms range from silent degradation (slower convergence, i.e., higher final loss at the same step count, with no visible instability) to overt failure: unstable grad_norm, loss bumps, loss spikes, or divergence.
Recommended fp32 Master Weight and Optimizer-State Patterns
Use one of these patterns for long full-parameter training (pre-training or extended fine-tuning):
Both patterns keep forward/backward compute in bf16 through FSDP2 mixed precision. They differ in where the fp32 master weight lives, how much peak memory they use for a specific model, and what dtype is written to the training checkpoint.
Pattern A: TE FusedAdam and bf16 Model Storage
Use Transformer Engine FusedAdam when it has been validated for the model. The resident model parameters remain bf16, so model checkpoints can stay bf16, while TE keeps the optimizer’s master weights and Adam EMA buffers in fp32.
TE FusedAdam is the cleanest way to request fp32 optimizer state without making the resident model parameter fp32.
It is a common misconception that TE costs extra memory for a second fp32 master. With store_param_remainders: true (as above), TE does not keep a full extra fp32 master: it stores the bf16 parameter plus a 16-bit remainder that together reconstruct the fp32 master, costing the same ~4 bytes/param as torch AdamW’s fp32 resident parameter. The fp32 Adam EMA buffers (exp_avg, exp_avg_sq) are the same in both patterns, so the two optimizers’ steady-state footprints are essentially equal. The practical difference is then simplicity (torch AdamW: no TE dependency, fewer moving parts) vs. keeping model storage and checkpoints in bf16 (TE).
Where the two can differ is the gradient buffer: TE keeps the resident parameter in bf16, so its gradients are bf16, whereas torch AdamW with fp32 storage keeps parameters and gradients in fp32 (~2 bytes/param more on the gradient buffer). In practice, what we measure is peak memory, which is usually dominated by activations rather than the optimizer step, so depending on the model (fraction of intrinsically-fp32 params, fragmentation, where the peak falls), TE can come out lower, equal, or higher. Validate memory per model before making it the default.
Pattern B: torch AdamW and fp32 Model Storage
This is the PyTorch AdamW version of the master-weights pattern. Forward/backward run in bf16, the all-reduce / reduce-scatter runs in fp32, and the optimizer applies updates against fp32 resident parameters. With torch.optim.AdamW, the resident fp32 sharded parameter is the master weight, so there is no separate fp32 master-weight copy.
The trade-off is the dtype of the training checkpoint: AutoModel’s training (DCP) checkpoint stores the resident model parameters, so this pattern writes them in fp32, which you keep for exact resume. This concerns only the training checkpoint; a consolidated HF checkpoint for inference or release is exported separately and follows the model’s intended dtype (for fine-tuning, this matches the original HF checkpoint, typically bf16). See checkpointing for details.
Precision and Robustness
Both patterns keep the master weights and Adam EMA in fp32, so for most models they converge equivalently. The remaining difference is the gradient: under TE (Pattern A) the resident parameter is bf16, so the gradient feeding the Adam update carries bf16 rounding, while torch AdamW with fp32 storage (Pattern B) keeps the parameter, gradient, and optimizer state all in fp32.
This is usually negligible. When bringing up a new model, fp32 storage (Pattern B) is a robust starting point because every part of the update path is fp32; move to TE (Pattern A) after it is validated for that model and you want bf16 checkpoints.
Risky Pattern: torch AdamW with bf16 Model Storage
This pattern is easy to enter accidentally. In several AutoModel paths, leaving model.torch_dtype unset, or setting it to auto, resolves the resident model parameter dtype to bf16. If that is paired with torch.optim.AdamW, the AdamW EMA buffers are also bf16 because the optimizer initializes state from the parameter dtype.
This keeps the resident model parameters in bf16 instead of fp32, so it can reduce memory usage compared with the torch AdamW and model.torch_dtype: float32 pattern. It is common in existing fine-tuning example configs and is probably acceptable for short fine-tuning runs or LoRA/PEFT. It is not recommended for long full-parameter training (pre-training or extended fine-tuning): bf16 EMA quantization can quietly slow convergence (higher final loss at the same step count) and, in worse cases, produce unstable grad_norm, loss bumps, loss spikes, or divergence.
Example Configs
examples/llm_pretrain/llama3_70b_pretrain.yaml: TE FusedAdam example. This keeps model storage and training checkpoints in bf16 while using fp32 master weights and optimizer state.examples/llm_pretrain/megatron_pretrain_moonlight_16b_te_slurm.yaml: torch AdamW example. This usesmodel.torch_dtype: float32so AdamW state stays fp32 while compute remains bf16 through FSDP2 mixed precision.