nemo_automodel.components.distributed.magi_attn_utils

View as Markdown

MagiAttention integration for Automodel.

MagiAttention (https://github.com/SandAI-org/MagiAttention) is a distributed (context-parallel) attention built on a Flex-Flash-Attention (FFA) kernel. It shards a single packed sequence across a CP process group with a load-balancing dispatch solver and exchanges KV with zero-redundant GroupCast/GroupReduce collectives.

This module wires MagiAttention into the HF-transformers-based LLM path used by recipes/llm/train_ft.py following MagiAttention’s official examples/transformers recipe:

  1. register_magi_attention() registers a "magi" entry in HF’s ALL_ATTENTION_FUNCTIONS so that a model loaded with attn_implementation="magi" routes its attention through the FFA kernel.
  2. magi_prepare_batch() builds the per-step dist-attn runtime key, dispatches input_ids/position_ids/labels across the CP group and stamps cp_group on every attention sub-module so the registered forward finds the key.
  3. Each rank runs the model on its local shard and computes a per-shard loss; the recipe’s cross-CP reduction sums the shards into the global loss (like TE-CP). Sharding labels (rather than undispatching logits) keeps the loss path identical for the HF and custom-model backends.

When cp_size == 1 the dispatch is a no-op shard (identity + chunk padding), so this path is also a clean way to swap only the attention kernel (FFA) in place of eager/SDPA/flash for convergence-parity comparisons.

Module Contents

Classes

NameDescription
AttnMaskSpecBackend-agnostic description of an attention mask as AttnSlice rectangles.
MagiStateResolved MagiAttention wiring for a recipe, produced by :func:setup_magi.

Functions

NameDescription
_build_self_keyBuild a cp=1 causal varlen key matching the actual q length (no dispatch).
_flex_key_forReturn the flex key for spec, rebuilding only when the mask changes.
_get_head_configExtract (num_heads_q, num_heads_kv, head_dim) from an HF model/config.
_iter_language_model_attentionYield attention sub-modules belonging to the language backbone only.
_packed_cp_doc_seqlensResolve per-document lengths spanning the padded THD layout of length total_len.
_self_key_forReturn a causal self-key for seqlen, (re)building only when it changes.
_set_cp_group_on_attentionStamp cp_group on every attention sub-module so the FFA forward finds the key.
build_flex_keyBuild a magi dist-attn key for an arbitrary AttnSlice mask (no extra padding).
get_active_attn_specReturn the active mask spec (None -> plain causal self-key).
get_active_cp_groupReturn the CP group set by :func:set_active_cp_group (may be None).
get_cp_groupReturn the CP process group from the device mesh (size-1 group is fine).
is_magi_availableReturn True if the magi_attention package is importable.
magi_prepare_batchDispatch a (batch_size==1) sequence for MagiAttention on the HF path.
magi_prepare_packed_cpContext-parallel prep for a packed (THD) batch on the custom-model path.
magi_prepare_vlmPrepare a VLM (bs==1) step for MagiAttention on the language backbone.
make_magi_attn_funcBuild the attn_func used by the custom-model attention factory.
register_magi_attentionRegister the "magi" attention backend in HF transformers (idempotent).
set_active_attn_specSet the mask spec the custom-model magi attn_func should apply this step.
set_active_cp_groupRecord the CP group the custom-model magi attn_func should use.
setup_magiResolve MagiAttention from config: register the backend and CP group.

Data

DEFAULT_CHUNK_SIZE

_ACTIVE_ATTN_SPEC

_ACTIVE_CP_GROUP

_FLEX_KEY_CACHE

_MAGI_REGISTERED

_MAGI_SELF_KEY_LEN

logger

API

class nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec(
q_ranges: list,
k_ranges: list,
mask_types: list,
total_seqlen: int
)
Dataclass

Backend-agnostic description of an attention mask as AttnSlice rectangles.

k_ranges
list
mask_types
list
q_ranges
list
total_seqlen
int
nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.causal(
seqlen: int
) -> 'AttnMaskSpec'
classmethod

A single full causal sequence.

nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.fingerprint() -> tuple

Hashable identity used to cache the built key across layers.

nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.prefix_tree(
node_lengths: list[int],
sample_paths: list[list[int]]
)
classmethod

Build a prefix-tree mask over a flat deduplicated token layout.

Each node attends FULL to every ancestor node in its path and CAUSAL to itself; duplicate rectangles (shared nodes) are emitted once.

Parameters:

node_lengths
list[int]

token count of each node, in flat layout order. The flat layout is [node_0 | node_1 | ...] with node i occupying [offset_i, offset_i + node_lengths[i]).

sample_paths
list[list[int]]

one list of node indices per sample, root -> leaf. Every sample is the causal concatenation of its nodes; a shared prefix node simply appears in multiple paths.

Returns: (spec, sample_token_ranges)

spec is the AttnMaskSpec; the second

nemo_automodel.components.distributed.magi_attn_utils.AttnMaskSpec.varlen(
seqlens: list[int],
causal: bool = True
) -> 'AttnMaskSpec'
classmethod

Block-diagonal mask for packed sequences (one block per document).

class nemo_automodel.components.distributed.magi_attn_utils.MagiState(
enabled: bool = False,
custom: bool = False,
cp_group: typing.Optional['dist.ProcessGroup'] = None,
cp_size: int = 1
)
Dataclass

Resolved MagiAttention wiring for a recipe, produced by :func:setup_magi.

A single handle (stored as self.magi) replacing the scattered magi_enabled/magi_custom/magi_cp_group/magi_cp_size recipe attributes. When MagiAttention is not configured, enabled is False and the per-step methods are no-ops, so recipes can call them unconditionally.

cp_group
Optional['dist.ProcessGroup'] = None
cp_size
int = 1
custom
bool = False
enabled
bool = False
hf_dispatch
bool

HF path: dispatch the sequence (input + labels) across CP for a per-shard loss.

Distinguishes the HF attn_implementation=magi path (single causal sequence, :func:magi_prepare_batch) from the custom-model factory path; both shard labels and compute a per-shard loss at cp>1, so neither undispatches logits.

nemo_automodel.components.distributed.magi_attn_utils.MagiState.prepare_llm_batch(
model,
batch,
device_mesh,
is_thd,
pad_id,
num_chunks
)

Per-step batch prep for the LLM recipe (assumes enabled).

Returns (train_ctx, batch). magi does its own CP, so train_ctx is always nullcontext (no torch-native DTensor CP context).

nemo_automodel.components.distributed.magi_attn_utils.MagiState.prepare_vlm_batch(
model,
batch
)

Per-step batch prep for the VLM recipe (assumes enabled).

HF VLMs stamp the cp_group on the language-backbone attention; custom VLMs use the factory attn_func with the active cp_group set in :func:setup_magi (the vision tower stays on SDPA either way). Returns (train_ctx, batch).

nemo_automodel.components.distributed.magi_attn_utils._build_self_key(
cp_group,
seqlen,
num_heads_q,
num_heads_kv,
head_dim,
device
)

Build a cp=1 causal varlen key matching the actual q length (no dispatch).

nemo_automodel.components.distributed.magi_attn_utils._flex_key_for(
cp_group,
spec,
num_heads_q,
num_heads_kv,
head_dim
)

Return the flex key for spec, rebuilding only when the mask changes.

nemo_automodel.components.distributed.magi_attn_utils._get_head_config(
model
) -> tuple[int, int, int]

Extract (num_heads_q, num_heads_kv, head_dim) from an HF model/config.

For VLMs the text attention dims live under config.text_config; prefer that sub-config when the top-level config does not expose num_attention_heads.

nemo_automodel.components.distributed.magi_attn_utils._iter_language_model_attention(
model
)

Yield attention sub-modules belonging to the language backbone only.

For VLMs we must leave the vision tower on its own (bidirectional) attention. HF VLMs nest the text stack under a language_model/model.language_model attribute; we walk only that subtree. Falls back to the whole model for plain LLMs (no language_model attribute).

nemo_automodel.components.distributed.magi_attn_utils._packed_cp_doc_seqlens(
batch: dict,
total_len: int
) -> list

Resolve per-document lengths spanning the padded THD layout of length total_len.

The TE collater pads each document for the THD layout, so cu_seqlens_padded spans the full flat input_ids while cu_seqlens covers only the real tokens. magi dispatches the whole flat sequence, so the dist key must be built over the padded layout — otherwise the dispatched shard length (from input_ids) won’t match get_position_ids (built from the key), which surfaces downstream as a RoPE q vs cos/sin length mismatch. Causal masking keeps real tokens from attending the trailing per-document pad, and pad-token rows are dropped by the loss (labels == ignore_index), so this is numerically equivalent to attending only the real tokens.

Raises:

  • ValueError: if the resolved document layout does not span total_len.
nemo_automodel.components.distributed.magi_attn_utils._self_key_for(
cp_group,
seqlen,
num_heads_q,
num_heads_kv,
head_dim,
device
)

Return a causal self-key for seqlen, (re)building only when it changes.

All attention layers in one forward share the same sequence length, so the first layer builds the key and the rest reuse it via get_most_recent_key.

nemo_automodel.components.distributed.magi_attn_utils._set_cp_group_on_attention(
model,
cp_group
) -> None

Stamp cp_group on every attention sub-module so the FFA forward finds the key.

nemo_automodel.components.distributed.magi_attn_utils.build_flex_key(
spec: 'AttnMaskSpec',
num_heads_q,
num_heads_kv,
head_dim,
cp_group
)

Build a magi dist-attn key for an arbitrary AttnSlice mask (no extra padding).

nemo_automodel.components.distributed.magi_attn_utils.get_active_attn_spec() -> typing.Optional['AttnMaskSpec']

Return the active mask spec (None -> plain causal self-key).

nemo_automodel.components.distributed.magi_attn_utils.get_active_cp_group() -> typing.Optional['dist.ProcessGroup']

Return the CP group set by :func:set_active_cp_group (may be None).

nemo_automodel.components.distributed.magi_attn_utils.get_cp_group(
device_mesh
) -> typing.Optional[torch.distributed.ProcessGroup]

Return the CP process group from the device mesh (size-1 group is fine).

nemo_automodel.components.distributed.magi_attn_utils.is_magi_available() -> bool

Return True if the magi_attention package is importable.

nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_batch(
model,
batch: dict,
cp_group: typing.Optional[torch.distributed.ProcessGroup],
chunk_size: int = DEFAULT_CHUNK_SIZE
)

Dispatch a (batch_size==1) sequence for MagiAttention on the HF path.

Builds a causal varlen dist-attn key over the single sequence and dispatches input_ids, position_ids and labels across cp_group (identity shard at cp_size==1; load-balanced sharding at cp_size>1). Labels are sharded the same way as the input so the loss is computed per-shard and summed across CP — no logit undispatch; MaskedCrossEntropy does not shift, so the dispatch permutation is harmless (logits[j] stay paired with labels[j]). The FFA kernel does the cross-rank KV exchange via the stamped cp_group + the dist key.

Parameters:

model

the (possibly FSDP-wrapped) HF causal-LM.

batch
dict

dict with at least input_ids of shape [1, S].

cp_group
Optional[dist.ProcessGroup]

CP process group (size 1 -> identity shard).

chunk_size
intDefaults to DEFAULT_CHUNK_SIZE

dispatch solver chunk size.

Returns: (new_batch, key)

new_batch has dispatched input_ids/position_ids/

nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_packed_cp(
model,
batch: dict,
cp_group
)

Context-parallel prep for a packed (THD) batch on the custom-model path.

Takes a global THD batch (flat input_ids/labels/position_ids plus cu_seqlens marking document boundaries), builds a per-document varlen dist key over cp_group and dispatches the sequence with MagiAttention’s own load-balancing solver (not TE’s THD sharding). Each rank then runs the model on its local shard; the attn_func uses this pre-built key so the FFA kernel does the cross-rank KV exchange. Labels are dispatched the same way as the input, so each rank computes a per-shard loss that the recipe’s cross-CP reduction sums into the global loss (like TE-CP).

Returns: (new_batch, key)

new_batch has the local input_ids/position_ids

nemo_automodel.components.distributed.magi_attn_utils.magi_prepare_vlm(
model,
batch: dict,
cp_group: typing.Optional[torch.distributed.ProcessGroup]
)

Prepare a VLM (bs==1) step for MagiAttention on the language backbone.

Unlike the LLM path, the VLM merges image features into inputs_embeds inside its forward, so we cannot dispatch input_ids (image-placeholder positions must stay put). For cp_size == 1 we instead build a no-padding causal key (chunk_size=1 -> dispatch/undispatch are identity), stamp the cp_group on the language-model attention modules only, and let the FFA kernel run on the natural-length q/k/v. The vision tower falls back to SDPA.

Returns the (unchanged) batch and None (the key is built lazily in the attention forward from the real query length).

nemo_automodel.components.distributed.magi_attn_utils.make_magi_attn_func(
softmax_scale: typing.Optional[float] = None
)

Build the attn_func used by the custom-model attention factory.

The returned callable accepts q/k/v in either THD [t, nh, hd] or BSHD [b, s, nh, hd] (b must be 1) layout — the same layouts the custom models feed to their backend attn_func — and runs the MagiAttention FFA kernel.

Mask selection (no CP dispatch; cp=1 in-order tokens):

  • if an :class:AttnMaskSpec is active (set_active_attn_spec), use its flex key — this covers packing, sliding-window and prefix-tree masks;
  • otherwise build a plain causal self-key from the q length. If no CP group is active it falls back to causal SDPA so non-magi modules (e.g. a VLM vision tower routed through the same factory) keep working.

Parameters:

softmax_scale
Optional[float]Defaults to None

attention softmax scale (defaults to 1/sqrt(head_dim) inside FFA when None); forwarded so non-default scales stay correct.

nemo_automodel.components.distributed.magi_attn_utils.register_magi_attention() -> None

Register the "magi" attention backend in HF transformers (idempotent).

nemo_automodel.components.distributed.magi_attn_utils.set_active_attn_spec(
spec: typing.Optional['AttnMaskSpec']
) -> None

Set the mask spec the custom-model magi attn_func should apply this step.

nemo_automodel.components.distributed.magi_attn_utils.set_active_cp_group(
cp_group: typing.Optional['dist.ProcessGroup']
) -> None

Record the CP group the custom-model magi attn_func should use.

nemo_automodel.components.distributed.magi_attn_utils.setup_magi(
cfg,
device_mesh,
label: str = ''
) -> nemo_automodel.components.distributed.magi_attn_utils.MagiState

Resolve MagiAttention from config: register the backend and CP group.

Enabled when the model is configured with attn_implementation="magi" (HF) or backend.attn="magi" (custom models). Returns a :class:MagiState (enabled=False when magi is not configured). label is an optional suffix for the log line (e.g. "VLM language backbone").

nemo_automodel.components.distributed.magi_attn_utils.DEFAULT_CHUNK_SIZE = 512
nemo_automodel.components.distributed.magi_attn_utils._ACTIVE_ATTN_SPEC: Optional['AttnMaskSpec'] = None
nemo_automodel.components.distributed.magi_attn_utils._ACTIVE_CP_GROUP: Optional['dist.ProcessGroup'] = None
nemo_automodel.components.distributed.magi_attn_utils._FLEX_KEY_CACHE: dict = {}
nemo_automodel.components.distributed.magi_attn_utils._MAGI_REGISTERED = False
nemo_automodel.components.distributed.magi_attn_utils._MAGI_SELF_KEY_LEN: dict = {}
nemo_automodel.components.distributed.magi_attn_utils.logger = logging.getLogger(__name__)