nemo_automodel.components.loss.kd_loss
nemo_automodel.components.loss.kd_loss
Module Contents
Classes
Functions
Data
API
Bases: Module
Forward KL divergence loss for knowledge distillation.
Computes KL(P_teacher ‖ P_student) averaged over valid (non-padding) tokens.
Supports tensor-parallel (TP) training: when logits are vocab-sharded DTensors, the TP
group is inferred automatically and a distributed softmax is used to avoid gathering the full
vocabulary on each rank. A tp_group can also be supplied explicitly.
Parameters:
Label value marking padding tokens (default -100).
Softmax temperature T. Both teacher and student logits are divided by T before computing probabilities. The loss is then multiplied by T² so that gradient magnitudes remain independent of the chosen temperature (Hinton et al., 2015).
Cast logits to float32 before computing softmax / log-softmax for numerical
stability (default True).
Explicit TP process group. When None (default) the group is inferred from
the DTensor placement of student_logits, or the non-TP path is used for plain
tensors.
When positive, valid tokens are processed in chunks of this size to avoid
materializing the full [num_valid_tokens, vocab_size] probability matrix in fp32.
Reduces peak memory at the cost of slightly more kernel launches. 0 (default)
disables chunking. Ignored when using the TP path.
Compute the KD loss.
Parameters:
Shape [*, vocab_size] or [*, local_vocab_size] for TP.
Same shape as student_logits.
Shape [*]. Positions equal to ignore_index are excluded from the loss.
Total number of valid tokens across all gradient-accumulation steps.
When provided the loss is sum(kl_per_token) / num_batch_labels; otherwise it
is mean(kl_per_token) over the valid tokens in this micro-batch.
Returns: torch.Tensor
Scalar KD loss.
If logits is a DTensor sharded on the vocab (last) dimension, return its TP process group.
Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim
Shard and returns the corresponding process group. Returns None for plain tensors
or DTensors that are not vocab-sharded.
Compute per-token sum(P * log Q) in chunks to reduce peak memory.
Processes chunk_size tokens at a time so that only one chunk’s worth of the
[chunk_size, vocab_size] fp32 probability matrix is live at any moment.
Parameters:
Teacher logits, shape [num_valid_tokens, vocab_size].
Student logits, shape [num_valid_tokens, vocab_size].
Number of tokens per chunk.
Returns: torch.Tensor
Per-token sum(P * log Q), shape [num_valid_tokens].
Compute per-token negative cross-entropy sum(P * log Q) with tensor parallelism.
Both t_logits and s_logits are local vocab-sharded tensors of shape
[valid_tokens, local_vocab_size]. A numerically stable global softmax / log-softmax is
computed via all_reduce over tp_group, avoiding the need to gather the full vocab.
Parameters:
Local teacher logit shard, shape [valid_tokens, local_vocab_size].
Local student logit shard, shape [valid_tokens, local_vocab_size].
Process group spanning the tensor-parallel ranks.
Returns: torch.Tensor
Per-token sum(P * log Q), shape [valid_tokens]. This is the negative KL term;