nemo_automodel.components.loss.loss
nemo_automodel.components.loss.loss
Typed loss configs + builder (TorchTitan-style).
Each loss config is a plain dataclass exposing its full parameter surface as
named fields (no opaque **kwargs) and a build() method that constructs
the loss module directly — lazy imports keep optional kernel deps out of module
load. Reading the dataclass tells you exactly what you can configure.
:func:build_loss_config normalizes any of these into a :class:LossConfig,
and :func:build_loss_module is the single build entry point — a thin wrapper that
returns build_loss_config(...).build(). Both dispatch on the argument:
- a typed :class:
LossConfiginstance — the Automodel-native path; per-loss construction delegates toconfig.build(). - a registry name or a known loss class (see :data:
LOSS_CONFIG_REGISTRY, e.g.MaskedCrossEntropy→ :class:MaskedCrossEntropyConfig) — upgraded to its typed config so the YAML recipe path gets the same field validation. - any other loss class / callable plus arbitrary
**loss_kwargs— the escape hatch (wrapped in :class:LossFromFactoryConfig) for external integrations (e.g. veRL). Adding a new typed config never forces the caller to change.
Module Contents
Classes
Functions
Data
API
Bases: LossConfig
KDLoss (knowledge distillation).
Base loss config. Subclasses expose their full field surface and
implement :meth:build.
Construct the loss module.
Bases: LossConfig
Escape hatch for external integrations (e.g. veRL) and the YAML recipe path.
Rather than exposing typed fields, it wraps a loss class/callable and the
**kwargs to construct it, keeping the factory path on the same
config.build() contract as the typed configs so callers never have to
special-case it. The factory is called as factory(**kwargs).
Normalize a loss loss target plus kwargs into a :class:LossConfig.
The single normalization entry point shared by the recipe layer (which
resolves a YAML _target_ to a Python object) and :func:build_loss_module.
It dispatches on loss:
- a typed :class:
LossConfiginstance — returned as-is;**loss_kwargsmust be empty (hyperparameters live on the config). - a :class:
LossConfigsubclass (the class) — rejected; pass an instance. - a registry name (see :data:
LOSS_CONFIG_REGISTRY, e.g."MaskedCrossEntropy") — built from the matching typed config. - a loss class / callable plus
**loss_kwargs— if it is a registered loss class and the kwargs fit the config’s fields, it is upgraded to that typed config; otherwise it is wrapped in a :class:LossFromFactoryConfig. The caller resolves any dotted path to a callable; the component never does dotted-path resolution.
Returns: LossConfig
class:LossConfig ready to build().
Build a loss function.
Thin dispatcher: it normalizes loss to a :class:LossConfig via
:func:build_loss_config and returns config.build(). Dispatches on
loss:
- Typed config (:class:
LossConfiginstance) — the Automodel-native path. Hyperparameters come from the config;**loss_kwargsmust be empty. - Loss class / callable (e.g.
MaskedCrossEntropy) plus**loss_kwargs— the integration / YAML escape hatch. The caller resolves any dotted path to a callable; the component never does string resolution.
Parameters:
Typed :class:LossConfig instance, or a loss class/callable to
construct with **loss_kwargs.
Constructor kwargs for the class/callable form. Must be
empty when loss is a typed config.
Returns: nn.Module
Instantiated loss function.