nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd

View as Markdown

Module Contents

Functions

NameDescription
bwd-
postprocess-
preprocess-
sparse_mqa_bwd_interfaceBackward interface for V4 sparse MQA attention.

API

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.bwd(
B,
S,
S_kv,
H,
D,
topk,
sm_scale = None,
block_size = 32,
num_stages = 0,
threads = 128,
indices_dtype = T.int32,
dtype = T.bfloat16,
accum_dtype = T.float32
)
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.postprocess(
B,
S_kv,
D,
block_N = 64,
threads = 128,
dtype = T.bfloat16,
accum_dtype = T.float32
)
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.preprocess(
B,
S,
H,
D,
block_ND = 32,
num_stages = 5,
dtype = T.bfloat16,
accum_dtype = T.float32
)
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.sparse_mqa_bwd_interface(
q,
kv,
attn_sink,
o,
do,
topk_idxs,
lse,
sm_scale = None,
return_dkv_accum_dtype = False
)

Backward interface for V4 sparse MQA attention.

Parameters:

q

[B, S, H, D] bf16

kv

[B, S_kv, D] bf16

attn_sink

[H] fp32

o

[B, S, H, D] bf16 (forward output)

do

[B, S, H, D] bf16 (grad of output)

topk_idxs

[B, S, topk] int32

lse

[B, S, H] fp32 (log-sum-exp from forward)

sm_scale
Defaults to None

float or None

Returns:

[B, S, H, D] bf16