nemo_automodel.components.models.glm_moe_dsa.optimized_kernels

View as Markdown

Optional GLM-5.2 DSA TileLang kernel dispatch.

The dense torch path in layers.py (GlmMoeDsaIndexer / GlmMoeDsaMLA) is the numerical reference and the default. This module provides the optional TileLang-backed sparse path, gated behind backend.attn == "tilelang":

  • the fused lighting indexer (logits + top-k), and
  • the gather-top-k sparse MLA attention (reads only the selected KV; no [T, T] mask is materialized).

The kernels are vendored from THUDMโ€™s slime GLM-5.2 plugin under nemo_automodel.components.models.glm_moe_dsa.kernels (see that packageโ€™s __init__ for attribution). They are imported with safe_import_from so environments without tilelang still import the model and use the torch path.

Mirrors the structure of deepseek_v4/optimized_kernels.py.

Module Contents

Functions

NameDescription
_all_cuda-
_generate_padded_varlen_mask_paramsBuild per-query key windows for padded THD layout.
is_dsa_kernel_availableReturn whether the optional TileLang kernel package for name is importable.
should_use_tilelangDecide whether to run the TileLang kernel; raise if forced but unavailable.
tilelang_indexer_topkFused lighting-indexer top-k selection (THD/varlen).
tilelang_sparse_attentionGather-top-k sparse MLA attention on the absorbed latent representation.

Data

DsaIndexerBackend

DsaSparseAttentionBackend

API

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels._all_cuda(
tensors: torch.Tensor = ()
) -> bool
nemo_automodel.components.models.glm_moe_dsa.optimized_kernels._generate_padded_varlen_mask_params(
cu_seqlens: torch.Tensor,
cu_seqlens_padded: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]

Build per-query key windows for padded THD layout.

cu_seqlens stores real document lengths compactly. cu_seqlens_padded stores their offsets in the actual flattened token tensor, including CP padding between documents. TileLang top-k indices refer to the flattened token tensor, so starts/ends must be in the padded coordinate space while still excluding CP padding keys for real query tokens.

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.is_dsa_kernel_available(
name: typing.Literal['indexer', 'sparse_attn']
) -> bool

Return whether the optional TileLang kernel package for name is importable.

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.should_use_tilelang(
backend: str,
available: bool,
kernel_name: str,
tensors: tuple[torch.Tensor, ...],
require_bf16: bool = False
) -> bool

Decide whether to run the TileLang kernel; raise if forced but unavailable.

backend="tilelang" forces the kernel (and raises a clear error if it cannot run); backend="auto" silently falls back to torch when the kernel is unavailable; backend="torch" always uses the torch reference.

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.tilelang_indexer_topk(
index_q: torch.Tensor,
index_k: torch.Tensor,
head_weights: torch.Tensor,
cu_seqlens: torch.Tensor,
index_topk: int,
query_indices: torch.Tensor | None = None,
cu_seqlens_padded: torch.Tensor | None = None
) -> torch.Tensor

Fused lighting-indexer top-k selection (THD/varlen).

Parameters:

index_q
torch.Tensor

[T, index_n_heads, index_head_dim] bf16 (rope already applied).

index_k
torch.Tensor

[T, index_head_dim] bf16 (k_norm + rope already applied).

head_weights
torch.Tensor

[T, index_n_heads] fp32; the caller must fold the index softmax_scale into the weight (the kernel computes relu(qยทk) * w with no internal scale), i.e. weights_proj(x) * index_n_heads**-0.5 * index_head_dim**-0.5.

cu_seqlens
torch.Tensor

[num_seq + 1] cumulative sequence lengths of the packed batch.

index_topk
int

number of keys to keep (e.g. 2048).

query_indices
torch.Tensor | NoneDefaults to None

Optional global THD token indices for the local query rows. Used by context parallelism when index_q is sharded but index_k has been all-gathered in global token order.

cu_seqlens_padded
torch.Tensor | NoneDefaults to None

Optional cumulative lengths in the padded THD token layout. CP-packed datasets pad each document to a CP multiple, so local query indices address this padded layout rather than the compact cu_seqlens layout.

Returns: torch.Tensor

topk_indices [T, 1, index_topk] int32 (-1 for invalid/causal-masked),

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.tilelang_sparse_attention(
q: torch.Tensor,
kv_latent: torch.Tensor,
topk_indices: torch.Tensor,
w_vc: torch.Tensor,
softmax_scale: float
) -> torch.Tensor

Gather-top-k sparse MLA attention on the absorbed latent representation.

Parameters:

q
torch.Tensor

[T, n_heads, kv_lora_rank + qk_rope_head_dim] bf16 โ€” the absorbed query cat([q_nope @ w_kc, q_pe], -1) (e.g. 512 + 64 = 576).

kv_latent
torch.Tensor

[T, 1, kv_lora_rank + qk_rope_head_dim] bf16 โ€” the latent KV cat([kv_compressed, k_pe], -1).

topk_indices
torch.Tensor

[T, 1, index_topk] int32 (-1 sentinel).

w_vc
torch.Tensor

[n_heads, v_head_dim, kv_lora_rank] โ€” the value up-projection used to map the latent attention output back to v_head_dim.

softmax_scale
float

MLA attention scale mscale**2 / sqrt(qk_head_dim) (NOT the kernelโ€™s 1/sqrt(dim+tail) default).

Returns: torch.Tensor

attn_out [T, n_heads, v_head_dim] bf16.

nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.DsaIndexerBackend = Literal['torch', 'tilelang', 'auto']
nemo_automodel.components.models.glm_moe_dsa.optimized_kernels.DsaSparseAttentionBackend = Literal['torch', 'tilelang', 'auto']