nemo_automodel.components.loss.kd_loss

View as Markdown

Module Contents

Classes

NameDescription
KDLossForward KL divergence loss for knowledge distillation.

Functions

NameDescription
_infer_tp_group_from_dtensorIf logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.
_kl_forward_chunkedCompute per-token sum(P * log Q) in chunks to reduce peak memory.
_kl_forward_tpCompute per-token negative cross-entropy sum(P * log Q) with tensor parallelism.

Data

_HAVE_DTENSOR

API

class nemo_automodel.components.loss.kd_loss.KDLoss(
ignore_index: int = -100,
temperature: float = 1.0,
fp32_upcast: bool = True,
tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
chunk_size: int = 0
)

Bases: Module

Forward KL divergence loss for knowledge distillation.

Computes KL(P_teacher ‖ P_student) averaged over valid (non-padding) tokens.

Supports tensor-parallel (TP) training: when logits are vocab-sharded DTensors, the TP group is inferred automatically and a distributed softmax is used to avoid gathering the full vocabulary on each rank. A tp_group can also be supplied explicitly.

Parameters:

ignore_index
intDefaults to -100

Label value marking padding tokens (default -100).

temperature
floatDefaults to 1.0

Softmax temperature T. Both teacher and student logits are divided by T before computing probabilities. The loss is then multiplied by T² so that gradient magnitudes remain independent of the chosen temperature (Hinton et al., 2015).

fp32_upcast
boolDefaults to True

Cast logits to float32 before computing softmax / log-softmax for numerical stability (default True).

tp_group
Optional[torch.distributed.ProcessGroup]Defaults to None

Explicit TP process group. When None (default) the group is inferred from the DTensor placement of student_logits, or the non-TP path is used for plain tensors.

chunk_size
intDefaults to 0

When positive, valid tokens are processed in chunks of this size to avoid materializing the full [num_valid_tokens, vocab_size] probability matrix in fp32. Reduces peak memory at the cost of slightly more kernel launches. 0 (default) disables chunking. Ignored when using the TP path.

nemo_automodel.components.loss.kd_loss.KDLoss.forward(
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
num_batch_labels: int | None = None
) -> torch.Tensor

Compute the KD loss.

Parameters:

student_logits
torch.Tensor

Shape [*, vocab_size] or [*, local_vocab_size] for TP.

teacher_logits
torch.Tensor

Same shape as student_logits.

labels
torch.Tensor

Shape [*]. Positions equal to ignore_index are excluded from the loss.

num_batch_labels
int | NoneDefaults to None

Total number of valid tokens across all gradient-accumulation steps. When provided the loss is sum(kl_per_token) / num_batch_labels; otherwise it is mean(kl_per_token) over the valid tokens in this micro-batch.

Returns: torch.Tensor

Scalar KD loss.

nemo_automodel.components.loss.kd_loss._infer_tp_group_from_dtensor(
logits: torch.Tensor
) -> typing.Optional[torch.distributed.ProcessGroup]

If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.

Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim Shard and returns the corresponding process group. Returns None for plain tensors or DTensors that are not vocab-sharded.

nemo_automodel.components.loss.kd_loss._kl_forward_chunked(
t_logits: torch.Tensor,
s_logits: torch.Tensor,
chunk_size: int
) -> torch.Tensor

Compute per-token sum(P * log Q) in chunks to reduce peak memory.

Processes chunk_size tokens at a time so that only one chunk’s worth of the [chunk_size, vocab_size] fp32 probability matrix is live at any moment.

Parameters:

t_logits
torch.Tensor

Teacher logits, shape [num_valid_tokens, vocab_size].

s_logits
torch.Tensor

Student logits, shape [num_valid_tokens, vocab_size].

chunk_size
int

Number of tokens per chunk.

Returns: torch.Tensor

Per-token sum(P * log Q), shape [num_valid_tokens].

nemo_automodel.components.loss.kd_loss._kl_forward_tp(
t_logits: torch.Tensor,
s_logits: torch.Tensor,
tp_group: torch.distributed.ProcessGroup
) -> torch.Tensor

Compute per-token negative cross-entropy sum(P * log Q) with tensor parallelism.

Both t_logits and s_logits are local vocab-sharded tensors of shape [valid_tokens, local_vocab_size]. A numerically stable global softmax / log-softmax is computed via all_reduce over tp_group, avoiding the need to gather the full vocab.

Parameters:

t_logits
torch.Tensor

Local teacher logit shard, shape [valid_tokens, local_vocab_size].

s_logits
torch.Tensor

Local student logit shard, shape [valid_tokens, local_vocab_size].

tp_group
torch.distributed.ProcessGroup

Process group spanning the tensor-parallel ranks.

Returns: torch.Tensor

Per-token sum(P * log Q), shape [valid_tokens]. This is the negative KL term;

nemo_automodel.components.loss.kd_loss._HAVE_DTENSOR = True