nemo_automodel.components.loss.soft_ce
nemo_automodel.components.loss.soft_ce
Loss utilities for EAGLE-3 training.
Module Contents
Functions
API
Compute masked soft-target cross entropy.
Important implementation notes:
- The target alignment should still follow the original EAGLE-3 training flow: target logits / input ids are shifted in target preparation, as in the reference SpecForge implementation.
- We intentionally do not preserve the original loss reduction from the
EAGLE / SpecForge code. The reference implementation averages over
batch * seq_leneven when only a small subset of positions is valid. That reduction is not a sound masked-loss definition because the loss scale changes with padding / sparse supervision density. Here we normalize by the number of valid supervised positions, which is the correct masked objective.
Parameters:
logits
Draft logits of shape [batch, seq_len, draft_vocab_size].
target_probs
Target distributions aligned to the draft vocabulary.
position_mask
Boolean/0-1 mask of shape [batch, seq_len, 1].
Returns: torch.Tensor
Scalar loss normalized by the number of valid positions.