nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer

View as Markdown

TileLang-based DSA Indexer for DeepSeek-V4.

Adapts GLM-5’s lighting_indexer to V4’s SBHD data layout and causal masking. Provides both a low-level per-sample interface and a batched autograd Function.

Module Contents

Classes

NameDescription
V4IndexerFunctionAutograd function for V4 tilelang indexer.

Functions

NameDescription
pytorch_extract_topk_scores-
v4_lighting_indexerMain entry point for V4 tilelang indexer.

API

class nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer.V4IndexerFunction()

Bases: Function

Autograd function for V4 tilelang indexer.

Inputs are in V4’s native SBHD layout: q: [seqlen, batch, heads, dim] bf16 k: [seqlen_kv, batch, dim] bf16 weights: [seqlen, batch, heads] fp32

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer.V4IndexerFunction.backward(
ctx,
grad_scores,
grad_indices
)
staticmethod
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer.V4IndexerFunction.forward(
ctx,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
compress_ratio: int,
topk: int,
topk_indices: torch.Tensor | None = None
)
staticmethod
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer.pytorch_extract_topk_scores(
logits,
topk_indices,
dim = -1
)
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer.v4_lighting_indexer(
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
compress_ratio: int,
topk: int,
topk_indices: torch.Tensor | None = None
)

Main entry point for V4 tilelang indexer.

Parameters:

index_q
torch.Tensor

[seqlen, batch, heads, dim] bf16

index_k
torch.Tensor

[seqlen_kv, batch, dim] bf16

weights
torch.Tensor

[seqlen, batch, heads] fp32

compress_ratio
int

compression ratio (4 for C4 layers)

topk
int

number of top-k indices to select

topk_indices
torch.Tensor | NoneDefaults to None

optional pre-computed topk indices [batch, seqlen, topk] int32

Returns:

[batch, seqlen, topk] fp32