nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss#
Module Contents#
Classes#
Cross-entropy loss with coordinate token weighting for NemotronParse. |
API#
- class nemo_automodel.components.models.nemotron_parse.nemotron_parse_loss.NemotronParseLoss(
- coordinate_weight: float = 10.0,
- class_token_start_idx: int = 50000,
- num_heads: int = 1,
- ignore_index: int = -100,
- reduction: str = 'sum',
- fp32_upcast: bool = True,
Bases:
torch.nn.ModuleCross-entropy loss with coordinate token weighting for NemotronParse.
This loss function computes cross-entropy across prediction heads with configurable weighting for coordinate tokens (tokens >= class_token_start_idx). When num_heads > 1, it implements per-head label shifting for multi-task output predictions.
- Parameters:
coordinate_weight (float) β Weight multiplier for coordinate tokens. Tokens with label IDs >= class_token_start_idx will have their loss multiplied by this factor. Default: 10.0
class_token_start_idx (int) β Token index threshold for coordinate tokens. Tokens with label IDs >= this value are considered coordinate/class tokens and receive higher loss weight. Default: 50000
num_heads (int) β Number of prediction heads (main + extra). Must match the modelβs num_extra_heads + 1. Default: 1
ignore_index (int) β Label value to ignore in loss computation. Default: -100
reduction (str) β Loss reduction strategy (βsumβ or βmeanβ). Default: βsumβ
fp32_upcast (bool) β Cast logits to fp32 for numerical stability. Default: True
.. rubric:: Example
loss_fn = NemotronParseLoss( β¦ coordinate_weight=10.0, β¦ class_token_start_idx=50000, β¦ num_heads=1, β¦ )
logits shape: [batch, seq_len, vocab_size]
labels shape: [batch, seq_len]
loss = loss_fn(logits=logits, labels=labels)
Initialization
- forward(
- logits: torch.Tensor,
- labels: torch.Tensor,
- decoder_inputs_embeds: Optional[torch.Tensor] = None,
- num_label_tokens: Optional[int] = None,
Compute loss with coordinate token weighting.
- Parameters:
logits (torch.Tensor) β Model logits with shape [batch_size, seq_len, vocab_size]
labels (torch.Tensor) β Ground truth labels with shape [batch_size, seq_len]
decoder_inputs_embeds (torch.Tensor, optional) β Decoder input embeddings. Currently unused but kept for API compatibility. Default: None
num_label_tokens (int, optional) β Total number of valid tokens for normalization across gradient accumulation steps. If provided, loss is normalized by this value instead of the actual token count. Only supported with reduction=βsumβ. Default: None
- Returns:
Computed loss value as a scalar tensor.
- Return type:
torch.Tensor