Mixed-Precision Training#

NeMo AutoModel uses FSDP2’s MixedPrecisionPolicy to control compute precision during forward and backward, and the model’s storage dtype (model.torch_dtype) to control the precision of the resident sharded parameter. Together these decide what numeric precision the optimizer state ends up in, which is the part that determines whether long full-parameter training runs converge cleanly.

This page describes the precision patterns we recommend, and the trap that sits between them. For any long full-parameter training run (pre-training or extended fine-tuning), the key rule is: do not combine torch.optim.AdamW with bf16 resident parameters unless you have explicitly accepted bf16 Adam state.

Storage Dtype vs. Compute Dtype#

Precision settings have distinct effects:

Setting

Controls

Effect on optimizer state

model.torch_dtype

Storage dtype of the sharded parameter that PyTorch holds.

The optimizer reads param.data, so the EMA buffers (exp_avg, exp_avg_sq for AdamW) end up in this dtype.

mp_policy.param_dtype

Compute dtype FSDP2 casts to during forward/backward.

None directly; this only affects matmul / activation precision.

mp_policy.reduce_dtype

Dtype used for gradient reduce-scatter / all-reduce across DP ranks.

None directly; only affects how gradients are summed.

mp_policy.output_dtype

Dtype FSDP2 casts module outputs to.

None directly; this affects activation tensors, including tensors that cross pipeline-parallel boundaries.

When model.torch_dtype: bfloat16 is used with torch.optim.AdamW, the AdamW EMA buffers (exp_avg and exp_avg_sq) are also stored in bf16. This is fragile for long full-parameter training runs: bf16 has only a 7-bit mantissa, so small EMA updates can be rounded away even though the values themselves are still in range. Symptoms range from silent degradation - slower convergence, i.e., higher final loss at the same step count, with no visible instability - to overt failure: unstable grad_norm, loss bumps, loss spikes, or divergence.

Risky Pattern: torch AdamW with bf16 Model Storage#

This pattern is easy to enter accidentally. In several AutoModel paths, leaving model.torch_dtype unset, or setting it to auto, resolves the resident model parameter dtype to bf16. If that is paired with torch.optim.AdamW, the AdamW EMA buffers are also bf16 because the optimizer initializes state from the parameter dtype.

model:
  # torch_dtype omitted, or torch_dtype: auto / bfloat16

optimizer:
  _target_: torch.optim.AdamW

distributed:
  strategy: fsdp2

This keeps the resident model parameters in bf16 instead of fp32, so it can reduce memory usage compared with the torch AdamW + model.torch_dtype: float32 pattern. It is common in existing fine-tuning example configs and is probably acceptable for short fine-tuning runs or LoRA / PEFT. It is not recommended for long full-parameter training (pre-training or extended fine-tuning): bf16 EMA quantization can quietly slow convergence (higher final loss at the same step count) and, in worse cases, produce unstable grad_norm, loss bumps, loss spikes, or divergence.

Example Configs#