nemo_automodel.components.loss.kd_loss
#
Module Contents#
Classes#
API#
- class nemo_automodel.components.loss.kd_loss.KDLoss(
- ignore_index: int = -100,
- temperature: float = 1.0,
- fp32_upcast: bool = True,
Bases:
torch.nn.Module
Initialization
- forward(
- student_logits: torch.Tensor,
- teacher_logits: torch.Tensor,
- labels: torch.Tensor,
- num_batch_labels: int | None = None,
Calculates KL(P_teacher‖P_student) averaged over valid tokens.
Logits are (optionally) cast to fp32 for numerical stability, probabilities are obtained with softmax / log_softmax after temperature scaling, and padding tokens (== ignore_index) are ignored in the average.
- Parameters:
student_logits (torch.Tensor) – The logits of the student model.
teacher_logits (torch.Tensor) – The logits of the teacher model.
labels (torch.Tensor) – The labels of the batch.
num_batch_labels (int | None) – The number of valid labels in the batch.
Important note on num_batch_labels: - if
num_batch_labels
is None, it will return the mean over kl_per_token. - ifnum_batch_labels
is not None, it will return the sum(kl_per_token) / num_batch_labels. Please do note that usually, num_batch_labels > #valid labels in labels tensor, for example, when doing gradient accumulation.We prefer the num_batch_labels variable over counting the number of valid labels in the batch, to allow for easier handling when doing gradient accumulation and per-token loss computation.
- Returns:
The KL loss.