nemo_automodel.components.optim.precision_warnings
nemo_automodel.components.optim.precision_warnings
Module Contents
Functions
Data
API
Read key from a config node or dict, returning None when absent.
True when the optimizer _target_ is a built-in torch.optim optimizer.
These optimizers update the resident parameter in place and keep no internal fp32
master copy, so the storage dtype is the master-weight dtype and fp32 storage is
always safe. Optimizers outside the torch.optim namespace (TE FusedAdam,
DeepSpeed, bitsandbytes 8-bit, Muon, …) manage their own master / state precision
and are deliberately excluded.
Set key on a config node or dict.
Default model storage dtype to float32 for full-parameter torch.optim training.
Built-in torch.optim optimizers update the resident parameter in place and keep
no internal fp32 master copy, so the model parameters are the master copy. Storing
them in bf16 therefore makes optimizer updates and state bf16, which degrades training
precision relative to frameworks that keep an fp32 master. To avoid that, when the user
has not explicitly chosen a storage dtype we default cfg_model.torch_dtype to
float32 so the parameters act as the fp32 master copy. fp32 storage is never
numerically worse than bf16 for these optimizers; the only cost is memory, which an
explicit model.torch_dtype=bfloat16 opts out of.
No-ops (leaving the dtype unchanged) when:
is_peftis True (base weights are frozen; only small adapters train), or- the optimizer is not a
torch.optimoptimizer (e.g. TEFusedAdam, DeepSpeed, or bitsandbytes optimizers, which manage their own master / state precision and so live outside thetorch.optimnamespace), or model.torch_dtypeis already set to a concrete (non-auto) value.
The decision is mutated on every rank (so all ranks agree) but logged only on rank zero. It is idempotent: once set, a second call sees the explicit value and returns.
Parameters:
The model config node/dict (must expose/accept torch_dtype).
The optimizer config node/dict (read for _target_).
Whether this is a PEFT/LoRA run.
Short label used for log de-duplication.
Optional logger; defaults to this module’s logger.
Warn about full-parameter bf16 training with vanilla torch Adam optimizers.