nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd#
Module Contents#
Functions#
Backward interface for V4 sparse MQA attention. |
API#
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.preprocess(
- B,
- S,
- H,
- D,
- block_ND=32,
- num_stages=5,
- dtype=T.bfloat16,
- accum_dtype=T.float32,
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.postprocess(
- B,
- S_kv,
- D,
- block_N=64,
- threads=128,
- dtype=T.bfloat16,
- accum_dtype=T.float32,
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.bwd(
- B,
- S,
- S_kv,
- H,
- D,
- topk,
- sm_scale=None,
- block_size=32,
- num_stages=0,
- threads=128,
- indices_dtype=T.int32,
- dtype=T.bfloat16,
- accum_dtype=T.float32,
- nemo_automodel.components.models.deepseek_v4.kernels.tilelang_sparse_mla_bwd.sparse_mqa_bwd_interface(
- q,
- kv,
- attn_sink,
- o,
- do,
- topk_idxs,
- lse,
- sm_scale=None,
- return_dkv_accum_dtype=False,
Backward interface for V4 sparse MQA attention.
- Parameters:
q – [B, S, H, D] bf16
kv – [B, S_kv, D] bf16
attn_sink – [H] fp32
o – [B, S, H, D] bf16 (forward output)
do – [B, S, H, D] bf16 (grad of output)
topk_idxs – [B, S, topk] int32
lse – [B, S, H] fp32 (log-sum-exp from forward)
sm_scale – float or None
- Returns:
[B, S, H, D] bf16 dkv: [B, S_kv, D] bf16 d_attn_sink: [H] fp32
- Return type:
dq