nemo_automodel.components.optim.optimizer

View as Markdown

Typed optimizer + LR scheduler configs (TorchTitan-style).

Each optimizer config is a plain dataclass exposing the full parameter surface as named fields (no opaque **kwargs). Reading the dataclass tells you exactly what you can configure.

Every config owns its own construction via config.build(model, ...), which loops over model.parts and applies the per-part concerns (TP foreach, Megatron-FSDP sharding). Subclasses only implement the small _build_optimizer(params) hook; configs with bespoke construction needs (e.g. :class:MuonConfig’s Dion parameter grouping) override build directly.

:func:build_optimizer is a thin dispatcher: it normalizes its optimizer_config argument to an :class:OptimizerConfig and returns config.build(model, ...). The argument is either:

  • a typed :class:OptimizerConfig instance — the Automodel-native path; or
  • a (name_or_path, kwargs) tuple, where name_or_path is a short registry name ("adam", "adamw", "muon", …) or a dotted import path ("torch.optim.AdamW"). It is resolved and constructed with kwargs: a typed config from its fields, or — for any other callable — the escape hatch for external integrations (e.g. veRL) via :class:OptimizerFromFactoryConfig.

Module Contents

Classes

NameDescription
AdamConfigtorch.optim.Adam.
AdamWConfigtorch.optim.AdamW.
Dion2Configdion.Dion2 — recommended successor to the legacy Dion optimizer.
DionConfigdion.Dion — legacy low-rank optimizer (prefer :class:Dion2Config).
FlashAdamWConfigflashoptim.FlashAdamW.
FusedAdamConfigtransformer_engine.pytorch.optimizers.FusedAdam.
LRSchedulerConfigLR scheduler configuration. None fields are computed by
MuonConfigdion.Muon — matrix-aware update for 2D+ params, scalar fallback for 1D.
NorMuonConfigdion.NorMuon — Muon variant with neuron-wise normalization.
OptimizerConfigBase optimizer config.
OptimizerFromFactoryConfigBuild an optimizer from an arbitrary factory callable plus kwargs.
_DionConfigBaseShared base for the dion-family typed configs (Muon / NorMuon / Dion2 / Dion).

Functions

NameDescription
_factory_accepts_foreachReturn True if factory accepts a foreach kwarg.
_foreach_for_meshReturn False when TP > 1 (foreach is unsupported), else None.
_import_from_pathImport an object from a dotted path, e.g. "torch.optim.AdamW".
build_optimizerBuild one optimizer per model.parts (or [model]).
build_optimizer_configNormalize an optimizer target plus kwargs into an :class:OptimizerConfig.

Data

OPTIMIZER_CONFIG_REGISTRY

_DION_CONFIG_FOR

_DION_GROUPING_FIELDS

_DTYPE_FIELDS

__all__

logger

API

class nemo_automodel.components.optim.optimizer.AdamConfig(
lr: float = 0.0001,
weight_decay: float = 0.01,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
amsgrad: bool = False
)
Dataclass

Bases: OptimizerConfig

torch.optim.Adam.

amsgrad
bool = False
betas
tuple[float, float] = (0.9, 0.999)
eps
float = 1e-08
lr
float = 0.0001
weight_decay
float = 0.01
nemo_automodel.components.optim.optimizer.AdamConfig._build_optimizer(
params,
foreach: bool | None = None
) -> torch.optim.Optimizer
class nemo_automodel.components.optim.optimizer.AdamWConfig(
lr: float = 0.0001,
weight_decay: float = 0.01,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
amsgrad: bool = False,
fused: bool = False
)
Dataclass

Bases: OptimizerConfig

torch.optim.AdamW.

amsgrad
bool = False
betas
tuple[float, float] = (0.9, 0.999)
eps
float = 1e-08
fused
bool = False
lr
float = 0.0001
weight_decay
float = 0.01
nemo_automodel.components.optim.optimizer.AdamWConfig._build_optimizer(
params,
foreach: bool | None = None
) -> torch.optim.Optimizer
class nemo_automodel.components.optim.optimizer.Dion2Config(
lr: float = 0.0005,
weight_decay: float = 0.0,
scalar_opt: str = 'adamw',
scalar_betas: tuple[float, float] = (0.9, 0.999),
scalar_eps: float = 1e-08,
scalar_lr: float | None = None,
embed_lr: float | None = None,
lm_head_lr: float | None = None,
fraction: float = 0.25,
ef_decay: float = 0.95,
betas: tuple[float, float] = (0.9, 0.95),
epsilon: float = 1e-08,
adjust_lr: str = 'spectral_norm'
)
Dataclass

Bases: _DionConfigBase

dion.Dion2 — recommended successor to the legacy Dion optimizer.

adjust_lr
str = 'spectral_norm'
betas
tuple[float, float] = (0.9, 0.95)
ef_decay
float = 0.95
epsilon
float = 1e-08
fraction
float = 0.25
nemo_automodel.components.optim.optimizer.Dion2Config._make_optimizer(
param_groups,
ctor_kwargs
)
class nemo_automodel.components.optim.optimizer.DionConfig(
lr: float = 0.0005,
weight_decay: float = 0.0,
scalar_opt: str = 'adamw',
scalar_betas: tuple[float, float] = (0.9, 0.999),
scalar_eps: float = 1e-08,
scalar_lr: float | None = None,
embed_lr: float | None = None,
lm_head_lr: float | None = None,
mu: float = 0.95,
betas: tuple[float, float] = (0.9, 0.95),
epsilon: float = 1e-08,
rank_fraction: float = 1.0,
rank_multiple_of: int = 1,
power_iters: int = 1,
qr_method: str = 'rcqr'
)
Dataclass

Bases: _DionConfigBase

dion.Dion — legacy low-rank optimizer (prefer :class:Dion2Config).

Legacy Dion takes separate replicate/outer/inner shard meshes; for FSDP2 the resolved 1-D shard submesh maps to outer_shard_mesh.

_mesh_kwarg
str = 'outer_shard_mesh'
betas
tuple[float, float] = (0.9, 0.95)
epsilon
float = 1e-08
mu
float = 0.95
power_iters
int = 1
qr_method
str = 'rcqr'
rank_fraction
float = 1.0
rank_multiple_of
int = 1
nemo_automodel.components.optim.optimizer.DionConfig._make_optimizer(
param_groups,
ctor_kwargs
)
class nemo_automodel.components.optim.optimizer.FlashAdamWConfig(
lr: float = 0.0001,
weight_decay: float = 0.01,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
master_weight_bits: int = 24
)
Dataclass

Bases: OptimizerConfig

flashoptim.FlashAdamW.

betas
tuple[float, float] = (0.9, 0.999)
eps
float = 1e-08
lr
float = 0.0001
master_weight_bits
int = 24
weight_decay
float = 0.01
nemo_automodel.components.optim.optimizer.FlashAdamWConfig._build_optimizer(
params,
foreach: bool | None = None
) -> torch.optim.Optimizer
class nemo_automodel.components.optim.optimizer.FusedAdamConfig(
lr: float = 0.0001,
weight_decay: float = 0.01,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-08,
adam_w_mode: bool = True,
bias_correction: bool = True,
master_weights: bool = True,
master_weight_dtype: str | None = None
)
Dataclass

Bases: OptimizerConfig

transformer_engine.pytorch.optimizers.FusedAdam.

adam_w_mode
bool = True
betas
tuple[float, float] = (0.9, 0.999)
bias_correction
bool = True
eps
float = 1e-08
lr
float = 0.0001
master_weight_dtype
str | None = None
master_weights
bool = True
weight_decay
float = 0.01
nemo_automodel.components.optim.optimizer.FusedAdamConfig._build_optimizer(
params,
foreach: bool | None = None
) -> torch.optim.Optimizer
class nemo_automodel.components.optim.optimizer.LRSchedulerConfig(
lr_warmup_steps: int | None = None,
lr_decay_steps: int | None = None,
lr_decay_style: str = 'cosine',
init_lr: float | None = None,
max_lr: float | None = None,
min_lr: float | None = None,
start_wd: float | None = None,
end_wd: float | None = None,
wd_incr_steps: int | None = None,
wd_incr_style: str = 'constant',
use_checkpoint_opt_param_scheduler: bool = True,
override_opt_param_scheduler: bool = False,
wsd_decay_steps: int | None = None,
lr_wsd_decay_style: str | None = None
)
Dataclass

LR scheduler configuration. None fields are computed by :meth:build from the training schedule (total steps, optimizer base LR/WD).

end_wd
float | None = None
init_lr
float | None = None
lr_decay_steps
int | None = None
lr_decay_style
str = 'cosine'
lr_warmup_steps
int | None = None
lr_wsd_decay_style
str | None = None
max_lr
float | None = None
min_lr
float | None = None
override_opt_param_scheduler
bool = False
start_wd
float | None = None
use_checkpoint_opt_param_scheduler
bool = True
wd_incr_steps
int | None = None
wd_incr_style
str = 'constant'
wsd_decay_steps
int | None = None
nemo_automodel.components.optim.optimizer.LRSchedulerConfig.build(
optimizer: list[torch.optim.Optimizer] | torch.optim.Optimizer,
step_scheduler: nemo_automodel.components.training.step_scheduler.StepScheduler
) -> list[nemo_automodel.components.optim.scheduler.OptimizerParamScheduler]

Build one LR scheduler per optimizer.

None fields are filled from the training schedule and each optimizer’s base LR/WD.

Parameters:

optimizer
list[torch.optim.Optimizer] | torch.optim.Optimizer

The optimizer(s) to schedule.

step_scheduler
StepScheduler

The step scheduler, used to derive total steps.

Returns: list[OptimizerParamScheduler]

class:OptimizerParamScheduler per optimizer.

class nemo_automodel.components.optim.optimizer.MuonConfig(
lr: float = 0.0005,
weight_decay: float = 0.0,
scalar_opt: str = 'adamw',
scalar_betas: tuple[float, float] = (0.9, 0.999),
scalar_eps: float = 1e-08,
scalar_lr: float | None = None,
embed_lr: float | None = None,
lm_head_lr: float | None = None,
mu: float = 0.95,
betas: tuple[float, float] = (0.9, 0.95),
epsilon: float = 1e-08,
adjust_lr: str = 'spectral_norm',
nesterov: bool = False,
flatten: bool = False,
use_triton: bool = False
)
Dataclass

Bases: _DionConfigBase

dion.Muon — matrix-aware update for 2D+ params, scalar fallback for 1D.

adjust_lr
str = 'spectral_norm'
betas
tuple[float, float] = (0.9, 0.95)
epsilon
float = 1e-08
flatten
bool = False
mu
float = 0.95
nesterov
bool = False
use_triton
bool = False
nemo_automodel.components.optim.optimizer.MuonConfig._make_optimizer(
param_groups,
ctor_kwargs
)
class nemo_automodel.components.optim.optimizer.NorMuonConfig(
lr: float = 0.0005,
weight_decay: float = 0.0,
scalar_opt: str = 'adamw',
scalar_betas: tuple[float, float] = (0.9, 0.999),
scalar_eps: float = 1e-08,
scalar_lr: float | None = None,
embed_lr: float | None = None,
lm_head_lr: float | None = None,
mu: float = 0.95,
muon_beta2: float = 0.95,
betas: tuple[float, float] = (0.9, 0.95),
epsilon: float = 1e-08,
adjust_lr: str = 'spectral_norm'
)
Dataclass

Bases: _DionConfigBase

dion.NorMuon — Muon variant with neuron-wise normalization.

adjust_lr
str = 'spectral_norm'
betas
tuple[float, float] = (0.9, 0.95)
epsilon
float = 1e-08
mu
float = 0.95
muon_beta2
float = 0.95
nemo_automodel.components.optim.optimizer.NorMuonConfig._make_optimizer(
param_groups,
ctor_kwargs
)
class nemo_automodel.components.optim.optimizer.OptimizerConfig()
Dataclass

Base optimizer config.

Subclasses expose their full field surface and implement :meth:_build_optimizer, the per-part hook that constructs a single optimizer from a list of parameters. :meth:build owns the shared orchestration (per-part loop, TP foreach) and is rarely overridden — only by configs whose construction does not fit the parameters -> optimizer shape (e.g. :class:MuonConfig). Megatron-FSDP optimizer sharding is no longer applied here; the recipe layer re-applies it via shard_optimizers_for_megatron_fsdp(...).

supports_megatron_fsdp_sharding
bool = True
nemo_automodel.components.optim.optimizer.OptimizerConfig._build_optimizer(
params,
foreach: bool | None = None
) -> torch.optim.Optimizer

Construct a single optimizer for params (one model part).

nemo_automodel.components.optim.optimizer.OptimizerConfig.build(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
is_peft: bool = False
) -> list[torch.optim.Optimizer]

Build one optimizer per model.parts (or [model]).

Applies the shared per-part concern (TP foreach disabling) and delegates the actual optimizer instantiation to :meth:_build_optimizer. Megatron-FSDP optimizer sharding is applied by the recipe layer, not here.

Parameters:

model
torch.nn.Module

Model (or model with .parts) to optimize.

device_mesh
DeviceMesh | NoneDefaults to None

Device mesh used for tensor/data parallelism.

is_peft
boolDefaults to False

Whether the model is being trained with PEFT (suppresses the bf16 torch-Adam precision warning).

Returns: list[torch.optim.Optimizer]

One optimizer per model part.

nemo_automodel.components.optim.optimizer.OptimizerConfig.build_from_param_groups(
param_groups: list[dict[str, typing.Any]],
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None
) -> torch.optim.Optimizer

Build one optimizer from caller-defined parameter groups.

class nemo_automodel.components.optim.optimizer.OptimizerFromFactoryConfig(
factory: collections.abc.Callable[..., torch.optim.Optimizer] | None = None,
kwargs: dict[str, typing.Any] = dict()
)
Dataclass

Bases: OptimizerConfig

Build an optimizer from an arbitrary factory callable plus kwargs.

The integration escape hatch (e.g. veRL): rather than exposing typed fields, it wraps an optimizer class/callable and the **kwargs to construct it. This keeps the factory path on the same config.build(model, ...) contract as the typed configs, so :func:build_optimizer never has to special-case it.

Hyperparameters live in :attr:kwargs; the inherited lr/weight_decay fields are unused. The factory is called as factory(params=..., **kwargs); Dion-family optimizers (which need parameter grouping) should use the typed :class:MuonConfig instead.

factory
Callable[..., Optimizer] | None = None
kwargs
dict[str, Any] = field(default_factory=dict)
nemo_automodel.components.optim.optimizer.OptimizerFromFactoryConfig.build(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
is_peft: bool = False
) -> list[torch.optim.Optimizer]
nemo_automodel.components.optim.optimizer.OptimizerFromFactoryConfig.build_from_param_groups(
param_groups: list[dict[str, typing.Any]],
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None
) -> torch.optim.Optimizer
class nemo_automodel.components.optim.optimizer._DionConfigBase(
lr: float = 0.0005,
weight_decay: float = 0.0,
scalar_opt: str = 'adamw',
scalar_betas: tuple[float, float] = (0.9, 0.999),
scalar_eps: float = 1e-08,
scalar_lr: float | None = None,
embed_lr: float | None = None,
lm_head_lr: float | None = None
)
Dataclass

Bases: OptimizerConfig

Shared base for the dion-family typed configs (Muon / NorMuon / Dion2 / Dion).

Dion optimizers need Dion’s parameter grouping (built from the model) and the device mesh rather than a flat parameter list, so :meth:build runs grouping per model part. The grouping-only fields below (scalar_* / *_lr) are consumed by :func:build_dion_optimizer and stripped from the constructor kwargs. Dion is incompatible with Megatron-FSDP optimizer sharding; this is enforced at the recipe layer (supports_megatron_fsdp_sharding = False drives an allow=False sharding call that asserts rather than silently returning an unsharded optimizer).

_mesh_kwarg
str = 'distributed_mesh'
embed_lr
float | None = None
lm_head_lr
float | None = None
lr
float = 0.0005
scalar_betas
tuple[float, float] = (0.9, 0.999)
scalar_eps
float = 1e-08
scalar_lr
float | None = None
scalar_opt
str = 'adamw'
supports_megatron_fsdp_sharding
bool = False
weight_decay
float = 0.0
nemo_automodel.components.optim.optimizer._DionConfigBase._make_optimizer(
param_groups: typing.Any,
ctor_kwargs: dict[str, typing.Any]
) -> torch.optim.Optimizer

Instantiate the concrete dion optimizer from grouped params + filtered kwargs.

nemo_automodel.components.optim.optimizer._DionConfigBase.build(
model: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
is_peft: bool = False
) -> list[torch.optim.Optimizer]
nemo_automodel.components.optim.optimizer._factory_accepts_foreach(
factory: collections.abc.Callable[..., typing.Any]
) -> bool

Return True if factory accepts a foreach kwarg.

torch.optim optimizers take foreach; external factories such as TE FusedAdam do not, so passing it would raise TypeError.

nemo_automodel.components.optim.optimizer._foreach_for_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh | None
) -> bool | None

Return False when TP > 1 (foreach is unsupported), else None.

nemo_automodel.components.optim.optimizer._import_from_path(
path: str
) -> typing.Any

Import an object from a dotted path, e.g. "torch.optim.AdamW".

nemo_automodel.components.optim.optimizer.build_optimizer(
model: torch.nn.Module,
config: nemo_automodel.components.optim.optimizer.OptimizerConfig | tuple[str, dict[str, typing.Any]],
device_mesh: torch.distributed.device_mesh.DeviceMesh | None = None
) -> list[torch.optim.Optimizer]

Build one optimizer per model.parts (or [model]).

Thin dispatcher: it normalizes config to an :class:OptimizerConfig and returns config.build(model, ...). Per-part concerns (TP foreach, Dion param grouping) live on the config. Megatron-FSDP optimizer sharding is re-applied separately by the recipe layer.

config is one of:

  • a typed :class:OptimizerConfig instance — the Automodel-native path.
  • a (name_or_path, kwargs) tuple, where name_or_path is a short registry name (see :data:OPTIMIZER_CONFIG_REGISTRY, e.g. "adamw") or a dotted import path (e.g. "torch.optim.AdamW"), and kwargs are the constructor arguments. A registry/import-path that resolves to an :class:OptimizerConfig subclass is built from its typed fields; any other callable is wrapped in an :class:OptimizerFromFactoryConfig (the escape hatch for external integrations, e.g. veRL).

Parameters:

model
torch.nn.Module

Model (or model with .parts) to optimize.

config
OptimizerConfig | tuple[str, dict[str, Any]]

An :class:OptimizerConfig instance or a (name_or_path, kwargs) tuple.

device_mesh
DeviceMesh | NoneDefaults to None

Device mesh used for tensor/data parallelism.

Returns: list[torch.optim.Optimizer]

One optimizer per model part.

nemo_automodel.components.optim.optimizer.build_optimizer_config(
target: nemo_automodel.components.optim.optimizer.OptimizerConfig | str | type[nemo_automodel.components.optim.optimizer.OptimizerConfig] | collections.abc.Callable[..., torch.optim.Optimizer],
kwargs: dict[str, typing.Any] | None = None
) -> nemo_automodel.components.optim.optimizer.OptimizerConfig

Normalize an optimizer target plus kwargs into an :class:OptimizerConfig.

This is the single normalization entry point shared by the recipe layer (which resolves a YAML _target_ to a Python object) and :func:build_optimizer (which accepts (name_or_path, kwargs) tuples).

target is one of:

  • an :class:OptimizerConfig instance — returned as-is (kwargs ignored, since the instance already carries its typed fields).
  • an :class:OptimizerConfig subclass — instantiated from its typed fields with **kwargs.
  • a string — a registry short name (see :data:OPTIMIZER_CONFIG_REGISTRY, e.g. "adamw") or a dotted import path (e.g. "torch.optim.AdamW"); it is resolved and then handled as a subclass or callable.
  • any other optimizer callable/class — wrapped in an :class:OptimizerFromFactoryConfig (the escape hatch for external integrations, e.g. veRL).

Parameters:

target
OptimizerConfig | str | type[OptimizerConfig] | Callable[..., torch.optim.Optimizer]

The optimizer config instance/subclass, registry name or import path, or optimizer callable to normalize.

kwargs
dict[str, Any] | NoneDefaults to None

Constructor arguments for the resolved config/callable.

Returns: OptimizerConfig

class:OptimizerConfig ready to build(...).

nemo_automodel.components.optim.optimizer.OPTIMIZER_CONFIG_REGISTRY: dict[str, type[OptimizerConfig]] = {'adam': AdamConfig, 'adamw': AdamWConfig, 'fused_adam': FusedAdamConfig, 'flash...
nemo_automodel.components.optim.optimizer._DION_CONFIG_FOR: dict[str, type[OptimizerConfig]] = {'Muon': MuonConfig, 'NorMuon': NorMuonConfig, 'Dion2': Dion2Config, 'Dion': Dio...
nemo_automodel.components.optim.optimizer._DION_GROUPING_FIELDS = frozenset({'scalar_opt', 'scalar_betas', 'scalar_eps', 'scalar_lr', 'embed_lr', ...
nemo_automodel.components.optim.optimizer._DTYPE_FIELDS = ('master_weight_dtype', 'exp_avg_dtype', 'exp_avg_sq_dtype')
nemo_automodel.components.optim.optimizer.__all__ = ['OPTIMIZER_CONFIG_REGISTRY', 'AdamConfig', 'AdamWConfig', 'Dion2Config', 'DionC...
nemo_automodel.components.optim.optimizer.logger = logging.getLogger(__name__)