nemo_automodel.components.loss.kd_loss#
Module Contents#
Classes#
Forward KL divergence loss for knowledge distillation. |
Functions#
If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group. |
|
Compute per-token negative cross-entropy |
API#
- nemo_automodel.components.loss.kd_loss._infer_tp_group_from_dtensor(
- logits: torch.Tensor,
If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.
Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim
Shardand returns the corresponding process group. ReturnsNonefor plain tensors or DTensors that are not vocab-sharded.
- nemo_automodel.components.loss.kd_loss._kl_forward_tp(
- t_logits: torch.Tensor,
- s_logits: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
Compute per-token negative cross-entropy
sum(P * log Q)with tensor parallelism.Both
t_logitsands_logitsare local vocab-sharded tensors of shape[valid_tokens, local_vocab_size]. A numerically stable global softmax / log-softmax is computed viaall_reduceovertp_group, avoiding the need to gather the full vocab.- Parameters:
t_logits – Local teacher logit shard, shape
[valid_tokens, local_vocab_size].s_logits – Local student logit shard, shape
[valid_tokens, local_vocab_size].tp_group – Process group spanning the tensor-parallel ranks.
- Returns:
Per-token sum(P * log Q), shape
[valid_tokens]. This is the negative KL term; negate and average in the caller to obtain the final loss.
- class nemo_automodel.components.loss.kd_loss.KDLoss(
- ignore_index: int = -100,
- temperature: float = 1.0,
- fp32_upcast: bool = True,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
torch.nn.ModuleForward KL divergence loss for knowledge distillation.
Computes
KL(P_teacher ‖ P_student)averaged over valid (non-padding) tokens.Supports tensor-parallel (TP) training: when logits are vocab-sharded
DTensors, the TP group is inferred automatically and a distributed softmax is used to avoid gathering the full vocabulary on each rank. Atp_groupcan also be supplied explicitly.- Parameters:
ignore_index – Label value marking padding tokens (default
-100).temperature – Softmax temperature T. Both teacher and student logits are divided by T before computing probabilities. The loss is then multiplied by T² so that gradient magnitudes remain independent of the chosen temperature (Hinton et al., 2015).
fp32_upcast – Cast logits to float32 before computing softmax / log-softmax for numerical stability (default
True).tp_group – Explicit TP process group. When
None(default) the group is inferred from the DTensor placement ofstudent_logits, or the non-TP path is used for plain tensors.
Initialization
- forward(
- student_logits: torch.Tensor,
- teacher_logits: torch.Tensor,
- labels: torch.Tensor,
- num_batch_labels: int | None = None,
Compute the KD loss.
- Parameters:
student_logits – Shape
[*, vocab_size]or[*, local_vocab_size]for TP.teacher_logits – Same shape as
student_logits.labels – Shape
[*]. Positions equal toignore_indexare excluded from the loss.num_batch_labels – Total number of valid tokens across all gradient-accumulation steps. When provided the loss is
sum(kl_per_token) / num_batch_labels; otherwise it ismean(kl_per_token)over the valid tokens in this micro-batch.
- Returns:
Scalar KD loss.