> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.loss.kd_loss

## Module Contents

### Classes

| Name                                                       | Description                                            |
| ---------------------------------------------------------- | ------------------------------------------------------ |
| [`KDLoss`](#nemo_automodel-components-loss-kd_loss-KDLoss) | Forward KL divergence loss for knowledge distillation. |

### Functions

| Name                                                                                                   | Description                                                                                  |
| ------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------- |
| [`_infer_tp_group_from_dtensor`](#nemo_automodel-components-loss-kd_loss-_infer_tp_group_from_dtensor) | If *logits* is a DTensor sharded on the vocab (last) dimension, return its TP process group. |
| [`_kl_forward_chunked`](#nemo_automodel-components-loss-kd_loss-_kl_forward_chunked)                   | Compute per-token sum(P \* log Q) in chunks to reduce peak memory.                           |
| [`_kl_forward_tp`](#nemo_automodel-components-loss-kd_loss-_kl_forward_tp)                             | Compute per-token negative cross-entropy `sum(P * log Q)` with tensor parallelism.           |

### Data

[`_HAVE_DTENSOR`](#nemo_automodel-components-loss-kd_loss-_HAVE_DTENSOR)

### API

```python
class nemo_automodel.components.loss.kd_loss.KDLoss(
    ignore_index: int = -100,
    temperature: float = 1.0,
    fp32_upcast: bool = True,
    tp_group: typing.Optional[torch.distributed.ProcessGroup] = None,
    chunk_size: int = 0
)
```

**Bases:** `Module`

Forward KL divergence loss for knowledge distillation.

Computes `KL(P_teacher ‖ P_student)` averaged over valid (non-padding) tokens.

Supports tensor-parallel (TP) training: when logits are vocab-sharded `DTensor`s, the TP
group is inferred automatically and a distributed softmax is used to avoid gathering the full
vocabulary on each rank.  A `tp_group` can also be supplied explicitly.

**Parameters:**

Label value marking padding tokens (default `-100`).

Softmax temperature *T*.  Both teacher and student logits are divided by *T*
before computing probabilities.  The loss is then multiplied by *T²* so that gradient
magnitudes remain independent of the chosen temperature (Hinton et al., 2015).

Cast logits to float32 before computing softmax / log-softmax for numerical
stability (default `True`).

Explicit TP process group.  When `None` (default) the group is inferred from
the DTensor placement of `student_logits`, or the non-TP path is used for plain
tensors.

When positive, valid tokens are processed in chunks of this size to avoid
materializing the full `[num_valid_tokens, vocab_size]` probability matrix in fp32.
Reduces peak memory at the cost of slightly more kernel launches.  `0` (default)
disables chunking.  Ignored when using the TP path.

```python
nemo_automodel.components.loss.kd_loss.KDLoss.forward(
    student_logits: torch.Tensor,
    teacher_logits: torch.Tensor,
    labels: torch.Tensor,
    num_batch_labels: int | None = None
) -> torch.Tensor
```

Compute the KD loss.

**Parameters:**

Shape `[*, vocab_size]` or `[*, local_vocab_size]` for TP.

Same shape as `student_logits`.

Shape `[*]`.  Positions equal to `ignore_index` are excluded from the loss.

Total number of valid tokens across all gradient-accumulation steps.
When provided the loss is `sum(kl_per_token) / num_batch_labels`; otherwise it
is `mean(kl_per_token)` over the valid tokens in this micro-batch.

**Returns:** `torch.Tensor`

Scalar KD loss.

```python
nemo_automodel.components.loss.kd_loss._infer_tp_group_from_dtensor(
    logits: torch.Tensor
) -> typing.Optional[torch.distributed.ProcessGroup]
```

If *logits* is a DTensor sharded on the vocab (last) dimension, return its TP process group.

Iterates over the DTensor placements to find the mesh dimension that holds a vocab-dim
`Shard` and returns the corresponding process group.  Returns `None` for plain tensors
or DTensors that are not vocab-sharded.

```python
nemo_automodel.components.loss.kd_loss._kl_forward_chunked(
    t_logits: torch.Tensor,
    s_logits: torch.Tensor,
    chunk_size: int
) -> torch.Tensor
```

Compute per-token sum(P \* log Q) in chunks to reduce peak memory.

Processes `chunk_size` tokens at a time so that only one chunk's worth of the
`[chunk_size, vocab_size]` fp32 probability matrix is live at any moment.

**Parameters:**

Teacher logits, shape `[num_valid_tokens, vocab_size]`.

Student logits, shape `[num_valid_tokens, vocab_size]`.

Number of tokens per chunk.

**Returns:** `torch.Tensor`

Per-token sum(P \* log Q), shape `[num_valid_tokens]`.

```python
nemo_automodel.components.loss.kd_loss._kl_forward_tp(
    t_logits: torch.Tensor,
    s_logits: torch.Tensor,
    tp_group: torch.distributed.ProcessGroup
) -> torch.Tensor
```

Compute per-token negative cross-entropy `sum(P * log Q)` with tensor parallelism.

Both `t_logits` and `s_logits` are **local** vocab-sharded tensors of shape
`[valid_tokens, local_vocab_size]`.  A numerically stable global softmax / log-softmax is
computed via `all_reduce` over `tp_group`, avoiding the need to gather the full vocab.

**Parameters:**

Local teacher logit shard, shape `[valid_tokens, local_vocab_size]`.

Local student logit shard, shape `[valid_tokens, local_vocab_size]`.

Process group spanning the tensor-parallel ranks.

**Returns:** `torch.Tensor`

Per-token sum(P \* log Q), shape `[valid_tokens]`.  This is the *negative* KL term;

```python
nemo_automodel.components.loss.kd_loss._HAVE_DTENSOR = True
```