nemo_automodel.components.models.glm_moe_dsa.layers

View as Markdown

GLM-5.2 DSA layers.

Contains the GlmMoeDsaIndexer for top-k sparse attention selection and GlmMoeDsaMLA which integrates the indexer with Multi-head Latent Attention.

Module Contents

Classes

NameDescription
GlmMoeDsaIndexerIndexer for top-k sparse attention selection.
GlmMoeDsaMLAMulti-head Latent Attention with Indexer for sparse attention.

Functions

NameDescription
_apply_index_rope_half_splitApply NON-interleaved (half-split) RoPE to the indexer’s rope slice.
_rotate_activationApply Hadamard rotation activation.
_to_additive_key_maskConvert a {0,1} keep-mask (1=attend, 0=mask) to an ADDITIVE key mask (0 / finfo.min).
hadamard_transformFallback hadamard_transform when fast_hadamard_transform is not available.
hadamard_transform_torchMultiply H_n @ u where H_n is the Hadamard matrix of dimension n x n.

Data

_FAST_HADAMARD_AVAILABLE

API

class nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaIndexer(
config: transformers.models.glm_moe_dsa.configuration_glm_moe_dsa.GlmMoeDsaConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Indexer for top-k sparse attention selection.

Based on the official GLM-5.2 training implementation. Computes attention scores between queries and keys with per-head weights, applies ReLU activation, then selects the top-k positions to attend to.

Key features:

  • Uses LayerNorm (not RMSNorm) for key normalization
  • Has a weights_proj that learns per-head importance weights
  • Optional Hadamard transform (rotate_activation) on Q and K
  • ReLU activation on attention scores before weighting
head_dim
= config.index_head_dim
hidden_size
= config.hidden_size
index_topk
= config.index_topk
k_norm
num_heads
= config.index_n_heads
q_lora_rank
= config.q_lora_rank
qk_nope_head_dim
= self.head_dim - self.qk_rope_head_dim
qk_rope_head_dim
= config.qk_rope_head_dim
softmax_scale
= self.head_dim ** -0.5
weights_proj
wk
wq_b
nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaIndexer.forward(
x: torch.Tensor,
q_resid: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
attn_kwargs: typing.Any = {}
) -> torch.Tensor

Compute top-k indices for sparse attention.

Parameters:

x
torch.Tensor

Hidden states [B, S, hidden] or [T, hidden] for thd format

q_resid
torch.Tensor

Q lora residual from MLA [B, S, q_lora_rank] or [T, q_lora_rank]

freqs_cis
torch.Tensor

RoPE frequencies

attention_mask
torch.Tensor | NoneDefaults to None

Optional attention mask

**attn_kwargs
AnyDefaults to {}

Additional attention kwargs (cu_seqlens, etc.)

Returns: torch.Tensor

Indices of top-k positions [B, S, topk] or [T, topk]

nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaIndexer.init_weights(
init_std: float = 0.02
)
class nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaMLA(
config: transformers.models.glm_moe_dsa.configuration_glm_moe_dsa.GlmMoeDsaConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
skip_topk: bool = False
)

Bases: Module

Multi-head Latent Attention with Indexer for sparse attention.

This extends the V3 MLA with an Indexer module that performs top-k selection for sparse attention. The indexer uses the q_lora residual and hidden states to compute which positions to attend to.

index_topk
= config.index_topk
indexer
kv_a_layernorm
kv_a_proj_with_mqa
kv_b_proj
kv_lora_rank
= config.kv_lora_rank
n_heads
= config.num_attention_heads
o_proj
q_a_layernorm
q_a_proj
q_b_proj
q_lora_rank
= config.q_lora_rank
qk_head_dim
qk_nope_head_dim
= config.qk_nope_head_dim
qk_rope_head_dim
= config.qk_rope_head_dim
rope_fusion
= backend.rope_fusion
softmax_scale
= self.qk_head_dim ** -0.5
v_head_dim
= config.v_head_dim
nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaMLA._build_sparse_mask(
topk_indices: torch.Tensor,
seq_len: int,
qkv_format: str,
bsz: int = 1,
n_heads: int = 1,
dtype: torch.dtype = torch.bfloat16,
attention_mask: torch.Tensor | None = None,
union_across_batches: bool = False
) -> torch.Tensor

Build a sparse attention mask/bias from top-k indices.

Creates a mask tensor where non-top-k positions are set to finfo.min. Works for both TE (core_attention_bias) and SDPA (attn_mask).

Uses the same efficient pattern as the official DeepSeek inference code, but with finfo.min instead of -inf (F.sdpa mishandles -inf float masks): torch.full(..., finfo.min).scatter_(-1, topk_indices, 0)

Parameters:

topk_indices
torch.Tensor

Indices of top-k positions [B, S, topk] or [T, topk]

seq_len
int

Sequence length

qkv_format
str

‘bshd’ or ‘thd’

bsz
intDefaults to 1

Batch size (only used for bshd format)

n_heads
intDefaults to 1

Number of attention heads to expand to

dtype
torch.dtypeDefaults to torch.bfloat16

Data type for the output tensor

attention_mask
torch.Tensor | NoneDefaults to None

Optional attention mask to combine with (for SDPA)

union_across_batches
boolDefaults to False

If True, union top-k across batches (for TE); if False, keep per-batch masks (for SDPA)

Returns: torch.Tensor

Mask tensor with shape:

  • [1, n_heads, S, S] if union_across_batches=True
  • [B, n_heads, S, S] if union_across_batches=False (bshd)
  • [1, n_heads, T, T] for thd format
nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaMLA.forward(
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_mask: torch.Tensor | None = None,
prev_topk_indices: torch.Tensor | None = None,
return_topk_indices: bool = False,
attn_kwargs: typing.Any = {}
)

Run MLA with (optionally shared) DSA sparse attention.

Parameters:

x
torch.Tensor

Hidden states [B, S, hidden] (bshd) or [T, hidden] (thd).

freqs_cis
torch.Tensor

RoPE frequencies.

attention_mask
torch.Tensor | NoneDefaults to None

Optional additive attention mask.

prev_topk_indices
torch.Tensor | NoneDefaults to None

Top-k indices from the most recent “full” indexer layer. Required (and only used) when this is a “shared” layer (skip_topk=True).

return_topk_indices
boolDefaults to False

When True, return (attn_out, topk_indices) so the caller can thread the selection to subsequent shared layers (GLM IndexShare). When False (default), return just attn_out.

Returns:

attn_out tensor, or (attn_out, topk_indices) when return_topk_indices.

nemo_automodel.components.models.glm_moe_dsa.layers.GlmMoeDsaMLA.init_weights(
_buffer_device: torch.device,
init_std: float = 0.02
)
nemo_automodel.components.models.glm_moe_dsa.layers._apply_index_rope_half_split(
x: torch.Tensor,
freqs_cis: torch.Tensor,
qkv_format: str
) -> torch.Tensor

Apply NON-interleaved (half-split) RoPE to the indexer’s rope slice.

The DSA indexer uses half-split RoPE (rotate_half: pair dim j with j + d/2), unlike the main MLA attention which uses interleaved RoPE. freqs_cis is the same complex tensor used by the MLA (exp(i * theta_j * pos) for j in [0, d/2)); we read its real/imag parts as cos/sin so the angles match exactly.

Parameters:

x
torch.Tensor

rope slice, [B, S, H, d] / [B, S, d] (bshd) or [T, H, d] / [T, d] (thd).

freqs_cis
torch.Tensor

complex RoPE table with trailing dim d/2.

qkv_format
str

"bshd" or "thd".

nemo_automodel.components.models.glm_moe_dsa.layers._rotate_activation(
x: torch.Tensor
) -> torch.Tensor

Apply Hadamard rotation activation.

Parameters:

x
torch.Tensor

Input tensor (must be bfloat16).

Returns: torch.Tensor

Rotated tensor.

nemo_automodel.components.models.glm_moe_dsa.layers._to_additive_key_mask(
mask: torch.Tensor,
dtype: torch.dtype
) -> torch.Tensor

Convert a {0,1} keep-mask (1=attend, 0=mask) to an ADDITIVE key mask (0 / finfo.min).

Masked positions use finfo.min rather than -inf: F.scaled_dot_product_attention mishandles -inf float masks (its fused kernels corrupt the softmax). HF builds the attention bias with create_causal_mask, which likewise masks padding to finfo.min. The recipe, however, hands the model a 2D &#123;0,1&#125; padding mask; adding it to the scores raw (the previous behaviour) both fails to mask padding (0 -> +0 instead of finfo.min) AND adds +1.0 to every kept key, which is only softmax-invariant in fp32 — in bf16 the +1.0 swamps the (scaled) score differences and collapses attention toward uniform. A mask that is already additive (values <= 0) is returned unchanged.

nemo_automodel.components.models.glm_moe_dsa.layers.hadamard_transform(
x: torch.Tensor,
scale: float
) -> torch.Tensor

Fallback hadamard_transform when fast_hadamard_transform is not available.

nemo_automodel.components.models.glm_moe_dsa.layers.hadamard_transform_torch(
u,
scale: float,
normalize = False
)

Multiply H_n @ u where H_n is the Hadamard matrix of dimension n x n. n must be a power of 2. Parameters: u: Tensor of shape (…, n) normalize: if True, divide the result by 2^{m/2} where m = log_2(n). Returns: product: Tensor of shape (…, n)

nemo_automodel.components.models.glm_moe_dsa.layers._FAST_HADAMARD_AVAILABLE = True