nemo_automodel.components.loss.kd_loss#

Module Contents#

Classes#

API#

class nemo_automodel.components.loss.kd_loss.KDLoss(
ignore_index: int = -100,
temperature: float = 1.0,
fp32_upcast: bool = True,
)#

Bases: torch.nn.Module

Initialization

forward(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
num_batch_labels: int | None = None,
) torch.Tensor#

Calculates KL(P_teacher‖P_student) averaged over valid tokens.

Logits are (optionally) cast to fp32 for numerical stability, probabilities are obtained with softmax / log_softmax after temperature scaling, and padding tokens (== ignore_index) are ignored in the average.

Parameters:
  • student_logits (torch.Tensor) – The logits of the student model.

  • teacher_logits (torch.Tensor) – The logits of the teacher model.

  • labels (torch.Tensor) – The labels of the batch.

  • num_batch_labels (int | None) – The number of valid labels in the batch.

Important note on num_batch_labels: - if num_batch_labels is None, it will return the mean over kl_per_token. - if num_batch_labels is not None, it will return the sum(kl_per_token) / num_batch_labels. Please do note that usually, num_batch_labels > #valid labels in labels tensor, for example, when doing gradient accumulation.

We prefer the num_batch_labels variable over counting the number of valid labels in the batch,
to allow for easier handling when doing gradient accumulation and per-token loss computation.
Returns:

The KL loss.