> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.attention.flex_attention

## Module Contents

### Classes

| Name                                                                                 | Description                                                        |
| ------------------------------------------------------------------------------------ | ------------------------------------------------------------------ |
| [`FlexAttention`](#nemo_automodel-components-attention-flex_attention-FlexAttention) | FlexAttention module that uses torch.nn.attention.flex\_attention. |

### Data

[`FLEX_ATTN_MASK_T`](#nemo_automodel-components-attention-flex_attention-FLEX_ATTN_MASK_T)

### API

```python
class nemo_automodel.components.attention.flex_attention.FlexAttention()
```

**Bases:** `Module`

FlexAttention module that uses torch.nn.attention.flex\_attention.

This module is a wrapper around torch.nn.attention.flex\_attention. This module
implements certain common attention types, such as causal and block\_causal.

**Parameters:**

The type of attention mask. Currently, we support
"causal" and "block\_causal". "causal" means the lower triangle of the
attention matrix is masked. "block\_causal" means the attention matrix
is divided into blocks, where block boundary is defined by EOS token,
and the lower triangle of each block is masked.

The block size to be used to perform attention.
If specified, each sequence will be further divided to blocks, where each
block has the maximum size of `fixed_block_size`. A query will only attend
to the keys within the same block.

```python
nemo_automodel.components.attention.flex_attention.FlexAttention._fixed_block_mask_mod(
    mask_mod: torch.nn.attention.flex_attention._mask_mod_signature,
    fixed_block_size: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
```

staticmethod

Given an arbitrary mask\_mod, divide the input sequence to blocks
and only allow attention within the same block.

**Parameters:**

The mask mod to apply to the documents

The number of tokens in each block.

```python
nemo_automodel.components.attention.flex_attention.FlexAttention._get_block_causal_mask_mod(
    batch: torch.Tensor,
    eos_id: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
```

staticmethod

```python
nemo_automodel.components.attention.flex_attention.FlexAttention._get_causal_mask_mod() -> torch.nn.attention.flex_attention._mask_mod_signature
```

staticmethod

```python
nemo_automodel.components.attention.flex_attention.FlexAttention._get_sliding_window_mask_mod(
    window: int
) -> torch.nn.attention.flex_attention._mask_mod_signature
```

staticmethod

Returns a mask\_mod function that

* only allows kv\_idx ≤ q\_idx (causal)
* and only if (q\_idx - kv\_idx) ≤ window

```python
nemo_automodel.components.attention.flex_attention.FlexAttention.forward(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    scale: float | None = None,
    sink_weights: torch.Tensor | None = None,
    sliding_window: int = 0,
    enable_gqa: bool = False
) -> torch.Tensor
```

```python
nemo_automodel.components.attention.flex_attention.FLEX_ATTN_MASK_T = tuple[str, int | None] | tuple[int, int, int]
```