nemo_automodel.components.models.glm_moe_dsa.layers
nemo_automodel.components.models.glm_moe_dsa.layers
GLM-5.2 DSA layers.
Contains the GlmMoeDsaIndexer for top-k sparse attention selection and GlmMoeDsaMLA which integrates the indexer with Multi-head Latent Attention.
Module Contents
Classes
Functions
Data
API
Bases: Module
Indexer for top-k sparse attention selection.
Based on the official GLM-5.2 training implementation. Computes attention scores between queries and keys with per-head weights, applies ReLU activation, then selects the top-k positions to attend to.
Key features:
- Uses LayerNorm (not RMSNorm) for key normalization
- Has a weights_proj that learns per-head importance weights
- Optional Hadamard transform (rotate_activation) on Q and K
- ReLU activation on attention scores before weighting
Compute top-k indices for sparse attention.
Parameters:
Hidden states [B, S, hidden] or [T, hidden] for thd format
Q lora residual from MLA [B, S, q_lora_rank] or [T, q_lora_rank]
RoPE frequencies
Optional attention mask
Additional attention kwargs (cu_seqlens, etc.)
Returns: torch.Tensor
Indices of top-k positions [B, S, topk] or [T, topk]
Bases: Module
Multi-head Latent Attention with Indexer for sparse attention.
This extends the V3 MLA with an Indexer module that performs top-k selection for sparse attention. The indexer uses the q_lora residual and hidden states to compute which positions to attend to.
Build a sparse attention mask/bias from top-k indices.
Creates a mask tensor where non-top-k positions are set to finfo.min. Works for both TE (core_attention_bias) and SDPA (attn_mask).
Uses the same efficient pattern as the official DeepSeek inference code, but with
finfo.min instead of -inf (F.sdpa mishandles -inf float masks):
torch.full(..., finfo.min).scatter_(-1, topk_indices, 0)
Parameters:
Indices of top-k positions [B, S, topk] or [T, topk]
Sequence length
‘bshd’ or ‘thd’
Batch size (only used for bshd format)
Number of attention heads to expand to
Data type for the output tensor
Optional attention mask to combine with (for SDPA)
If True, union top-k across batches (for TE); if False, keep per-batch masks (for SDPA)
Returns: torch.Tensor
Mask tensor with shape:
- [1, n_heads, S, S] if union_across_batches=True
- [B, n_heads, S, S] if union_across_batches=False (bshd)
- [1, n_heads, T, T] for thd format
Run MLA with (optionally shared) DSA sparse attention.
Parameters:
Hidden states [B, S, hidden] (bshd) or [T, hidden] (thd).
RoPE frequencies.
Optional additive attention mask.
Top-k indices from the most recent “full” indexer layer.
Required (and only used) when this is a “shared” layer (skip_topk=True).
When True, return (attn_out, topk_indices) so the
caller can thread the selection to subsequent shared layers (GLM IndexShare).
When False (default), return just attn_out.
Returns:
attn_out tensor, or (attn_out, topk_indices) when return_topk_indices.
Apply NON-interleaved (half-split) RoPE to the indexer’s rope slice.
The DSA indexer uses half-split RoPE (rotate_half: pair dim j with j + d/2),
unlike the main MLA attention which uses interleaved RoPE. freqs_cis is the same
complex tensor used by the MLA (exp(i * theta_j * pos) for j in [0, d/2)); we read
its real/imag parts as cos/sin so the angles match exactly.
Parameters:
rope slice, [B, S, H, d] / [B, S, d] (bshd) or [T, H, d] / [T, d] (thd).
complex RoPE table with trailing dim d/2.
"bshd" or "thd".
Apply Hadamard rotation activation.
Parameters:
Input tensor (must be bfloat16).
Returns: torch.Tensor
Rotated tensor.
Convert a {0,1} keep-mask (1=attend, 0=mask) to an ADDITIVE key mask (0 / finfo.min).
Masked positions use finfo.min rather than -inf: F.scaled_dot_product_attention
mishandles -inf float masks (its fused kernels corrupt the softmax). HF builds the
attention bias with create_causal_mask, which likewise masks padding to finfo.min.
The recipe, however, hands the model a 2D {0,1} padding mask; adding it to the scores
raw (the previous behaviour) both fails to mask padding (0 -> +0 instead of finfo.min) AND adds
+1.0 to every kept key, which is only softmax-invariant in fp32 — in bf16 the +1.0
swamps the (scaled) score differences and collapses attention toward uniform. A mask that is
already additive (values <= 0) is returned unchanged.
Fallback hadamard_transform when fast_hadamard_transform is not available.
Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n. n must be a power of 2. Parameters: u: Tensor of shape (…, n) normalize: if True, divide the result by 2^{m/2} where m = log_2(n). Returns: product: Tensor of shape (…, n)