nemo_automodel.components.loss.soft_ce#
Loss utilities for EAGLE-3 training.
Module Contents#
Functions#
Compute masked soft-target cross entropy. |
API#
- nemo_automodel.components.loss.soft_ce.masked_soft_cross_entropy(
- logits: torch.Tensor,
- target_probs: torch.Tensor,
- position_mask: torch.Tensor,
Compute masked soft-target cross entropy.
Important implementation notes:
The target alignment should still follow the original EAGLE-3 training flow: target logits / input ids are shifted in target preparation, as in the reference SpecForge implementation.
We intentionally do not preserve the original loss reduction from the EAGLE / SpecForge code. The reference implementation averages over
batch * seq_leneven when only a small subset of positions is valid. That reduction is not a sound masked-loss definition because the loss scale changes with padding / sparse supervision density. Here we normalize by the number of valid supervised positions, which is the correct masked objective.
- Parameters:
logits – Draft logits of shape
[batch, seq_len, draft_vocab_size].target_probs – Target distributions aligned to the draft vocabulary.
position_mask – Boolean/0-1 mask of shape
[batch, seq_len, 1].
- Returns:
Scalar loss normalized by the number of valid positions.