nemo_automodel.components.loss.kd_loss#
Module Contents#
Classes#
Forward KL divergence loss for knowledge distillation. |
Functions#
If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group. |
|
Compute per-token negative cross-entropy |
|
Compute per-token sum(P * log Q) in chunks to reduce peak memory. |
API#
- nemo_automodel.components.loss.kd_loss._infer_tp_group_from_dtensor(
- logits: torch.Tensor,
If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.
Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim
Shardand returns the corresponding process group. ReturnsNonefor plain tensors or DTensors that are not vocab-sharded.
- nemo_automodel.components.loss.kd_loss._kl_forward_tp(
- t_logits: torch.Tensor,
- s_logits: torch.Tensor,
- tp_group: torch.distributed.ProcessGroup,
Compute per-token negative cross-entropy
sum(P * log Q)with tensor parallelism.Both
t_logitsands_logitsare local vocab-sharded tensors of shape[valid_tokens, local_vocab_size]. A numerically stable global softmax / log-softmax is computed viaall_reduceovertp_group, avoiding the need to gather the full vocab.- Parameters:
t_logits β Local teacher logit shard, shape
[valid_tokens, local_vocab_size].s_logits β Local student logit shard, shape
[valid_tokens, local_vocab_size].tp_group β Process group spanning the tensor-parallel ranks.
- Returns:
Per-token sum(P * log Q), shape
[valid_tokens]. This is the negative KL term; negate and average in the caller to obtain the final loss.
- nemo_automodel.components.loss.kd_loss._kl_forward_chunked(
- t_logits: torch.Tensor,
- s_logits: torch.Tensor,
- chunk_size: int,
Compute per-token sum(P * log Q) in chunks to reduce peak memory.
Processes
chunk_sizetokens at a time so that only one chunkβs worth of the[chunk_size, vocab_size]fp32 probability matrix is live at any moment.- Parameters:
t_logits β Teacher logits, shape
[num_valid_tokens, vocab_size].s_logits β Student logits, shape
[num_valid_tokens, vocab_size].chunk_size β Number of tokens per chunk.
- Returns:
Per-token sum(P * log Q), shape
[num_valid_tokens].
- class nemo_automodel.components.loss.kd_loss.KDLoss(
- ignore_index: int = -100,
- temperature: float = 1.0,
- fp32_upcast: bool = True,
- tp_group: Optional[torch.distributed.ProcessGroup] = None,
- chunk_size: int = 0,
Bases:
torch.nn.ModuleForward KL divergence loss for knowledge distillation.
Computes
KL(P_teacher β P_student)averaged over valid (non-padding) tokens.Supports tensor-parallel (TP) training: when logits are vocab-sharded
DTensors, the TP group is inferred automatically and a distributed softmax is used to avoid gathering the full vocabulary on each rank. Atp_groupcan also be supplied explicitly.- Parameters:
ignore_index β Label value marking padding tokens (default
-100).temperature β Softmax temperature T. Both teacher and student logits are divided by T before computing probabilities. The loss is then multiplied by TΒ² so that gradient magnitudes remain independent of the chosen temperature (Hinton et al., 2015).
fp32_upcast β Cast logits to float32 before computing softmax / log-softmax for numerical stability (default
True).tp_group β Explicit TP process group. When
None(default) the group is inferred from the DTensor placement ofstudent_logits, or the non-TP path is used for plain tensors.chunk_size β When positive, valid tokens are processed in chunks of this size to avoid materializing the full
[num_valid_tokens, vocab_size]probability matrix in fp32. Reduces peak memory at the cost of slightly more kernel launches.0(default) disables chunking. Ignored when using the TP path.
Initialization
- forward(
- student_logits: torch.Tensor,
- teacher_logits: torch.Tensor,
- labels: torch.Tensor,
- num_batch_labels: int | None = None,
Compute the KD loss.
- Parameters:
student_logits β Shape
[*, vocab_size]or[*, local_vocab_size]for TP.teacher_logits β Same shape as
student_logits.labels β Shape
[*]. Positions equal toignore_indexare excluded from the loss.num_batch_labels β Total number of valid tokens across all gradient-accumulation steps. When provided the loss is
sum(kl_per_token) / num_batch_labels; otherwise it ismean(kl_per_token)over the valid tokens in this micro-batch.
- Returns:
Scalar KD loss.