nemo_automodel._transformers.te_attention#
TransformerEngine attention injection for HuggingFace models.
Replaces F.scaled_dot_product_attention within each HF self_attn
module’s forward pass with TE’s DotProductAttention, enabling the
FlashAttention-3 kernel and FP8 training without requiring model-specific
rewrites.
The injection works by:
Detecting
self_attnmodules with a standard HF projection layout (separateq_proj,k_proj,v_proj).Creating a
DotProductAttentioninstance (stored asmodule.attn_module) so that :func:_uses_te_attentioncan detect it.Monkey-patching
module.forwardto temporarily swap in a TE-backed replacement fortorch.nn.functional.scaled_dot_product_attentionwhile the original HF forward runs.
Call :func:inject_te_attention on the model before FSDP wrapping and
after any weight loading (so head-count/head-dim values are correct).
Supported patterns#
Standard Llama-style layout: separate
q_proj/k_proj/v_proj, GQA viarepeat_kv(enable_gqa=False) orenable_gqa=True. Covers Llama, Gemma, Qwen2, Mistral, and most popular HF causal LMs.
Mask handling#
is_causal=Truewithattn_mask=None→ TE"causal".is_causal=Falsewithattn_mask=None→ TE"no_mask".A 4D
attn_maskmatching HF’s canonical causal or causal+sliding pattern is detected by :func:_detect_causal_maskin O(S) and converted to("causal", window_size), so HF’s always-present mask doesn’t force a fallback. Per-sample padded batches or non-canonical mask patterns still fall back to native SDPA.
Sliding window#
The per-layer
module.sliding_windowattribute is read at injection time and converted to TE’s(window_size[0], 0)convention (sliding_window - 1tokens to the left). Both the runtime mask detector and theattn_mask=Nonepath pass this to TE aswindow_size.
Limitations#
Models using
from torch.nn.functional import scaled_dot_product_attention(a local import) will not pick up the runtime patch; affected modules are skipped with a warning.
Module Contents#
Functions#
Return a snapshot of the TE dispatch counters. |
|
Zero out the TE dispatch counters (test / benchmarking helper). |
|
Return the output feature count of a projection module. |
|
Infer attention hyper-parameters from a HF |
|
Instantiate a TE |
|
Map a 4D HF additive attention mask to a TE (attn_mask_type, window_size) pair. |
|
Periodically emit the dispatch counters when auto-logging is enabled. |
|
Return a callable that replaces |
|
Shadow |
|
Inject TE attention into a single HF |
|
Walk model and inject TE attention into all compatible |
Data#
API#
- nemo_automodel._transformers.te_attention.logger#
‘getLogger(…)’
- nemo_automodel._transformers.te_attention._TE_MODULE_ATTR#
‘attn_module’
- nemo_automodel._transformers.te_attention._TE_MODEL_FLAG#
‘_te_attention_injected’
- nemo_automodel._transformers.te_attention._TE_STATS: dict[str, int]#
None
- nemo_automodel._transformers.te_attention._STATS_LOG_EVERY#
‘int(…)’
- nemo_automodel._transformers.te_attention._SCALE_MISMATCH_WARNED#
False
- nemo_automodel._transformers.te_attention.get_te_attention_stats() dict[str, int]#
Return a snapshot of the TE dispatch counters.
Keys:
te_hits(ran the real TE kernel),fallback_mask(fell back becauseattn_maskwas non-None),fallback_scale_mismatch(fell back because the runtimescaleargument disagreed with the TE module’s fixedsoftmax_scale).
- nemo_automodel._transformers.te_attention.reset_te_attention_stats() None#
Zero out the TE dispatch counters (test / benchmarking helper).
- nemo_automodel._transformers.te_attention._proj_out_features(proj: torch.nn.Module | None) int | None#
Return the output feature count of a projection module.
Handles three layouts:
Standard
nn.Linear: readsproj.out_featuresdirectly.Weight-only: reads
proj.weight.shape[0](works on meta device).Wrapped linear (e.g.
Gemma4ClippableLinear): recurses into theproj.linearchild module.
- nemo_automodel._transformers.te_attention._infer_attn_params(
- module: torch.nn.Module,
Infer attention hyper-parameters from a HF
self_attnmodule.Returns
Nonewhen the module does not match the expected layout.Head counts are read from module attributes when present (standard HF), or inferred from projection
out_featureswhen absent (e.g.Gemma4TextAttentionwhich stores head count only in the config).
- nemo_automodel._transformers.te_attention._create_te_dot_product_attention(
- num_heads: int,
- num_kv_heads: int,
- head_dim: int,
- window_size: tuple[int, int] = (-1, 0),
- softmax_scale: float | None = None,
Instantiate a TE
DotProductAttentionfor the given attention shape.
- nemo_automodel._transformers.te_attention._detect_causal_mask(
- attn_mask: torch.Tensor,
- window_size: tuple[int, int],
Map a 4D HF additive attention mask to a TE (attn_mask_type, window_size) pair.
HF’s
create_causal_maskandcreate_sliding_window_causal_maskalways emit a 4D float mask even when the batch has no padding. This causes the genericattn_mask is not Noneguard to fire for every sliding/full-attention layer and route every call to native SDPA instead of TE.This function detects the two structurally trivial cases:
Pure causal (lower-triangular 0 / -inf): returns
("causal", (-1, 0)).Sliding-window causal: returns
("causal", window_size).
Returns
Nonewhen the mask cannot be safely converted — e.g. it encodes per-sample padding, has an unexpected shape, or is not a float additive mask. The caller should fall back to native SDPA in that case.Detection is O(S) per call (two row reductions of length S_k):
Upper-right corner scalar check: must be < -1e4 (causal-like).
First-row visible-key count across all batch items must equal 1 (each first query token can only attend to itself).
Last-row visible-key count across all batch items must equal
S(full causal) ormin(S, window_size[0]+1)(sliding window). Any deviation indicates padding or a non-standard mask structure.
- nemo_automodel._transformers.te_attention._maybe_log_stats() None#
Periodically emit the dispatch counters when auto-logging is enabled.
- nemo_automodel._transformers.te_attention._make_te_sdpa(
- te_module: torch.nn.Module,
- num_heads: int,
- num_kv_heads: int,
- original_sdpa,
- window_size: tuple[int, int] = (-1, 0),
- softmax_scale: float | None = None,
Return a callable that replaces
F.scaled_dot_product_attention.The replacement:
Transposes Q/K/V from HF’s
[B, H, S, D]to TE’s[B, S, H, D].Undoes
repeat_kvwhen TE can handle GQA natively.Falls back to
original_sdpafor non-trivialattn_maskinputs.Falls back when the caller passes an explicit
scalethat disagrees with thesoftmax_scalecaptured at TE-module creation time (TE freezes that value at construction; trying to override it silently would change numerics).Transposes the TE output back to
[B, H, S, D]before returning.
- nemo_automodel._transformers.te_attention._patch_module_forward(module: torch.nn.Module, te_sdpa) None#
Shadow
module.forwardwith a version that uses TE for SDPA.
- nemo_automodel._transformers.te_attention.inject_te_attention_into_module(module: torch.nn.Module) bool#
Inject TE attention into a single HF
self_attnmodule.Returns
Trueon success,Falsewhen the module does not match the expected layout.
- nemo_automodel._transformers.te_attention.inject_te_attention(model: torch.nn.Module) None#
Walk model and inject TE attention into all compatible
self_attnmodules.Skips modules that already carry
attn_module(i.e. custom models or modules that were already patched). Setsmodel._te_attention_injectedon success so that :func:_uses_te_attentioncan short-circuit the walk.