> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.distributed.magi_attn_utils

MagiAttention integration for Automodel.

MagiAttention ([https://github.com/SandAI-org/MagiAttention](https://github.com/SandAI-org/MagiAttention)) is a distributed
(context-parallel) attention built on a Flex-Flash-Attention (FFA) kernel.  It
shards a single packed sequence across a CP process group with a load-balancing
dispatch solver and exchanges KV with zero-redundant GroupCast/GroupReduce
collectives.

This module wires MagiAttention into the HF-transformers-based LLM path used by
`recipes/llm/train_ft.py` following MagiAttention's official
`examples/transformers` recipe:

1. `register_magi_attention()` registers a `"magi"` entry in HF's
   `ALL_ATTENTION_FUNCTIONS` so that a model loaded with
   `attn_implementation="magi"` routes its attention through the FFA kernel.
2. `magi_prepare_batch()` builds the per-step dist-attn runtime key, dispatches
   `input_ids`/`position_ids`/`labels` across the CP group and stamps
   `cp_group` on every attention sub-module so the registered forward finds the key.
3. Each rank runs the model on its local shard and computes a per-shard loss; the
   recipe's cross-CP reduction sums the shards into the global loss (like TE-CP).
   Sharding labels (rather than undispatching logits) keeps the loss path identical
   for the HF and custom-model backends.

When `cp_size == 1` the dispatch is a no-op shard (identity + chunk padding),
so this path is also a clean way to swap *only* the attention kernel (FFA) in
place of eager/SDPA/flash for convergence-parity comparisons.

## Module Contents

### Classes

| Name                                                                                  | Description                                                                 |
| ------------------------------------------------------------------------------------- | --------------------------------------------------------------------------- |
| [`AttnMaskSpec`](#nemo_automodel-components-distributed-magi_attn_utils-AttnMaskSpec) | Backend-agnostic description of an attention mask as AttnSlice rectangles.  |
| [`MagiState`](#nemo_automodel-components-distributed-magi_attn_utils-MagiState)       | Resolved MagiAttention wiring for a recipe, produced by :func:`setup_magi`. |

### Functions

| Name                                                                                                                      | Description                                                                          |
| ------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ |
| [`_build_self_key`](#nemo_automodel-components-distributed-magi_attn_utils-_build_self_key)                               | Build a cp=1 causal varlen key matching the actual q length (no dispatch).           |
| [`_flex_key_for`](#nemo_automodel-components-distributed-magi_attn_utils-_flex_key_for)                                   | Return the flex key for `spec`, rebuilding only when the mask changes.               |
| [`_get_head_config`](#nemo_automodel-components-distributed-magi_attn_utils-_get_head_config)                             | Extract (num\_heads\_q, num\_heads\_kv, head\_dim) from an HF model/config.          |
| [`_iter_language_model_attention`](#nemo_automodel-components-distributed-magi_attn_utils-_iter_language_model_attention) | Yield attention sub-modules belonging to the *language* backbone only.               |
| [`_packed_cp_doc_seqlens`](#nemo_automodel-components-distributed-magi_attn_utils-_packed_cp_doc_seqlens)                 | Resolve per-document lengths spanning the *padded* THD layout of length `total_len`. |
| [`_self_key_for`](#nemo_automodel-components-distributed-magi_attn_utils-_self_key_for)                                   | Return a causal self-key for `seqlen`, (re)building only when it changes.            |
| [`_set_cp_group_on_attention`](#nemo_automodel-components-distributed-magi_attn_utils-_set_cp_group_on_attention)         | Stamp `cp_group` on every attention sub-module so the FFA forward finds the key.     |
| [`build_flex_key`](#nemo_automodel-components-distributed-magi_attn_utils-build_flex_key)                                 | Build a magi dist-attn key for an arbitrary AttnSlice mask (no extra padding).       |
| [`get_active_attn_spec`](#nemo_automodel-components-distributed-magi_attn_utils-get_active_attn_spec)                     | Return the active mask spec (None -> plain causal self-key).                         |
| [`get_active_cp_group`](#nemo_automodel-components-distributed-magi_attn_utils-get_active_cp_group)                       | Return the CP group set by :func:`set_active_cp_group` (may be None).                |
| [`get_cp_group`](#nemo_automodel-components-distributed-magi_attn_utils-get_cp_group)                                     | Return the CP process group from the device mesh (size-1 group is fine).             |
| [`is_magi_available`](#nemo_automodel-components-distributed-magi_attn_utils-is_magi_available)                           | Return True if the `magi_attention` package is importable.                           |
| [`magi_prepare_batch`](#nemo_automodel-components-distributed-magi_attn_utils-magi_prepare_batch)                         | Dispatch a (batch\_size==1) sequence for MagiAttention on the HF path.               |
| [`magi_prepare_packed_cp`](#nemo_automodel-components-distributed-magi_attn_utils-magi_prepare_packed_cp)                 | Context-parallel prep for a packed (THD) batch on the custom-model path.             |
| [`magi_prepare_vlm`](#nemo_automodel-components-distributed-magi_attn_utils-magi_prepare_vlm)                             | Prepare a VLM (bs==1) step for MagiAttention on the language backbone.               |
| [`make_magi_attn_func`](#nemo_automodel-components-distributed-magi_attn_utils-make_magi_attn_func)                       | Build the attn\_func used by the custom-model attention factory.                     |
| [`register_magi_attention`](#nemo_automodel-components-distributed-magi_attn_utils-register_magi_attention)               | Register the `"magi"` attention backend in HF transformers (idempotent).             |
| [`set_active_attn_spec`](#nemo_automodel-components-distributed-magi_attn_utils-set_active_attn_spec)                     | Set the mask spec the custom-model magi attn\_func should apply this step.           |
| [`set_active_cp_group`](#nemo_automodel-components-distributed-magi_attn_utils-set_active_cp_group)                       | Record the CP group the custom-model magi attn\_func should use.                     |
| [`setup_magi`](#nemo_automodel-components-distributed-magi_attn_utils-setup_magi)                                         | Resolve MagiAttention from config: register the backend and CP group.                |

### Data

[`DEFAULT_CHUNK_SIZE`](#nemo_automodel-components-distributed-magi_attn_utils-DEFAULT_CHUNK_SIZE)

[`_ACTIVE_ATTN_SPEC`](#nemo_automodel-components-distributed-magi_attn_utils-_ACTIVE_ATTN_SPEC)

[`_ACTIVE_CP_GROUP`](#nemo_automodel-components-distributed-magi_attn_utils-_ACTIVE_CP_GROUP)

[`_FLEX_KEY_CACHE`](#nemo_automodel-components-distributed-magi_attn_utils-_FLEX_KEY_CACHE)

[`_MAGI_REGISTERED`](#nemo_automodel-components-distributed-magi_attn_utils-_MAGI_REGISTERED)

[`_MAGI_SELF_KEY_LEN`](#nemo_automodel-components-distributed-magi_attn_utils-_MAGI_SELF_KEY_LEN)

[`logger`](#nemo_automodel-components-distributed-magi_attn_utils-logger)

### API

```python
class nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec(
    q_ranges: list,
    k_ranges: list,
    mask_types: list,
    total_seqlen: int
)
```

Dataclass

Backend-agnostic description of an attention mask as AttnSlice rectangles.

```python
nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.causal(
    seqlen: int
) -> 'AttnMaskSpec'
```

classmethod

A single full causal sequence.

```python
nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.fingerprint() -> tuple
```

Hashable identity used to cache the built key across layers.

```python
nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.prefix_tree(
    node_lengths: list[int],
    sample_paths: list[list[int]]
)
```

classmethod

Build a prefix-tree mask over a flat deduplicated token layout.

Each node attends FULL to every ancestor node in its path and CAUSAL to
itself; duplicate rectangles (shared nodes) are emitted once.

**Parameters:**

token count of each node, in flat layout order. The flat
layout is `[node_0 | node_1 | ...]` with node `i` occupying
`[offset_i, offset_i + node_lengths[i])`.

one list of node indices per sample, root -> leaf. Every
sample is the causal concatenation of its nodes; a shared prefix
node simply appears in multiple paths.

**Returns:** `(spec, sample_token_ranges)`

`spec` is the AttnMaskSpec; the second

```python
nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.varlen(
    seqlens: list[int],
    causal: bool = True
) -> 'AttnMaskSpec'
```

classmethod

Block-diagonal mask for packed sequences (one block per document).

```python
class nemo_automodel.components.distributed.magi_attn_utils.MagiState(
    enabled: bool = False,
    custom: bool = False,
    cp_group: typing.Optional['dist.ProcessGroup'] = None,
    cp_size: int = 1
)
```

Dataclass

Resolved MagiAttention wiring for a recipe, produced by :func:`setup_magi`.

A single handle (stored as `self.magi`) replacing the scattered
`magi_enabled`/`magi_custom`/`magi_cp_group`/`magi_cp_size` recipe
attributes. When MagiAttention is not configured, `enabled` is False and the
per-step methods are no-ops, so recipes can call them unconditionally.

HF path: dispatch the sequence (input + labels) across CP for a per-shard loss.

Distinguishes the HF `attn_implementation=magi` path (single causal sequence,
:func:`magi_prepare_batch`) from the custom-model factory path; both shard labels
and compute a per-shard loss at cp>1, so neither undispatches logits.

```python
nemo_automodel.components.distributed.magi_attn_utils.MagiState.prepare_llm_batch(
    model,
    batch,
    device_mesh,
    is_thd,
    pad_id,
    num_chunks
)
```

Per-step batch prep for the LLM recipe (assumes `enabled`).

Returns `(train_ctx, batch)`. magi does its own CP, so `train_ctx` is
always `nullcontext` (no torch-native DTensor CP context).

```python
nemo_automodel.components.distributed.magi_attn_utils.MagiState.prepare_vlm_batch(
    model,
    batch
)
```

Per-step batch prep for the VLM recipe (assumes `enabled`).

HF VLMs stamp the cp\_group on the language-backbone attention; custom VLMs
use the factory attn\_func with the active cp\_group set in :func:`setup_magi`
(the vision tower stays on SDPA either way). Returns `(train_ctx, batch)`.

```python
nemo_automodel.components.distributed.magi_attn_utils._build_self_key(
    cp_group,
    seqlen,
    num_heads_q,
    num_heads_kv,
    head_dim,
    device
)
```

Build a cp=1 causal varlen key matching the actual q length (no dispatch).

```python
nemo_automodel.components.distributed.magi_attn_utils._flex_key_for(
    cp_group,
    spec,
    num_heads_q,
    num_heads_kv,
    head_dim
)
```

Return the flex key for `spec`, rebuilding only when the mask changes.

```python
nemo_automodel.components.distributed.magi_attn_utils._get_head_config(
    model
) -> tuple[int, int, int]
```

Extract (num\_heads\_q, num\_heads\_kv, head\_dim) from an HF model/config.

For VLMs the text attention dims live under `config.text_config`; prefer that
sub-config when the top-level config does not expose `num_attention_heads`.

```python
nemo_automodel.components.distributed.magi_attn_utils._iter_language_model_attention(
    model
)
```

Yield attention sub-modules belonging to the *language* backbone only.

For VLMs we must leave the vision tower on its own (bidirectional) attention.
HF VLMs nest the text stack under a `language_model`/`model.language_model`
attribute; we walk only that subtree. Falls back to the whole model for plain
LLMs (no language\_model attribute).

```python
nemo_automodel.components.distributed.magi_attn_utils._packed_cp_doc_seqlens(
    batch: dict,
    total_len: int
) -> list
```

Resolve per-document lengths spanning the *padded* THD layout of length `total_len`.

The TE collater pads each document for the THD layout, so `cu_seqlens_padded`
spans the full flat `input_ids` while `cu_seqlens` covers only the real
tokens. magi dispatches the whole flat sequence, so the dist key must be built
over the padded layout — otherwise the dispatched shard length (from
`input_ids`) won't match `get_position_ids` (built from the key), which
surfaces downstream as a RoPE q vs cos/sin length mismatch. Causal masking keeps
real tokens from attending the trailing per-document pad, and pad-token rows are
dropped by the loss (labels == ignore\_index), so this is numerically equivalent
to attending only the real tokens.

**Raises:**

* `ValueError`: if the resolved document layout does not span `total_len`.

```python
nemo_automodel.components.distributed.magi_attn_utils._self_key_for(
    cp_group,
    seqlen,
    num_heads_q,
    num_heads_kv,
    head_dim,
    device
)
```

Return a causal self-key for `seqlen`, (re)building only when it changes.

All attention layers in one forward share the same sequence length, so the
first layer builds the key and the rest reuse it via `get_most_recent_key`.

```python
nemo_automodel.components.distributed.magi_attn_utils._set_cp_group_on_attention(
    model,
    cp_group
) -> None
```

Stamp `cp_group` on every attention sub-module so the FFA forward finds the key.

```python
nemo_automodel.components.distributed.magi_attn_utils.build_flex_key(
    spec: 'AttnMaskSpec',
    num_heads_q,
    num_heads_kv,
    head_dim,
    cp_group
)
```

Build a magi dist-attn key for an arbitrary AttnSlice mask (no extra padding).

```python
nemo_automodel.components.distributed.magi_attn_utils.get_active_attn_spec() -> typing.Optional['AttnMaskSpec']
```

Return the active mask spec (None -> plain causal self-key).

```python
nemo_automodel.components.distributed.magi_attn_utils.get_active_cp_group() -> typing.Optional['dist.ProcessGroup']
```

Return the CP group set by :func:`set_active_cp_group` (may be None).

```python
nemo_automodel.components.distributed.magi_attn_utils.get_cp_group(
    device_mesh
) -> typing.Optional[torch.distributed.ProcessGroup]
```

Return the CP process group from the device mesh (size-1 group is fine).

```python
nemo_automodel.components.distributed.magi_attn_utils.is_magi_available() -> bool
```

Return True if the `magi_attention` package is importable.

```python
nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_batch(
    model,
    batch: dict,
    cp_group: typing.Optional[torch.distributed.ProcessGroup],
    chunk_size: int = DEFAULT_CHUNK_SIZE
)
```

Dispatch a (batch\_size==1) sequence for MagiAttention on the HF path.

Builds a causal varlen dist-attn key over the single sequence and dispatches
`input_ids`, `position_ids` *and* `labels` across `cp_group` (identity
shard at cp\_size==1; load-balanced sharding at cp\_size>1). Labels are sharded the
same way as the input so the loss is computed per-shard and summed across CP — no
logit undispatch; `MaskedCrossEntropy` does not shift, so the dispatch
permutation is harmless (logits\[j] stay paired with labels\[j]). The FFA kernel
does the cross-rank KV exchange via the stamped `cp_group` + the dist key.

**Parameters:**

the (possibly FSDP-wrapped) HF causal-LM.

dict with at least `input_ids` of shape `[1, S]`.

CP process group (size 1 -> identity shard).

dispatch solver chunk size.

**Returns:** `(new_batch, key)`

`new_batch` has dispatched `input_ids`/`position_ids`/

```python
nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_packed_cp(
    model,
    batch: dict,
    cp_group
)
```

Context-parallel prep for a packed (THD) batch on the custom-model path.

Takes a *global* THD batch (flat `input_ids`/`labels`/`position_ids` plus
`cu_seqlens` marking document boundaries), builds a per-document varlen dist
key over `cp_group` and dispatches the sequence with MagiAttention's own
load-balancing solver (not TE's THD sharding). Each rank then runs the model on
its local shard; the attn\_func uses this pre-built key so the FFA kernel does
the cross-rank KV exchange. Labels are dispatched the same way as the input, so
each rank computes a per-shard loss that the recipe's cross-CP reduction sums
into the global loss (like TE-CP).

**Returns:** `(new_batch, key)`

`new_batch` has the local `input_ids`/`position_ids`

```python
nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_vlm(
    model,
    batch: dict,
    cp_group: typing.Optional[torch.distributed.ProcessGroup]
)
```

Prepare a VLM (bs==1) step for MagiAttention on the language backbone.

Unlike the LLM path, the VLM merges image features into `inputs_embeds`
*inside* its forward, so we cannot dispatch `input_ids` (image-placeholder
positions must stay put). For `cp_size == 1` we instead build a no-padding
causal key (`chunk_size=1` -> dispatch/undispatch are identity), stamp the
cp\_group on the language-model attention modules only, and let the FFA kernel
run on the natural-length q/k/v. The vision tower falls back to SDPA.

Returns the (unchanged) batch and `None` (the key is built lazily in the
attention forward from the real query length).

```python
nemo_automodel.components.distributed.magi_attn_utils.make_magi_attn_func(
    softmax_scale: typing.Optional[float] = None
)
```

Build the attn\_func used by the custom-model attention factory.

The returned callable accepts q/k/v in either THD `[t, nh, hd]` or BSHD
`[b, s, nh, hd]` (b must be 1) layout — the same layouts the custom models
feed to their backend attn\_func — and runs the MagiAttention FFA kernel.

Mask selection (no CP dispatch; cp=1 in-order tokens):

* if an :class:`AttnMaskSpec` is active (`set_active_attn_spec`), use its
  flex key — this covers packing, sliding-window and prefix-tree masks;
* otherwise build a plain causal self-key from the q length.
  If no CP group is active it falls back to causal SDPA so non-magi modules
  (e.g. a VLM vision tower routed through the same factory) keep working.

**Parameters:**

attention softmax scale (defaults to 1/sqrt(head\_dim) inside
FFA when None); forwarded so non-default scales stay correct.

```python
nemo_automodel.components.distributed.magi_attn_utils.register_magi_attention() -> None
```

Register the `"magi"` attention backend in HF transformers (idempotent).

```python
nemo_automodel.components.distributed.magi_attn_utils.set_active_attn_spec(
    spec: typing.Optional['AttnMaskSpec']
) -> None
```

Set the mask spec the custom-model magi attn\_func should apply this step.

```python
nemo_automodel.components.distributed.magi_attn_utils.set_active_cp_group(
    cp_group: typing.Optional['dist.ProcessGroup']
) -> None
```

Record the CP group the custom-model magi attn\_func should use.

```python
nemo_automodel.components.distributed.magi_attn_utils.setup_magi(
    cfg,
    device_mesh,
    label: str = ''
) -> nemo_automodel.components.distributed.magi_attn_utils.MagiState
```

Resolve MagiAttention from config: register the backend and CP group.

Enabled when the model is configured with `attn_implementation="magi"` (HF) or
`backend.attn="magi"` (custom models). Returns a :class:`MagiState`
(`enabled=False` when magi is not configured). `label` is an optional suffix
for the log line (e.g. `"VLM language backbone"`).

```python
nemo_automodel.components.distributed.magi_attn_utils.DEFAULT_CHUNK_SIZE = 512
```

```python
nemo_automodel.components.distributed.magi_attn_utils._ACTIVE_ATTN_SPEC: Optional['AttnMaskSpec'] = None
```

```python
nemo_automodel.components.distributed.magi_attn_utils._ACTIVE_CP_GROUP: Optional['dist.ProcessGroup'] = None
```

```python
nemo_automodel.components.distributed.magi_attn_utils._FLEX_KEY_CACHE: dict = {}
```

```python
nemo_automodel.components.distributed.magi_attn_utils._MAGI_REGISTERED = False
```

```python
nemo_automodel.components.distributed.magi_attn_utils._MAGI_SELF_KEY_LEN: dict = {}
```

```python
nemo_automodel.components.distributed.magi_attn_utils.logger = logging.getLogger(__name__)
```