nemo_automodel.components.loss.loss

View as Markdown

Typed loss configs + builder (TorchTitan-style).

Each loss config is a plain dataclass exposing its full parameter surface as named fields (no opaque **kwargs) and a build() method that constructs the loss module directly — lazy imports keep optional kernel deps out of module load. Reading the dataclass tells you exactly what you can configure.

:func:build_loss_config normalizes any of these into a :class:LossConfig, and :func:build_loss_module is the single build entry point — a thin wrapper that returns build_loss_config(...).build(). Both dispatch on the argument:

  • a typed :class:LossConfig instance — the Automodel-native path; per-loss construction delegates to config.build().
  • a registry name or a known loss class (see :data:LOSS_CONFIG_REGISTRY, e.g. MaskedCrossEntropy → :class:MaskedCrossEntropyConfig) — upgraded to its typed config so the YAML recipe path gets the same field validation.
  • any other loss class / callable plus arbitrary **loss_kwargs — the escape hatch (wrapped in :class:LossFromFactoryConfig) for external integrations (e.g. veRL). Adding a new typed config never forces the caller to change.

Module Contents

Classes

NameDescription
FusedLinearCEConfigFusedLinearCrossEntropy.
KDLossConfigKDLoss (knowledge distillation).
LossConfigBase loss config. Subclasses expose their full field surface and
LossFromFactoryConfigEscape hatch for external integrations (e.g. veRL) and the YAML recipe path.
MaskedCrossEntropyConfigMaskedCrossEntropy.
TEParallelCEConfigTEParallelCrossEntropy.

Functions

NameDescription
build_loss_configNormalize a loss loss target plus kwargs into a :class:LossConfig.
build_loss_moduleBuild a loss function.

Data

LOSS_CONFIG_REGISTRY

__all__

API

class nemo_automodel.components.loss.loss.FusedLinearCEConfig(
ignore_index: int = -100,
logit_softcapping: float = 0.0,
reduction: str = 'sum'
)
Dataclass

Bases: LossConfig

FusedLinearCrossEntropy.

ignore_index
int = -100
logit_softcapping
float = 0.0
reduction
str = 'sum'
nemo_automodel.components.loss.loss.FusedLinearCEConfig.build() -> torch.nn.Module
class nemo_automodel.components.loss.loss.KDLossConfig(
ignore_index: int = -100,
temperature: float = 1.0,
fp32_upcast: bool = True,
tp_group: typing.Any = None,
chunk_size: int = 0
)
Dataclass

Bases: LossConfig

KDLoss (knowledge distillation).

chunk_size
int = 0
fp32_upcast
bool = True
ignore_index
int = -100
temperature
float = 1.0
nemo_automodel.components.loss.loss.KDLossConfig.build() -> torch.nn.Module
class nemo_automodel.components.loss.loss.LossConfig()
Dataclass

Base loss config. Subclasses expose their full field surface and implement :meth:build.

nemo_automodel.components.loss.loss.LossConfig.build() -> torch.nn.Module

Construct the loss module.

class nemo_automodel.components.loss.loss.LossFromFactoryConfig(
factory: collections.abc.Callable[..., torch.nn.Module] | None = None,
kwargs: dict[str, typing.Any] = dict()
)
Dataclass

Bases: LossConfig

Escape hatch for external integrations (e.g. veRL) and the YAML recipe path.

Rather than exposing typed fields, it wraps a loss class/callable and the **kwargs to construct it, keeping the factory path on the same config.build() contract as the typed configs so callers never have to special-case it. The factory is called as factory(**kwargs).

factory
Callable[..., Module] | None = None
kwargs
dict[str, Any] = field(default_factory=dict)
nemo_automodel.components.loss.loss.LossFromFactoryConfig.build() -> torch.nn.Module
class nemo_automodel.components.loss.loss.MaskedCrossEntropyConfig(
fp32_upcast: bool = True,
ignore_index: int = -100,
reduction: str = 'sum'
)
Dataclass

Bases: LossConfig

MaskedCrossEntropy.

fp32_upcast
bool = True
ignore_index
int = -100
reduction
str = 'sum'
nemo_automodel.components.loss.loss.MaskedCrossEntropyConfig.build() -> torch.nn.Module
class nemo_automodel.components.loss.loss.TEParallelCEConfig(
ignore_index: int = -100,
reduction: str = 'sum',
tp_group: typing.Any = None
)
Dataclass

Bases: LossConfig

TEParallelCrossEntropy.

ignore_index
int = -100
reduction
str = 'sum'
nemo_automodel.components.loss.loss.TEParallelCEConfig.build() -> torch.nn.Module
nemo_automodel.components.loss.loss.build_loss_config(
loss: nemo_automodel.components.loss.loss.LossConfig | str | collections.abc.Callable[..., torch.nn.Module],
loss_kwargs: typing.Any = {}
) -> nemo_automodel.components.loss.loss.LossConfig

Normalize a loss loss target plus kwargs into a :class:LossConfig.

The single normalization entry point shared by the recipe layer (which resolves a YAML _target_ to a Python object) and :func:build_loss_module. It dispatches on loss:

  • a typed :class:LossConfig instance — returned as-is; **loss_kwargs must be empty (hyperparameters live on the config).
  • a :class:LossConfig subclass (the class) — rejected; pass an instance.
  • a registry name (see :data:LOSS_CONFIG_REGISTRY, e.g. "MaskedCrossEntropy") — built from the matching typed config.
  • a loss class / callable plus **loss_kwargs — if it is a registered loss class and the kwargs fit the config’s fields, it is upgraded to that typed config; otherwise it is wrapped in a :class:LossFromFactoryConfig. The caller resolves any dotted path to a callable; the component never does dotted-path resolution.

Returns: LossConfig

class:LossConfig ready to build().

nemo_automodel.components.loss.loss.build_loss_module(
loss: nemo_automodel.components.loss.loss.LossConfig | collections.abc.Callable[..., torch.nn.Module],
loss_kwargs: typing.Any = {}
) -> torch.nn.Module

Build a loss function.

Thin dispatcher: it normalizes loss to a :class:LossConfig via :func:build_loss_config and returns config.build(). Dispatches on loss:

  • Typed config (:class:LossConfig instance) — the Automodel-native path. Hyperparameters come from the config; **loss_kwargs must be empty.
  • Loss class / callable (e.g. MaskedCrossEntropy) plus **loss_kwargs — the integration / YAML escape hatch. The caller resolves any dotted path to a callable; the component never does string resolution.

Parameters:

loss
LossConfig | Callable[..., nn.Module]

Typed :class:LossConfig instance, or a loss class/callable to construct with **loss_kwargs.

**loss_kwargs
AnyDefaults to {}

Constructor kwargs for the class/callable form. Must be empty when loss is a typed config.

Returns: nn.Module

Instantiated loss function.

nemo_automodel.components.loss.loss.LOSS_CONFIG_REGISTRY: dict[str, type[LossConfig]] = {'MaskedCrossEntropy': MaskedCrossEntropyConfig, 'FusedLinearCrossEntropy': Fuse...
nemo_automodel.components.loss.loss.__all__ = ['LOSS_CONFIG_REGISTRY', 'FusedLinearCEConfig', 'KDLossConfig', 'LossConfig', 'L...