nemo_automodel.components.loss.soft_ce#

Loss utilities for EAGLE-3 training.

Module Contents#

Functions#

masked_soft_cross_entropy

Compute masked soft-target cross entropy.

API#

nemo_automodel.components.loss.soft_ce.masked_soft_cross_entropy(
logits: torch.Tensor,
target_probs: torch.Tensor,
position_mask: torch.Tensor,
) torch.Tensor#

Compute masked soft-target cross entropy.

Important implementation notes:

  1. The target alignment should still follow the original EAGLE-3 training flow: target logits / input ids are shifted in target preparation, as in the reference SpecForge implementation.

  2. We intentionally do not preserve the original loss reduction from the EAGLE / SpecForge code. The reference implementation averages over batch * seq_len even when only a small subset of positions is valid. That reduction is not a sound masked-loss definition because the loss scale changes with padding / sparse supervision density. Here we normalize by the number of valid supervised positions, which is the correct masked objective.

Parameters:
  • logits – Draft logits of shape [batch, seq_len, draft_vocab_size].

  • target_probs – Target distributions aligned to the draft vocabulary.

  • position_mask – Boolean/0-1 mask of shape [batch, seq_len, 1].

Returns:

Scalar loss normalized by the number of valid positions.