> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.speculative.eagle.peagle_attention

Flex-attention mask for P-EAGLE parallel-group prediction.

P-EAGLE flattens all COD-sampled depths into one sequence and runs a *single*
attention forward over it. The cross-depth visibility pattern is not plain
causal: an element may attend to (a) any earlier-position element at depth 0
(the committed real context) and (b) earlier-or-equal depths *of its own
rollout* (the masked multi-token-prediction chain it belongs to). Documents are
isolated so packed / padded rows never attend across each other.

This is a verbatim port of speculators' `create_peagle_mask_mod`
([https://github.com/vllm-project/speculators/pull/480](https://github.com/vllm-project/speculators/pull/480)) so training reproduces
exactly what vLLM's parallel-drafting runtime sees at inference.

## Module Contents

### Functions

| Name                                                                                                             | Description                                                      |
| ---------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------- |
| [`create_peagle_mask_mod`](#nemo_automodel-components-speculative-eagle-peagle_attention-create_peagle_mask_mod) | Build a `flex_attention` `mask_mod` for P-EAGLE parallel groups. |

### API

```python
nemo_automodel.components.speculative.eagle.peagle_attention.create_peagle_mask_mod(
    anchor_pos: torch.Tensor,
    depth: torch.Tensor,
    lengths: torch.Tensor,
    total_seq_len: int
)
```

Build a `flex_attention` `mask_mod` for P-EAGLE parallel groups.

Each query attends only to previous elements in the same sampling chain /
rollout, plus the causal depth-0 context in the same document. `lengths`
drives a document-id map so padded positions (ids `-1`) and cross-document
pairs are excluded.

Example (one document of length 6, COD sampling):
Round 1 positions: \[0, 1, 3, 4]; Round 2: \[0, 3]; Round 3: \[0]
anchor\_pos: \[0,1,2,3,4,5, 0,1,3,4, 0,3, 0]
depth:      \[0,0,0,0,0,0, 1,1,1,1, 2,2, 3]

**Parameters:**

Chain-start position in the original sequence per element.

COD round per element.

Valid (unpadded) length of each document in the flat sequence.

Combined padded length of the original sequence(s).

**Returns:**

A `mask_mod` callable compatible with `create_block_mask`.