nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd#

Module Contents#

Functions#

preprocess

postprocess

bwd

sparse_mqa_bwd_interface

Backward interface for V4 sparse MQA attention.

API#

nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.preprocess(
B,
S,
H,
D,
block_ND=32,
num_stages=5,
dtype=T.bfloat16,
accum_dtype=T.float32,
)#
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.postprocess(
B,
S_kv,
D,
block_N=64,
threads=128,
dtype=T.bfloat16,
accum_dtype=T.float32,
)#
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.bwd(
B,
S,
S_kv,
H,
D,
topk,
sm_scale=None,
block_size=32,
num_stages=0,
threads=128,
indices_dtype=T.int32,
dtype=T.bfloat16,
accum_dtype=T.float32,
)#
nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.sparse_mqa_bwd_interface(
q,
kv,
attn_sink,
o,
do,
topk_idxs,
lse,
sm_scale=None,
return_dkv_accum_dtype=False,
)#

Backward interface for V4 sparse MQA attention.

Parameters:
  • q – [B, S, H, D] bf16

  • kv – [B, S_kv, D] bf16

  • attn_sink – [H] fp32

  • o – [B, S, H, D] bf16 (forward output)

  • do – [B, S, H, D] bf16 (grad of output)

  • topk_idxs – [B, S, topk] int32

  • lse – [B, S, H] fp32 (log-sum-exp from forward)

  • sm_scale – float or None

Returns:

[B, S, H, D] bf16 dkv: [B, S_kv, D] bf16 d_attn_sink: [H] fp32

Return type:

dq