> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.speculative.eagle.peagle_trainer

P-EAGLE (parallel-drafting EAGLE-3) training logic.

Split out of `core.py` so the P-EAGLE trainer evolves independently of the
EAGLE-3 test-time-training trainer. The shared step-metrics container
(:class:`Eagle3StepMetrics`) still lives in `core.py` and is imported here.

## Module Contents

### Classes

| Name                                                                                                     | Description                                                            |
| -------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------- |
| [`PEagleTrainerModule`](#nemo_automodel-components-speculative-eagle-peagle_trainer-PEagleTrainerModule) | Draft-side P-EAGLE (parallel-drafting EAGLE-3) trainer module.         |
| [`_PeaglePlan`](#nemo_automodel-components-speculative-eagle-peagle_trainer-_PeaglePlan)                 | Sequence-partitioning plan: one `unit` per non-empty `(row, segment)`. |

### Functions

| Name                                                                                       | Description                                                   |
| ------------------------------------------------------------------------------------------ | ------------------------------------------------------------- |
| [`_kl_div_loss`](#nemo_automodel-components-speculative-eagle-peagle_trainer-_kl_div_loss) | Per-position KL(target \|\| draft) over the draft vocabulary. |

### API

```python
class nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule(
    draft_model: torch.nn.Module,
    selected_token_ids: torch.Tensor,
    selected_token_mask: torch.Tensor,
    num_depths: int,
    mask_token_id: int,
    down_sample_ratio: float = 0.7,
    down_sample_ratio_min: float = 0.2,
    sequence_partitions: int = 1
)
```

**Bases:** `Module`

Draft-side P-EAGLE (parallel-drafting EAGLE-3) trainer module.

Faithful port of speculators' P-EAGLE
([https://github.com/vllm-project/speculators/pull/480](https://github.com/vllm-project/speculators/pull/480)): the draft predicts
all `num_depths` tokens in a *single* parallel forward over a flat,
COD-subsampled sequence -- it does NOT run EAGLE-3's autoregressive TTT
recurrence.

Per training sequence:

1. **COD sampling** (:func:`generate_cod_sample_indices`) draws
   `(anchor_pos, depth)`: depth 0 keeps every position, depth `d` keeps a
   geometrically decaying `down_sample_ratio**d` fraction. The reference
   position of each element is `anchor_pos + depth`.
2. **Flat input assembly.** All depths are concatenated into one
   `[1, total_sampled]` sequence. Depth-0 slots take the real token id and
   the `fc`-projected target aux hidden state; depth >= 1 slots take the
   masked `mask_token_id` and the single learnable `mask_hidden`
   placeholder (projected through the same `fc`).
3. **COD flex attention.** A single `flex_attention` forward with the
   :func:`create_peagle_mask_mod` block mask: each element attends to the
   causal depth-0 context of its document plus earlier-or-equal depths of
   its own rollout. This is exactly what vLLM's parallel-drafting runtime
   sees at inference.
4. **Count-normalized KL loss.** `KL(target || draft)` over the draft vocab
   at every supervised sampled position, normalized by a single total token
   count -- deeper depths (fewer COD positions) naturally contribute less
   gradient. No `0.8**d` schedule.

Batches with `batch_size &gt; 1` are processed row-by-row (speculators is
batch-size-1); per-row losses are accumulated with a shared denominator so
the normalization stays count-based across the whole batch.

**Sequence partitioning (`sequence_partitions &gt; 1`).** The flat COD
forward attends over `n * sum(r**d)` positions, so its peak attention /
activation memory grows with the context length and OOMs on long sequences.
P-EAGLE's Algorithm 1 (arXiv:2602.01469) splits each sequence into `S`
segments by dependency lineage (:func:`assign_cod_segments`) and runs a
*separate* forward+backward per segment so only one segment's activations are
resident at a time. The partition is exact: each segment additionally reads
every depth-0 position up to its right boundary as key/value context (causal
completion), so a segment's queries see exactly the key/value set they would
in the single flat forward -- the gradients accumulated across segments equal
the single-forward gradient.

The split is *caller-driven* so the gradient sync stays correct under DDP:
:meth:`build_peagle_plan` (no-grad) assigns COD elements to segments, then the
recipe runs one `forward(..., peagle_segment=(plan, i))` per segment and
owns the `backward()`. Doing the per-segment backward here (inside a single
`forward`) would bypass `DistributedDataParallel`'s reducer -- its grad
all-reduce hooks only fire for backwards over the tensor `DDP.forward`
returns -- and silently desynchronize ranks. `sequence_partitions == 1` and
eval take the single flat `forward` unchanged.

```python
nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule._forward_peagle_segment(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    aux_hidden_states: torch.Tensor,
    target_logits: torch.Tensor,
    peagle_segment: tuple
) -> nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics
```

Loss for one segment of a :meth:`build_peagle_plan` plan.

The recipe drives this once per `plan.units` entry and back-propagates
each result, so each segment owns a self-contained autograd graph that is
freed before the next -- and the backward flows through `DDP.forward` so
gradients all-reduce correctly. `metrics.loss` is the segment's share of
the count-normalized batch loss (`loss / total_den`); summing it over the
plan reproduces the single flat forward's loss.

```python
nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule._peagle_position_loss(
    input_ids_row: torch.Tensor,
    aux_row: torch.Tensor,
    target_logits_row: torch.Tensor,
    anchor_pos: torch.Tensor,
    depth: torch.Tensor,
    orig_positions: torch.Tensor,
    loss_positions: torch.Tensor,
    row_length: torch.Tensor,
    seq_len: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
```

Draft forward + count-normalized KL for one row's COD elements.

Shared by the single flat forward (one call per row, all sampled
positions charged) and the partitioned forward (one call per segment,
only the segment's owned/supervised positions charged -- the rest ride
along as key/value context). Returns `(loss_num, loss_den, correct,
valid)` as float scalars, where the loss is `Σ KL` over `loss_positions`
and `loss_den` is their count; the caller normalizes.

```python
nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule.build_peagle_plan(
    loss_mask: torch.Tensor
) -> '_PeaglePlan'
```

Assign COD elements to segments for the sequence-partitioning path.

Samples COD once per row (the indices must be reused across the segment
forwards), runs Algorithm 1 assignment (:func:`assign_cod_segments`) plus
causal completion, and emits one `unit` per non-empty `(row, segment)`
as `(b, anchor, depth, orig_positions, loss_positions)`. `loss_positions`
marks the segment's *owned* supervised elements (charged loss); the other
elements are depth-0 causal-completion context (key/value only). The shared
`total_den` is the batch's total supervised-token count, so each segment's
`loss / total_den` sums to the single-forward loss.

```python
nemo_automodel.components.speculative.eagle.peagle_trainer.PEagleTrainerModule.forward(
    input_ids: torch.Tensor,
    attention_mask: torch.Tensor,
    loss_mask: torch.Tensor,
    aux_hidden_states: torch.Tensor,
    target_logits: torch.Tensor,
    peagle_segment: tuple | None = None
) -> nemo_automodel.components.speculative.eagle.core.Eagle3StepMetrics
```

Run the P-EAGLE parallel-drafting loss for one batch.

`attention_mask` supplies the per-row valid length so padded positions
are excluded from attention (document mask) and from supervision.

`peagle_segment` selects the sequence-partitioning path: when it is a
`(plan, index)` pair (built by :meth:`build_peagle_plan`) this computes
the loss for that one segment only -- the recipe calls this once per
segment and owns the `backward()` so DDP's gradient sync stays correct.
When `None` (`sequence_partitions == 1` and eval) a single flat
forward over the whole COD sequence returns a grad-carrying loss.

```python
class nemo_automodel.components.speculative.eagle.peagle_trainer._PeaglePlan(
    units: list[tuple],
    total_den: torch.Tensor
)
```

Sequence-partitioning plan: one `unit` per non-empty `(row, segment)`.

A plain (non-container) class on purpose: `DistributedDataParallel` only
scatters tensors and built-in containers across its inputs, so passing this
object through `forward(peagle_segment=(plan, i))` leaves its tensors intact
(a `dict`/`tuple` of tensors would be sliced along dim 0 by DDP's scatter).

```python
nemo_automodel.components.speculative.eagle.peagle_trainer._kl_div_loss(
    logits: torch.Tensor,
    target_logits: torch.Tensor
) -> torch.Tensor
```

Per-position KL(target || draft) over the draft vocabulary.

Matches speculators' `kl_div_loss`: `log_softmax` the draft logits,
`softmax` the target logits, and sum the elementwise KL over the vocab
axis. Shapes `[*, draft_vocab]` -> `[*]`.