nemo_automodel.components.models.glm_moe_dsa.optimized_kernels
nemo_automodel.components.models.glm_moe_dsa.optimized_kernels
Optional GLM-5.2 DSA TileLang kernel dispatch.
The dense torch path in layers.py (GlmMoeDsaIndexer / GlmMoeDsaMLA)
is the numerical reference and the default. This module provides the optional
TileLang-backed sparse path, gated behind backend.attn == "tilelang":
- the fused lighting indexer (logits + top-k), and
- the gather-top-k sparse MLA attention (reads only the selected KV; no
[T, T]mask is materialized).
The kernels are vendored from THUDMโs slime GLM-5.2 plugin under
nemo_automodel.components.models.glm_moe_dsa.kernels (see that packageโs
__init__ for attribution). They are imported with safe_import_from so
environments without tilelang still import the model and use the torch path.
Mirrors the structure of deepseek_v4/optimized_kernels.py.
Module Contents
Functions
Data
API
Build per-query key windows for padded THD layout.
cu_seqlens stores real document lengths compactly. cu_seqlens_padded
stores their offsets in the actual flattened token tensor, including CP
padding between documents. TileLang top-k indices refer to the flattened
token tensor, so starts/ends must be in the padded coordinate space while
still excluding CP padding keys for real query tokens.
Return whether the optional TileLang kernel package for name is importable.
Decide whether to run the TileLang kernel; raise if forced but unavailable.
backend="tilelang" forces the kernel (and raises a clear error if it cannot run);
backend="auto" silently falls back to torch when the kernel is unavailable;
backend="torch" always uses the torch reference.
Fused lighting-indexer top-k selection (THD/varlen).
Parameters:
[T, index_n_heads, index_head_dim] bf16 (rope already applied).
[T, index_head_dim] bf16 (k_norm + rope already applied).
[T, index_n_heads] fp32; the caller must fold the index
softmax_scale into the weight (the kernel computes relu(qยทk) * w
with no internal scale), i.e. weights_proj(x) * index_n_heads**-0.5 * index_head_dim**-0.5.
[num_seq + 1] cumulative sequence lengths of the packed batch.
number of keys to keep (e.g. 2048).
Optional global THD token indices for the local query
rows. Used by context parallelism when index_q is sharded but
index_k has been all-gathered in global token order.
Optional cumulative lengths in the padded THD token
layout. CP-packed datasets pad each document to a CP multiple, so
local query indices address this padded layout rather than the
compact cu_seqlens layout.
Returns: torch.Tensor
topk_indices [T, 1, index_topk] int32 (-1 for invalid/causal-masked),
Gather-top-k sparse MLA attention on the absorbed latent representation.
Parameters:
[T, n_heads, kv_lora_rank + qk_rope_head_dim] bf16 โ the absorbed query
cat([q_nope @ w_kc, q_pe], -1) (e.g. 512 + 64 = 576).
[T, 1, kv_lora_rank + qk_rope_head_dim] bf16 โ the latent KV
cat([kv_compressed, k_pe], -1).
[T, 1, index_topk] int32 (-1 sentinel).
[n_heads, v_head_dim, kv_lora_rank] โ the value up-projection used to
map the latent attention output back to v_head_dim.
MLA attention scale mscale**2 / sqrt(qk_head_dim) (NOT the
kernelโs 1/sqrt(dim+tail) default).
Returns: torch.Tensor
attn_out [T, n_heads, v_head_dim] bf16.