nemo_automodel.components.loss.kd_loss#

Module Contents#

Classes#

KDLoss

Forward KL divergence loss for knowledge distillation.

Functions#

_infer_tp_group_from_dtensor

If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.

_kl_forward_tp

Compute per-token negative cross-entropy sum(P * log Q) with tensor parallelism.

API#

nemo_automodel.components.loss.kd_loss._infer_tp_group_from_dtensor(
logits: torch.Tensor,
) Optional[torch.distributed.ProcessGroup]#

If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.

Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim Shard and returns the corresponding process group. Returns None for plain tensors or DTensors that are not vocab-sharded.

nemo_automodel.components.loss.kd_loss._kl_forward_tp(
t_logits: torch.Tensor,
s_logits: torch.Tensor,
tp_group: torch.distributed.ProcessGroup,
) torch.Tensor#

Compute per-token negative cross-entropy sum(P * log Q) with tensor parallelism.

Both t_logits and s_logits are local vocab-sharded tensors of shape [valid_tokens, local_vocab_size]. A numerically stable global softmax / log-softmax is computed via all_reduce over tp_group, avoiding the need to gather the full vocab.

Parameters:
  • t_logits – Local teacher logit shard, shape [valid_tokens, local_vocab_size].

  • s_logits – Local student logit shard, shape [valid_tokens, local_vocab_size].

  • tp_group – Process group spanning the tensor-parallel ranks.

Returns:

Per-token sum(P * log Q), shape [valid_tokens]. This is the negative KL term; negate and average in the caller to obtain the final loss.

class nemo_automodel.components.loss.kd_loss.KDLoss(
ignore_index: int = -100,
temperature: float = 1.0,
fp32_upcast: bool = True,
tp_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: torch.nn.Module

Forward KL divergence loss for knowledge distillation.

Computes KL(P_teacher ‖ P_student) averaged over valid (non-padding) tokens.

Supports tensor-parallel (TP) training: when logits are vocab-sharded DTensors, the TP group is inferred automatically and a distributed softmax is used to avoid gathering the full vocabulary on each rank. A tp_group can also be supplied explicitly.

Parameters:
  • ignore_index – Label value marking padding tokens (default -100).

  • temperature – Softmax temperature T. Both teacher and student logits are divided by T before computing probabilities. The loss is then multiplied by T² so that gradient magnitudes remain independent of the chosen temperature (Hinton et al., 2015).

  • fp32_upcast – Cast logits to float32 before computing softmax / log-softmax for numerical stability (default True).

  • tp_group – Explicit TP process group. When None (default) the group is inferred from the DTensor placement of student_logits, or the non-TP path is used for plain tensors.

Initialization

forward(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
num_batch_labels: int | None = None,
) torch.Tensor#

Compute the KD loss.

Parameters:
  • student_logits – Shape [*, vocab_size] or [*, local_vocab_size] for TP.

  • teacher_logits – Same shape as student_logits.

  • labels – Shape [*]. Positions equal to ignore_index are excluded from the loss.

  • num_batch_labels – Total number of valid tokens across all gradient-accumulation steps. When provided the loss is sum(kl_per_token) / num_batch_labels; otherwise it is mean(kl_per_token) over the valid tokens in this micro-batch.

Returns:

Scalar KD loss.