nemo_automodel.components.optim.precision_warnings

View as Markdown

Module Contents

Functions

NameDescription
_get_cfg_attrRead key from a config node or dict, returning None when absent.
_has_trainable_bf16_param-
_is_rank_zero-
_is_torch_adam_config-
_is_torch_adam_optimizer-
_is_torch_optim_configTrue when the optimizer _target_ is a built-in torch.optim optimizer.
_iter_optimizer_params-
_set_cfg_attrSet key on a config node or dict.
resolve_storage_dtypeDefault model storage dtype to float32 for full-parameter torch.optim training.
warn_if_torch_adam_with_bf16_paramsWarn about full-parameter bf16 training with vanilla torch Adam optimizers.

Data

_DTYPE_RESOLVED_CONTEXTS

_TORCH_ADAM_TARGETS

_WARNED_CONTEXTS

API

nemo_automodel.components.optim.precision_warnings._get_cfg_attr(
cfg: typing.Any,
key: str
) -> typing.Any

Read key from a config node or dict, returning None when absent.

nemo_automodel.components.optim.precision_warnings._has_trainable_bf16_param(
parameters: collections.abc.Iterable[torch.nn.Parameter]
) -> bool
nemo_automodel.components.optim.precision_warnings._is_rank_zero() -> bool
nemo_automodel.components.optim.precision_warnings._is_torch_adam_config(
optimizer_cfg: typing.Any | None
) -> bool
nemo_automodel.components.optim.precision_warnings._is_torch_adam_optimizer(
optimizer: torch.optim.Optimizer | collections.abc.Iterable[torch.optim.Optimizer] | None
) -> bool
nemo_automodel.components.optim.precision_warnings._is_torch_optim_config(
optimizer_cfg: typing.Any | None
) -> bool

True when the optimizer _target_ is a built-in torch.optim optimizer.

These optimizers update the resident parameter in place and keep no internal fp32 master copy, so the storage dtype is the master-weight dtype and fp32 storage is always safe. Optimizers outside the torch.optim namespace (TE FusedAdam, DeepSpeed, bitsandbytes 8-bit, Muon, …) manage their own master / state precision and are deliberately excluded.

nemo_automodel.components.optim.precision_warnings._iter_optimizer_params(
optimizer: torch.optim.Optimizer | collections.abc.Iterable[torch.optim.Optimizer] | None
) -> collections.abc.Iterable[torch.nn.Parameter]
nemo_automodel.components.optim.precision_warnings._set_cfg_attr(
cfg: typing.Any,
key: str,
value: typing.Any
) -> None

Set key on a config node or dict.

nemo_automodel.components.optim.precision_warnings.resolve_storage_dtype(
cfg_model: typing.Any | None,
cfg_opt: typing.Any | None,
is_peft: bool = False,
context: str = 'recipe',
logger: logging.Logger | None = None
) -> None

Default model storage dtype to float32 for full-parameter torch.optim training.

Built-in torch.optim optimizers update the resident parameter in place and keep no internal fp32 master copy, so the model parameters are the master copy. Storing them in bf16 therefore makes optimizer updates and state bf16, which degrades training precision relative to frameworks that keep an fp32 master. To avoid that, when the user has not explicitly chosen a storage dtype we default cfg_model.torch_dtype to float32 so the parameters act as the fp32 master copy. fp32 storage is never numerically worse than bf16 for these optimizers; the only cost is memory, which an explicit model.torch_dtype=bfloat16 opts out of.

No-ops (leaving the dtype unchanged) when:

  • is_peft is True (base weights are frozen; only small adapters train), or
  • the optimizer is not a torch.optim optimizer (e.g. TE FusedAdam, DeepSpeed, or bitsandbytes optimizers, which manage their own master / state precision and so live outside the torch.optim namespace), or
  • model.torch_dtype is already set to a concrete (non-auto) value.

The decision is mutated on every rank (so all ranks agree) but logged only on rank zero. It is idempotent: once set, a second call sees the explicit value and returns.

Parameters:

cfg_model
Any | None

The model config node/dict (must expose/accept torch_dtype).

cfg_opt
Any | None

The optimizer config node/dict (read for _target_).

is_peft
boolDefaults to False

Whether this is a PEFT/LoRA run.

context
strDefaults to 'recipe'

Short label used for log de-duplication.

logger
logging.Logger | NoneDefaults to None

Optional logger; defaults to this module’s logger.

nemo_automodel.components.optim.precision_warnings.warn_if_torch_adam_with_bf16_params(
optimizer: torch.optim.Optimizer | collections.abc.Iterable[torch.optim.Optimizer] | None = None,
optimizer_cfg: typing.Any | None = None,
parameters: collections.abc.Iterable[torch.nn.Parameter] | None = None,
is_peft: bool = False,
context: str = 'recipe',
logger: logging.Logger | None = None
) -> None

Warn about full-parameter bf16 training with vanilla torch Adam optimizers.

nemo_automodel.components.optim.precision_warnings._DTYPE_RESOLVED_CONTEXTS: set[str] = set()
nemo_automodel.components.optim.precision_warnings._TORCH_ADAM_TARGETS = {'torch.optim.Adam', 'torch.optim.AdamW', 'torch.optim.adam.Adam', 'torch.optim....
nemo_automodel.components.optim.precision_warnings._WARNED_CONTEXTS: set[str] = set()