nemo_automodel._transformers.te_attention#

TransformerEngine attention injection for HuggingFace models.

Replaces F.scaled_dot_product_attention within each HF self_attn module’s forward pass with TE’s DotProductAttention, enabling the FlashAttention-3 kernel and FP8 training without requiring model-specific rewrites.

The injection works by:

  1. Detecting self_attn modules with a standard HF projection layout (separate q_proj, k_proj, v_proj).

  2. Creating a DotProductAttention instance (stored as module.attn_module) so that :func:_uses_te_attention can detect it.

  3. Monkey-patching module.forward to temporarily swap in a TE-backed replacement for torch.nn.functional.scaled_dot_product_attention while the original HF forward runs.

Call :func:inject_te_attention on the model before FSDP wrapping and after any weight loading (so head-count/head-dim values are correct).

Supported patterns#

  • Standard Llama-style layout: separate q_proj/k_proj/v_proj, GQA via repeat_kv (enable_gqa=False) or enable_gqa=True. Covers Llama, Gemma, Qwen2, Mistral, and most popular HF causal LMs.

Mask handling#

  • is_causal=True with attn_mask=None → TE "causal".

  • is_causal=False with attn_mask=None → TE "no_mask".

  • A 4D attn_mask matching HF’s canonical causal or causal+sliding pattern is detected by :func:_detect_causal_mask in O(S) and converted to ("causal", window_size), so HF’s always-present mask doesn’t force a fallback. Per-sample padded batches or non-canonical mask patterns still fall back to native SDPA.

Sliding window#

  • The per-layer module.sliding_window attribute is read at injection time and converted to TE’s (window_size[0], 0) convention (sliding_window - 1 tokens to the left). Both the runtime mask detector and the attn_mask=None path pass this to TE as window_size.

Limitations#

  • Models using from torch.nn.functional import scaled_dot_product_attention (a local import) will not pick up the runtime patch; affected modules are skipped with a warning.

Module Contents#

Functions#

get_te_attention_stats

Return a snapshot of the TE dispatch counters.

reset_te_attention_stats

Zero out the TE dispatch counters (test / benchmarking helper).

_proj_out_features

Return the output feature count of a projection module.

_infer_attn_params

Infer attention hyper-parameters from a HF self_attn module.

_create_te_dot_product_attention

Instantiate a TE DotProductAttention for the given attention shape.

_detect_causal_mask

Map a 4D HF additive attention mask to a TE (attn_mask_type, window_size) pair.

_maybe_log_stats

Periodically emit the dispatch counters when auto-logging is enabled.

_make_te_sdpa

Return a callable that replaces F.scaled_dot_product_attention.

_patch_module_forward

Shadow module.forward with a version that uses TE for SDPA.

inject_te_attention_into_module

Inject TE attention into a single HF self_attn module.

inject_te_attention

Walk model and inject TE attention into all compatible self_attn modules.

Data#

API#

nemo_automodel._transformers.te_attention.logger#

‘getLogger(…)’

nemo_automodel._transformers.te_attention._TE_MODULE_ATTR#

‘attn_module’

nemo_automodel._transformers.te_attention._TE_MODEL_FLAG#

‘_te_attention_injected’

nemo_automodel._transformers.te_attention._TE_STATS: dict[str, int]#

None

nemo_automodel._transformers.te_attention._STATS_LOG_EVERY#

‘int(…)’

nemo_automodel._transformers.te_attention._SCALE_MISMATCH_WARNED#

False

nemo_automodel._transformers.te_attention.get_te_attention_stats() dict[str, int]#

Return a snapshot of the TE dispatch counters.

Keys: te_hits (ran the real TE kernel), fallback_mask (fell back because attn_mask was non-None), fallback_scale_mismatch (fell back because the runtime scale argument disagreed with the TE module’s fixed softmax_scale).

nemo_automodel._transformers.te_attention.reset_te_attention_stats() None#

Zero out the TE dispatch counters (test / benchmarking helper).

nemo_automodel._transformers.te_attention._proj_out_features(proj: torch.nn.Module | None) int | None#

Return the output feature count of a projection module.

Handles three layouts:

  • Standard nn.Linear: reads proj.out_features directly.

  • Weight-only: reads proj.weight.shape[0] (works on meta device).

  • Wrapped linear (e.g. Gemma4ClippableLinear): recurses into the proj.linear child module.

nemo_automodel._transformers.te_attention._infer_attn_params(
module: torch.nn.Module,
) dict[str, Any] | None#

Infer attention hyper-parameters from a HF self_attn module.

Returns None when the module does not match the expected layout.

Head counts are read from module attributes when present (standard HF), or inferred from projection out_features when absent (e.g. Gemma4TextAttention which stores head count only in the config).

nemo_automodel._transformers.te_attention._create_te_dot_product_attention(
num_heads: int,
num_kv_heads: int,
head_dim: int,
window_size: tuple[int, int] = (-1, 0),
softmax_scale: float | None = None,
) transformer_engine.pytorch.attention.DotProductAttention#

Instantiate a TE DotProductAttention for the given attention shape.

nemo_automodel._transformers.te_attention._detect_causal_mask(
attn_mask: torch.Tensor,
window_size: tuple[int, int],
) tuple[str, tuple[int, int]] | None#

Map a 4D HF additive attention mask to a TE (attn_mask_type, window_size) pair.

HF’s create_causal_mask and create_sliding_window_causal_mask always emit a 4D float mask even when the batch has no padding. This causes the generic attn_mask is not None guard to fire for every sliding/full-attention layer and route every call to native SDPA instead of TE.

This function detects the two structurally trivial cases:

  • Pure causal (lower-triangular 0 / -inf): returns ("causal", (-1, 0)).

  • Sliding-window causal: returns ("causal", window_size).

Returns None when the mask cannot be safely converted — e.g. it encodes per-sample padding, has an unexpected shape, or is not a float additive mask. The caller should fall back to native SDPA in that case.

Detection is O(S) per call (two row reductions of length S_k):

  1. Upper-right corner scalar check: must be < -1e4 (causal-like).

  2. First-row visible-key count across all batch items must equal 1 (each first query token can only attend to itself).

  3. Last-row visible-key count across all batch items must equal S (full causal) or min(S, window_size[0]+1) (sliding window). Any deviation indicates padding or a non-standard mask structure.

nemo_automodel._transformers.te_attention._maybe_log_stats() None#

Periodically emit the dispatch counters when auto-logging is enabled.

nemo_automodel._transformers.te_attention._make_te_sdpa(
te_module: torch.nn.Module,
num_heads: int,
num_kv_heads: int,
original_sdpa,
window_size: tuple[int, int] = (-1, 0),
softmax_scale: float | None = None,
) Any#

Return a callable that replaces F.scaled_dot_product_attention.

The replacement:

  • Transposes Q/K/V from HF’s [B, H, S, D] to TE’s [B, S, H, D].

  • Undoes repeat_kv when TE can handle GQA natively.

  • Falls back to original_sdpa for non-trivial attn_mask inputs.

  • Falls back when the caller passes an explicit scale that disagrees with the softmax_scale captured at TE-module creation time (TE freezes that value at construction; trying to override it silently would change numerics).

  • Transposes the TE output back to [B, H, S, D] before returning.

nemo_automodel._transformers.te_attention._patch_module_forward(module: torch.nn.Module, te_sdpa) None#

Shadow module.forward with a version that uses TE for SDPA.

nemo_automodel._transformers.te_attention.inject_te_attention_into_module(module: torch.nn.Module) bool#

Inject TE attention into a single HF self_attn module.

Returns True on success, False when the module does not match the expected layout.

nemo_automodel._transformers.te_attention.inject_te_attention(model: torch.nn.Module) None#

Walk model and inject TE attention into all compatible self_attn modules.

Skips modules that already carry attn_module (i.e. custom models or modules that were already patched). Sets model._te_attention_injected on success so that :func:_uses_te_attention can short-circuit the walk.