nemo_automodel.components.loss.soft_ce

View as Markdown

Loss utilities for EAGLE-3 training.

Module Contents

Functions

NameDescription
masked_soft_cross_entropyCompute masked soft-target cross entropy.

API

nemo_automodel.components.loss.soft_ce.masked_soft_cross_entropy(
logits: torch.Tensor,
target_probs: torch.Tensor,
position_mask: torch.Tensor
) -> torch.Tensor

Compute masked soft-target cross entropy.

Important implementation notes:

  1. The target alignment should still follow the original EAGLE-3 training flow: target logits / input ids are shifted in target preparation, as in the reference SpecForge implementation.
  2. We intentionally do not preserve the original loss reduction from the EAGLE / SpecForge code. The reference implementation averages over batch * seq_len even when only a small subset of positions is valid. That reduction is not a sound masked-loss definition because the loss scale changes with padding / sparse supervision density. Here we normalize by the number of valid supervised positions, which is the correct masked objective.

Parameters:

logits
torch.Tensor

Draft logits of shape [batch, seq_len, draft_vocab_size].

target_probs
torch.Tensor

Target distributions aligned to the draft vocabulary.

position_mask
torch.Tensor

Boolean/0-1 mask of shape [batch, seq_len, 1].

Returns: torch.Tensor

Scalar loss normalized by the number of valid positions.