nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss

View as Markdown

Module Contents

Classes

NameDescription
NemotronParseLossCross-entropy loss with coordinate token weighting for NemotronParse.

API

class nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss.NemotronParseLoss(
coordinate_weight: float = 10.0,
class_token_start_idx: int = 50000,
num_heads: int = 1,
ignore_index: int = -100,
reduction: str = 'sum',
fp32_upcast: bool = True
)

Bases: Module

Cross-entropy loss with coordinate token weighting for NemotronParse.

This loss function computes cross-entropy across prediction heads with configurable weighting for coordinate tokens (tokens >= class_token_start_idx). When num_heads > 1, it implements per-head label shifting for multi-task output predictions.

Parameters:

coordinate_weight
floatDefaults to 10.0

Weight multiplier for coordinate tokens. Tokens with label IDs >= class_token_start_idx will have their loss multiplied by this factor. Default: 10.0

class_token_start_idx
intDefaults to 50000

Token index threshold for coordinate tokens. Tokens with label IDs >= this value are considered coordinate/class tokens and receive higher loss weight. Default: 50000

num_heads
intDefaults to 1

Number of prediction heads (main + extra). Must match the model’s num_extra_heads + 1. Default: 1

ignore_index
intDefaults to -100

Label value to ignore in loss computation. Default: -100

reduction
strDefaults to 'sum'

Loss reduction strategy (“sum” or “mean”). Default: “sum”

fp32_upcast
boolDefaults to True

Cast logits to fp32 for numerical stability. Default: True

nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss.NemotronParseLoss.forward(
logits: torch.Tensor,
labels: torch.Tensor,
decoder_inputs_embeds: typing.Optional[torch.Tensor] = None,
num_label_tokens: typing.Optional[int] = None
) -> torch.Tensor

Compute loss with coordinate token weighting.

Parameters:

logits
torch.Tensor

Model logits with shape [batch_size, seq_len, vocab_size]

labels
torch.Tensor

Ground truth labels with shape [batch_size, seq_len]

decoder_inputs_embeds
torch.TensorDefaults to None

Decoder input embeddings. Currently unused but kept for API compatibility. Default: None

num_label_tokens
intDefaults to None

Total number of valid tokens for normalization across gradient accumulation steps. If provided, loss is normalized by this value instead of the actual token count. Only supported with reduction=“sum”. Default: None

Returns: torch.Tensor

torch.Tensor: Computed loss value as a scalar tensor.