nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss#

Module Contents#

Classes#

NemotronParseLoss

Cross-entropy loss with coordinate token weighting for NemotronParse.

API#

class nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss.NemotronParseLoss(
coordinate_weight: float = 10.0,
class_token_start_idx: int = 50000,
num_heads: int = 1,
ignore_index: int = -100,
reduction: str = 'sum',
fp32_upcast: bool = True,
)#

Bases: torch.nn.Module

Cross-entropy loss with coordinate token weighting for NemotronParse.

This loss function computes cross-entropy across prediction heads with configurable weighting for coordinate tokens (tokens >= class_token_start_idx). When num_heads > 1, it implements per-head label shifting for multi-task output predictions.

Parameters:
  • coordinate_weight (float) – Weight multiplier for coordinate tokens. Tokens with label IDs >= class_token_start_idx will have their loss multiplied by this factor. Default: 10.0

  • class_token_start_idx (int) – Token index threshold for coordinate tokens. Tokens with label IDs >= this value are considered coordinate/class tokens and receive higher loss weight. Default: 50000

  • num_heads (int) – Number of prediction heads (main + extra). Must match the model’s num_extra_heads + 1. Default: 1

  • ignore_index (int) – Label value to ignore in loss computation. Default: -100

  • reduction (str) – Loss reduction strategy (β€œsum” or β€œmean”). Default: β€œsum”

  • fp32_upcast (bool) – Cast logits to fp32 for numerical stability. Default: True

.. rubric:: Example

loss_fn = NemotronParseLoss( … coordinate_weight=10.0, … class_token_start_idx=50000, … num_heads=1, … )

logits shape: [batch, seq_len, vocab_size]

labels shape: [batch, seq_len]

loss = loss_fn(logits=logits, labels=labels)

Initialization

forward(
logits: torch.Tensor,
labels: torch.Tensor,
decoder_inputs_embeds: Optional[torch.Tensor] = None,
num_label_tokens: Optional[int] = None,
) torch.Tensor#

Compute loss with coordinate token weighting.

Parameters:
  • logits (torch.Tensor) – Model logits with shape [batch_size, seq_len, vocab_size]

  • labels (torch.Tensor) – Ground truth labels with shape [batch_size, seq_len]

  • decoder_inputs_embeds (torch.Tensor, optional) – Decoder input embeddings. Currently unused but kept for API compatibility. Default: None

  • num_label_tokens (int, optional) – Total number of valid tokens for normalization across gradient accumulation steps. If provided, loss is normalized by this value instead of the actual token count. Only supported with reduction=”sum”. Default: None

Returns:

Computed loss value as a scalar tensor.

Return type:

torch.Tensor