Mixed-Precision Training

View as Markdown

NeMo AutoModel uses FSDP2’s MixedPrecisionPolicy to control compute precision during forward and backward, and the model’s storage dtype (model.torch_dtype) to control the precision of the resident sharded parameter. Together, these decide what numeric precision the optimizer state ends up in, which is the part that determines whether long full-parameter training runs converge cleanly.

This page describes the precision patterns we recommend, and the trap that sits between them. For any long full-parameter training run (pre-training or extended fine-tuning), the key rule is: do not combine torch.optim.AdamW with bf16 resident parameters unless you have explicitly accepted bf16 Adam state.

Storage Dtype vs. Compute Dtype

Precision settings have distinct effects:

SettingControlsEffect on optimizer state
model.torch_dtypeStorage dtype of the sharded parameter that PyTorch holds.For torch.optim, the optimizer reads param.data, so the EMA buffers (exp_avg, exp_avg_sq for AdamW) end up in this dtype. (TE FusedAdam keeps fp32 state regardless.)
mp_policy.param_dtypeCompute dtype FSDP2 casts to during forward/backward.None directly; this only affects matmul / activation precision.
mp_policy.reduce_dtypeDtype used for gradient reduce-scatter / all-reduce across DP ranks.None directly; only affects how gradients are summed.
mp_policy.output_dtypeDtype FSDP2 casts module outputs to.None directly; this affects activation tensors, including tensors that cross pipeline-parallel boundaries.

When model.torch_dtype: bfloat16 is used with torch.optim.AdamW, the AdamW EMA buffers (exp_avg and exp_avg_sq) are also stored in bf16. This is fragile for long full-parameter training runs: bf16 has only a 7-bit mantissa, so small EMA updates can be rounded away even though the values themselves are still in range. Symptoms range from silent degradation (slower convergence, i.e., higher final loss at the same step count, with no visible instability) to overt failure: unstable grad_norm, loss bumps, loss spikes, or divergence.

Use one of these patterns for long full-parameter training (pre-training or extended fine-tuning):

PatternModel storage dtypeOptimizer stateWhen to use
TE FusedAdam with bf16 model storagebf16fp32 master weights and fp32 Adam EMA buffersWhen TE has a validated memory/runtime path for the model and you want training checkpoints to stay bf16.
torch AdamW with model.torch_dtype: float32fp32fp32 master weight (the resident param) and fp32 Adam EMA buffersRobust starting point for new or precision-sensitive models — params, gradients, and optimizer state are all fp32. Trade-off: writes fp32 training checkpoints.

Both patterns keep forward/backward compute in bf16 through FSDP2 mixed precision. They differ in where the fp32 master weight lives, how much peak memory they use for a specific model, and what dtype is written to the training checkpoint.

Pattern A: TE FusedAdam and bf16 Model Storage

Use Transformer Engine FusedAdam when it has been validated for the model. The resident model parameters remain bf16, so model checkpoints can stay bf16, while TE keeps the optimizer’s master weights and Adam EMA buffers in fp32.

1model:
2 torch_dtype: bfloat16 # resident sharded parameter + model checkpoint in bf16
3
4optimizer:
5 _target_: transformer_engine.pytorch.optimizers.fused_adam.FusedAdam
6 lr: 3.0e-4
7 adam_w_mode: true
8 bias_correction: true
9 master_weights: true
10 store_param_remainders: true
11 exp_avg_dtype: torch.float32
12 exp_avg_sq_dtype: torch.float32
13
14distributed:
15 strategy: fsdp2
16 # Defaults already provide bf16 forward/backward + fp32 gradient reduction; this block is shown explicitly for clarity.
17 mp_policy:
18 _target_: torch.distributed.fsdp.MixedPrecisionPolicy
19 param_dtype: bfloat16
20 reduce_dtype: float32
21 output_dtype: bfloat16
22 cast_forward_inputs: true

TE FusedAdam is the cleanest way to request fp32 optimizer state without making the resident model parameter fp32.

It is a common misconception that TE costs extra memory for a second fp32 master. With store_param_remainders: true (as above), TE does not keep a full extra fp32 master: it stores the bf16 parameter plus a 16-bit remainder that together reconstruct the fp32 master, costing the same ~4 bytes/param as torch AdamW’s fp32 resident parameter. The fp32 Adam EMA buffers (exp_avg, exp_avg_sq) are the same in both patterns, so the two optimizers’ steady-state footprints are essentially equal. The practical difference is then simplicity (torch AdamW: no TE dependency, fewer moving parts) vs. keeping model storage and checkpoints in bf16 (TE).

Where the two can differ is the gradient buffer: TE keeps the resident parameter in bf16, so its gradients are bf16, whereas torch AdamW with fp32 storage keeps parameters and gradients in fp32 (~2 bytes/param more on the gradient buffer). In practice, what we measure is peak memory, which is usually dominated by activations rather than the optimizer step, so depending on the model (fraction of intrinsically-fp32 params, fragmentation, where the peak falls), TE can come out lower, equal, or higher. Validate memory per model before making it the default.

Pattern B: torch AdamW and fp32 Model Storage

1model:
2 torch_dtype: float32 # sharded parameter + AdamW state in fp32
3
4optimizer:
5 _target_: torch.optim.AdamW
6 lr: 3.0e-4
7 betas: [0.9, 0.95]
8 weight_decay: 0.1
9
10distributed:
11 strategy: fsdp2
12 # Defaults already provide bf16 forward/backward + fp32 gradient reduction; this block is shown explicitly for clarity.
13 mp_policy:
14 _target_: torch.distributed.fsdp.MixedPrecisionPolicy
15 param_dtype: bfloat16 # forward/backward compute in bf16 (fast)
16 reduce_dtype: float32 # safe gradient reduction
17 output_dtype: bfloat16
18 cast_forward_inputs: true

This is the PyTorch AdamW version of the master-weights pattern. Forward/backward run in bf16, the all-reduce / reduce-scatter runs in fp32, and the optimizer applies updates against fp32 resident parameters. With torch.optim.AdamW, the resident fp32 sharded parameter is the master weight, so there is no separate fp32 master-weight copy.

The trade-off is the dtype of the training checkpoint: AutoModel’s training (DCP) checkpoint stores the resident model parameters, so this pattern writes them in fp32, which you keep for exact resume. This concerns only the training checkpoint; a consolidated HF checkpoint for inference or release is exported separately and follows the model’s intended dtype (for fine-tuning, this matches the original HF checkpoint, typically bf16). See checkpointing for details.

Precision and Robustness

Both patterns keep the master weights and Adam EMA in fp32, so for most models they converge equivalently. The remaining difference is the gradient: under TE (Pattern A) the resident parameter is bf16, so the gradient feeding the Adam update carries bf16 rounding, while torch AdamW with fp32 storage (Pattern B) keeps the parameter, gradient, and optimizer state all in fp32.

This is usually negligible. When bringing up a new model, fp32 storage (Pattern B) is a robust starting point because every part of the update path is fp32; move to TE (Pattern A) after it is validated for that model and you want bf16 checkpoints.

Risky Pattern: torch AdamW with bf16 Model Storage

This pattern is easy to enter accidentally. In several AutoModel paths, leaving model.torch_dtype unset, or setting it to auto, resolves the resident model parameter dtype to bf16. If that is paired with torch.optim.AdamW, the AdamW EMA buffers are also bf16 because the optimizer initializes state from the parameter dtype.

1model:
2 # torch_dtype omitted, or torch_dtype: auto / bfloat16
3
4optimizer:
5 _target_: torch.optim.AdamW
6
7distributed:
8 strategy: fsdp2

This keeps the resident model parameters in bf16 instead of fp32, so it can reduce memory usage compared with the torch AdamW and model.torch_dtype: float32 pattern. It is common in existing fine-tuning example configs and is probably acceptable for short fine-tuning runs or LoRA/PEFT. It is not recommended for long full-parameter training (pre-training or extended fine-tuning): bf16 EMA quantization can quietly slow convergence (higher final loss at the same step count) and, in worse cases, produce unstable grad_norm, loss bumps, loss spikes, or divergence.

Example Configs