nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd

View as Markdown

Module Contents

Functions

NameDescription
batched_indexer_bwdBatched backward: loops over batch dim.
indexer_bwd_interfaceBackward interface for a single batch element.
tl_indexer_bwd_impl-

Data

BF16

FP32

INT32

pass_configs

API

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.batched_indexer_bwd(
index_q,
weights,
index_k,
topk_indices,
grad_scores
)

Batched backward: loops over batch dim.

Parameters:

index_q

[seqlen, batch, heads, dim] bf16

weights

[seqlen, batch, heads] fp32

index_k

[seqlen_kv, batch, dim] bf16

topk_indices

[batch, seqlen, topk] int32

grad_scores

[batch, seqlen, topk] fp32

Returns:

[seqlen, batch, heads, dim] bf16

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.indexer_bwd_interface(
index_q: torch.Tensor,
weights: torch.Tensor,
index_k: torch.Tensor,
topk_indices: torch.Tensor,
grad_scores: torch.Tensor
)

Backward interface for a single batch element.

Parameters:

index_q
torch.Tensor

[seq_len, heads, dim] bf16

weights
torch.Tensor

[seq_len, heads] fp32

index_k
torch.Tensor

[seq_len_kv, dim] bf16

topk_indices
torch.Tensor

[seq_len, topk] int32

grad_scores
torch.Tensor

[seq_len, topk] fp32

Returns:

[seq_len, heads, dim] bf16

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.tl_indexer_bwd_impl(
heads: int,
dim: int,
topk: int,
block_I: int = 32,
num_stages: int = 0,
num_threads: int = 128
)
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.BF16 = T.bfloat16
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.FP32 = T.float32
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.INT32 = T.int32
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_indexer_bwd.pass_configs = {tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tl.PassConfigKey.TL_DISABLE_WARP_S...