nemo_automodel.components.distributed.magi_attn_utils
nemo_automodel.components.distributed.magi_attn_utils
MagiAttention integration for Automodel.
MagiAttention (https://github.com/SandAI-org/MagiAttention) is a distributed (context-parallel) attention built on a Flex-Flash-Attention (FFA) kernel. It shards a single packed sequence across a CP process group with a load-balancing dispatch solver and exchanges KV with zero-redundant GroupCast/GroupReduce collectives.
This module wires MagiAttention into the HF-transformers-based LLM path used by
recipes/llm/train_ft.py following MagiAttention’s official
examples/transformers recipe:
register_magi_attention()registers a"magi"entry in HF’sALL_ATTENTION_FUNCTIONSso that a model loaded withattn_implementation="magi"routes its attention through the FFA kernel.magi_prepare_batch()builds the per-step dist-attn runtime key, dispatchesinput_ids/position_ids/labelsacross the CP group and stampscp_groupon every attention sub-module so the registered forward finds the key.- Each rank runs the model on its local shard and computes a per-shard loss; the recipe’s cross-CP reduction sums the shards into the global loss (like TE-CP). Sharding labels (rather than undispatching logits) keeps the loss path identical for the HF and custom-model backends.
When cp_size == 1 the dispatch is a no-op shard (identity + chunk padding),
so this path is also a clean way to swap only the attention kernel (FFA) in
place of eager/SDPA/flash for convergence-parity comparisons.
Module Contents
Classes
Functions
Data
API
Backend-agnostic description of an attention mask as AttnSlice rectangles.
A single full causal sequence.
Hashable identity used to cache the built key across layers.
Build a prefix-tree mask over a flat deduplicated token layout.
Each node attends FULL to every ancestor node in its path and CAUSAL to itself; duplicate rectangles (shared nodes) are emitted once.
Parameters:
token count of each node, in flat layout order. The flat
layout is [node_0 | node_1 | ...] with node i occupying
[offset_i, offset_i + node_lengths[i]).
one list of node indices per sample, root -> leaf. Every sample is the causal concatenation of its nodes; a shared prefix node simply appears in multiple paths.
Returns: (spec, sample_token_ranges)
spec is the AttnMaskSpec; the second
Block-diagonal mask for packed sequences (one block per document).
Resolved MagiAttention wiring for a recipe, produced by :func:setup_magi.
A single handle (stored as self.magi) replacing the scattered
magi_enabled/magi_custom/magi_cp_group/magi_cp_size recipe
attributes. When MagiAttention is not configured, enabled is False and the
per-step methods are no-ops, so recipes can call them unconditionally.
HF path: dispatch the sequence (input + labels) across CP for a per-shard loss.
Distinguishes the HF attn_implementation=magi path (single causal sequence,
:func:magi_prepare_batch) from the custom-model factory path; both shard labels
and compute a per-shard loss at cp>1, so neither undispatches logits.
Per-step batch prep for the LLM recipe (assumes enabled).
Returns (train_ctx, batch). magi does its own CP, so train_ctx is
always nullcontext (no torch-native DTensor CP context).
Per-step batch prep for the VLM recipe (assumes enabled).
HF VLMs stamp the cp_group on the language-backbone attention; custom VLMs
use the factory attn_func with the active cp_group set in :func:setup_magi
(the vision tower stays on SDPA either way). Returns (train_ctx, batch).
Build a cp=1 causal varlen key matching the actual q length (no dispatch).
Return the flex key for spec, rebuilding only when the mask changes.
Extract (num_heads_q, num_heads_kv, head_dim) from an HF model/config.
For VLMs the text attention dims live under config.text_config; prefer that
sub-config when the top-level config does not expose num_attention_heads.
Yield attention sub-modules belonging to the language backbone only.
For VLMs we must leave the vision tower on its own (bidirectional) attention.
HF VLMs nest the text stack under a language_model/model.language_model
attribute; we walk only that subtree. Falls back to the whole model for plain
LLMs (no language_model attribute).
Resolve per-document lengths spanning the padded THD layout of length total_len.
The TE collater pads each document for the THD layout, so cu_seqlens_padded
spans the full flat input_ids while cu_seqlens covers only the real
tokens. magi dispatches the whole flat sequence, so the dist key must be built
over the padded layout — otherwise the dispatched shard length (from
input_ids) won’t match get_position_ids (built from the key), which
surfaces downstream as a RoPE q vs cos/sin length mismatch. Causal masking keeps
real tokens from attending the trailing per-document pad, and pad-token rows are
dropped by the loss (labels == ignore_index), so this is numerically equivalent
to attending only the real tokens.
Raises:
ValueError: if the resolved document layout does not spantotal_len.
Return a causal self-key for seqlen, (re)building only when it changes.
All attention layers in one forward share the same sequence length, so the
first layer builds the key and the rest reuse it via get_most_recent_key.
Stamp cp_group on every attention sub-module so the FFA forward finds the key.
Build a magi dist-attn key for an arbitrary AttnSlice mask (no extra padding).
Return the active mask spec (None -> plain causal self-key).
Return the CP group set by :func:set_active_cp_group (may be None).
Return the CP process group from the device mesh (size-1 group is fine).
Return True if the magi_attention package is importable.
Dispatch a (batch_size==1) sequence for MagiAttention on the HF path.
Builds a causal varlen dist-attn key over the single sequence and dispatches
input_ids, position_ids and labels across cp_group (identity
shard at cp_size==1; load-balanced sharding at cp_size>1). Labels are sharded the
same way as the input so the loss is computed per-shard and summed across CP — no
logit undispatch; MaskedCrossEntropy does not shift, so the dispatch
permutation is harmless (logits[j] stay paired with labels[j]). The FFA kernel
does the cross-rank KV exchange via the stamped cp_group + the dist key.
Parameters:
the (possibly FSDP-wrapped) HF causal-LM.
dict with at least input_ids of shape [1, S].
CP process group (size 1 -> identity shard).
dispatch solver chunk size.
Returns: (new_batch, key)
new_batch has dispatched input_ids/position_ids/
Context-parallel prep for a packed (THD) batch on the custom-model path.
Takes a global THD batch (flat input_ids/labels/position_ids plus
cu_seqlens marking document boundaries), builds a per-document varlen dist
key over cp_group and dispatches the sequence with MagiAttention’s own
load-balancing solver (not TE’s THD sharding). Each rank then runs the model on
its local shard; the attn_func uses this pre-built key so the FFA kernel does
the cross-rank KV exchange. Labels are dispatched the same way as the input, so
each rank computes a per-shard loss that the recipe’s cross-CP reduction sums
into the global loss (like TE-CP).
Returns: (new_batch, key)
new_batch has the local input_ids/position_ids
Prepare a VLM (bs==1) step for MagiAttention on the language backbone.
Unlike the LLM path, the VLM merges image features into inputs_embeds
inside its forward, so we cannot dispatch input_ids (image-placeholder
positions must stay put). For cp_size == 1 we instead build a no-padding
causal key (chunk_size=1 -> dispatch/undispatch are identity), stamp the
cp_group on the language-model attention modules only, and let the FFA kernel
run on the natural-length q/k/v. The vision tower falls back to SDPA.
Returns the (unchanged) batch and None (the key is built lazily in the
attention forward from the real query length).
Build the attn_func used by the custom-model attention factory.
The returned callable accepts q/k/v in either THD [t, nh, hd] or BSHD
[b, s, nh, hd] (b must be 1) layout — the same layouts the custom models
feed to their backend attn_func — and runs the MagiAttention FFA kernel.
Mask selection (no CP dispatch; cp=1 in-order tokens):
- if an :class:
AttnMaskSpecis active (set_active_attn_spec), use its flex key — this covers packing, sliding-window and prefix-tree masks; - otherwise build a plain causal self-key from the q length. If no CP group is active it falls back to causal SDPA so non-magi modules (e.g. a VLM vision tower routed through the same factory) keep working.
Parameters:
attention softmax scale (defaults to 1/sqrt(head_dim) inside FFA when None); forwarded so non-default scales stay correct.
Register the "magi" attention backend in HF transformers (idempotent).
Set the mask spec the custom-model magi attn_func should apply this step.
Record the CP group the custom-model magi attn_func should use.
Resolve MagiAttention from config: register the backend and CP group.
Enabled when the model is configured with attn_implementation="magi" (HF) or
backend.attn="magi" (custom models). Returns a :class:MagiState
(enabled=False when magi is not configured). label is an optional suffix
for the log line (e.g. "VLM language backbone").