Mixed-Precision Training#
NeMo AutoModel uses FSDP2’s MixedPrecisionPolicy to control compute precision during forward and backward, and the model’s storage dtype (model.torch_dtype) to control the precision of the resident sharded parameter. Together these decide what numeric precision the optimizer state ends up in, which is the part that determines whether long full-parameter training runs converge cleanly.
This page describes the precision patterns we recommend, and the trap that sits between them. For any long full-parameter training run (pre-training or extended fine-tuning), the key rule is: do not combine torch.optim.AdamW with bf16 resident parameters unless you have explicitly accepted bf16 Adam state.
Storage Dtype vs. Compute Dtype#
Precision settings have distinct effects:
Setting |
Controls |
Effect on optimizer state |
|---|---|---|
|
Storage dtype of the sharded parameter that PyTorch holds. |
The optimizer reads |
|
Compute dtype FSDP2 casts to during forward/backward. |
None directly; this only affects matmul / activation precision. |
|
Dtype used for gradient reduce-scatter / all-reduce across DP ranks. |
None directly; only affects how gradients are summed. |
|
Dtype FSDP2 casts module outputs to. |
None directly; this affects activation tensors, including tensors that cross pipeline-parallel boundaries. |
When model.torch_dtype: bfloat16 is used with torch.optim.AdamW, the AdamW EMA buffers (exp_avg and exp_avg_sq) are also stored in bf16. This is fragile for long full-parameter training runs: bf16 has only a 7-bit mantissa, so small EMA updates can be rounded away even though the values themselves are still in range. Symptoms range from silent degradation - slower convergence, i.e., higher final loss at the same step count, with no visible instability - to overt failure: unstable grad_norm, loss bumps, loss spikes, or divergence.
Recommended fp32 Optimizer-State Patterns#
Use one of these patterns for long pre-training:
Pattern |
Model storage / checkpoint dtype |
Optimizer state |
When to use |
|---|---|---|---|
TE FusedAdam with bf16 model storage |
bf16 |
fp32 master weights + fp32 Adam EMA buffers |
Use when the model has a validated TE memory/runtime path and you want training checkpoints to stay bf16. |
torch AdamW with |
fp32 |
fp32 Adam EMA buffers |
Use when the PyTorch optimizer path is the validated lower-memory or more stable path for that model. |
Both patterns keep forward/backward compute in bf16 through FSDP2 mixed precision. They differ in where the fp32 master weight lives, how much memory the optimizer uses for a specific model, and what dtype is written to model checkpoints.
Pattern A: TE FusedAdam + bf16 Model Storage#
Use Transformer Engine FusedAdam when it has been validated for the model. The resident model parameters remain bf16, so model checkpoints can stay bf16, while TE keeps the optimizer’s master weights and Adam EMA buffers in fp32.
model:
torch_dtype: bfloat16 # resident sharded parameter + model checkpoint in bf16
optimizer:
_target_: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
lr: 3.0e-4
adam_w_mode: true
bias_correction: true
master_weights: true
store_param_remainders: true
exp_avg_dtype: torch.float32
exp_avg_sq_dtype: torch.float32
distributed:
strategy: fsdp2
# Defaults already provide bf16 forward/backward + fp32 gradient reduction; this block is shown explicitly for clarity.
mp_policy:
_target_: torch.distributed.fsdp.MixedPrecisionPolicy
param_dtype: bfloat16
reduce_dtype: float32
output_dtype: bfloat16
cast_forward_inputs: true
TE FusedAdam is the cleanest way to request fp32 optimizer state without making the resident model parameter fp32. It is not a guaranteed memory reduction: depending on model architecture, sharding, and optimizer implementation details, TE can use either less or more memory than torch AdamW with fp32 resident parameters. Validate memory per model before making it the default.
Pattern B: torch AdamW + fp32 Model Storage#
model:
torch_dtype: float32 # sharded parameter + AdamW state in fp32
optimizer:
_target_: torch.optim.AdamW
lr: 3.0e-4
betas: [0.9, 0.95]
weight_decay: 0.1
distributed:
strategy: fsdp2
# Defaults already provide bf16 forward/backward + fp32 gradient reduction; this block is shown explicitly for clarity.
mp_policy:
_target_: torch.distributed.fsdp.MixedPrecisionPolicy
param_dtype: bfloat16 # forward/backward compute in bf16 (fast)
reduce_dtype: float32 # safe gradient reduction
output_dtype: bfloat16
cast_forward_inputs: true
This is the PyTorch AdamW version of the master-weights pattern. Forward/backward run in bf16, the all-reduce / reduce-scatter runs in fp32, and the optimizer applies updates against fp32 resident parameters. With torch.optim.AdamW, the resident fp32 sharded parameter is the master weight, so there is no separate fp32 master-weight copy.
The main artifact trade-off is checkpoint dtype: because AutoModel checkpoints the resident model parameters, this pattern writes model checkpoints in fp32. Keep those fp32 training checkpoints if you need exact resume. If you need a bf16 checkpoint for inference or release, export a separate model-only bf16 checkpoint after training.
Risky Pattern: torch AdamW with bf16 Model Storage#
This pattern is easy to enter accidentally. In several AutoModel paths, leaving model.torch_dtype unset, or setting it to auto, resolves the resident model parameter dtype to bf16. If that is paired with torch.optim.AdamW, the AdamW EMA buffers are also bf16 because the optimizer initializes state from the parameter dtype.
model:
# torch_dtype omitted, or torch_dtype: auto / bfloat16
optimizer:
_target_: torch.optim.AdamW
distributed:
strategy: fsdp2
This keeps the resident model parameters in bf16 instead of fp32, so it can reduce memory usage compared with the torch AdamW + model.torch_dtype: float32 pattern. It is common in existing fine-tuning example configs and is probably acceptable for short fine-tuning runs or LoRA / PEFT. It is not recommended for long full-parameter training (pre-training or extended fine-tuning): bf16 EMA quantization can quietly slow convergence (higher final loss at the same step count) and, in worse cases, produce unstable grad_norm, loss bumps, loss spikes, or divergence.
Example Configs#
examples/llm_pretrain/llama3_70b_pretrain.yaml— TE FusedAdam example. This keeps model storage and training checkpoints in bf16 while using fp32 optimizer state.examples/llm_pretrain/megatron_pretrain_moonlight_16b_te_slurm.yaml— torch AdamW example. This usesmodel.torch_dtype: float32so AdamW state stays fp32 while compute remains bf16 through FSDP2 mixed precision.