SDPA Backward (SM100, D=256)#
This is an experimental API and subject to change.
Overview#
SDPA backward pass for D=256 on NVIDIA Blackwell GPUs (SM100).
Computes attention gradients (dQ, dK, dV) for scaled dot-product attention using a 2-kernel CUTE DSL implementation.
Given forward tensors and statistics (Q, K, V, O, dO, LSE), this API returns:
dQ: gradient w.r.t. querydK: gradient w.r.t. keydV: gradient w.r.t. value
This is available through a standalone API (detailed below) or as part of the experimental SDPA Pytorch custom operator.
Acknowledgments#
This kernel was jointly developed by Shengbin Di, Yuxi Chi, and Linfeng Zheng in close collaboration with Alibaba. We would like to extend special thanks to the core contributors from Alibaba: Siyu Wang, Haoyan Huang, Lanbo Li, Yun Zhong, Man Yuan, Minmin Sun, Yong Li, and Wei Lin for their significant contributions to this work.
Equations#
For attention probabilities
\( P = \operatorname{softmax}(S), \quad S = \frac{QK^T}{\sqrt{D}} + \text{mask} \)
the backward pass follows:
\( dV = P^T dO,\quad dP = dO V^T,\quad dS = P \odot \left(dP - \sum(dP \odot P)\right),\quad dQ = dS K,\quad dK = dS^T Q \)
Here, the sum(dP \odot P) term is reduced row-wise over the key/softmax dimension for each query row.
API Usage#
High-level wrapper#
import torch
from cudnn import sdpa_bwd_wrapper_sm100_d256
result = sdpa_bwd_wrapper_sm100_d256(
q_tensor=q,
k_tensor=k,
v_tensor=v,
o_tensor=o,
do_tensor=do,
lse_tensor=lse,
cum_seqlen_q_tensor=cum_seqlen_q, # None for B,H,S,D layout
cum_seqlen_k_tensor=cum_seqlen_k, # None for B,H,S,D layout
max_s_q=1024, # Optional; inferred from cum_seqlen_q_tensor when None
max_s_k=1024, # Optional; inferred from cum_seqlen_k_tensor when None
acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
dkdv_mma_tiler_mn=(128, 64),
is_causal=False,
window_size=(-1, -1),
scale_softmax=None, # Defaults to 1/sqrt(D)
current_stream=None,
)
dq, dk, dv = result
# Key access: result["dq_tensor"], result["dk_tensor"], result["dv_tensor"]
Class API#
import torch
from cudnn import SdpabwdSm100D256
sdpa_bwd = SdpabwdSm100D256(
sample_q=q,
sample_k=k,
sample_v=v,
sample_o=o,
sample_do=do,
sample_lse=lse,
sample_dq=dq,
sample_dk=dk,
sample_dv=dv,
sample_cum_seqlen_q=cum_seqlen_q, # None for B,H,S,D layout
sample_cum_seqlen_k=cum_seqlen_k, # None for B,H,S,D layout
max_s_q=1024, # Optional; inferred from sample_cum_seqlen_q when None
max_s_k=1024, # Optional; inferred from sample_cum_seqlen_k when None
acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
dkdv_mma_tiler_mn=(128, 64),
is_causal=False,
window_size=(-1, -1),
scale_softmax=None,
)
assert sdpa_bwd.check_support()
sdpa_bwd.compile()
sdpa_bwd.execute(
q_tensor=q,
k_tensor=k,
v_tensor=v,
o_tensor=o,
do_tensor=do,
lse_tensor=lse,
dq_tensor=dq,
dk_tensor=dk,
dv_tensor=dv,
cum_seqlen_q_tensor=cum_seqlen_q, # None for B,H,S,D layout
cum_seqlen_k_tensor=cum_seqlen_k, # None for B,H,S,D layout
scale_softmax=None,
current_stream=None,
)
Parameters#
Input/Output tensors#
Input tensor Q:
q_tensor(wrapper) orsample_q/q_tensor(class)B,H,S,D layout: shape
(B, H_q, S_q, D)T,H,D layout: shape
(T_q, H_q, D)(or(1, T_q, H_q, D))Dtype:
{torch.float16, torch.bfloat16}
Input tensor K:
k_tensor(wrapper) orsample_k/k_tensor(class)B,H,S,D layout: shape
(B, H_k, S_k, D)T,H,D layout: shape
(T_k, H_k, D)(or(1, T_k, H_k, D))Dtype: must match
Q
Input tensor V:
v_tensor(wrapper) orsample_v/v_tensor(class)B,H,S,D layout: shape
(B, H_k, S_k, D)T,H,D layout: shape
(T_k, H_k, D)(or(1, T_k, H_k, D))Dtype: must match
Q
Input tensor O:
o_tensor(wrapper) orsample_o/o_tensor(class)Same layout and shape family as
QDtype: must match
Q
Input tensor dO:
do_tensor(wrapper) orsample_do/do_tensor(class)Same layout and shape family as
QDtype: must match
Q
Input tensor LSE:
lse_tensor(wrapper) orsample_lse/lse_tensor(class)B,H,S,D layout: shape
(B, H_q, S_q)(must be contiguous)T,H,D layout: shape
(T_q, H_q)(or(T_q, H_q, 1))Dtype:
torch.float32
Output tensor dQ
Wrapper:
result["dq_tensor"]Class:
sample_dq/dq_tensorShape family: same as
QDtype: must match
Q
Output tensor dK
Wrapper:
result["dk_tensor"]Class:
sample_dk/dk_tensorShape family: same as
KDtype: must match
Q
Output tensor dV
Wrapper:
result["dv_tensor"]Class:
sample_dv/dv_tensorShape family: same as
VDtype: must match
Q
Varlen cumulative lengths (T,H,D layout only)
cum_seqlen_q_tensor,cum_seqlen_k_tensorShape:
(batch_size + 1,)Dtype:
torch.int32
Common parameters#
acc_dtype: torch.dtypeAccumulator dtype
Supported:
torch.float32only
mma_tiler_mn: Tuple[int, int]dQ kernel MMA tile
Supported:
(128, 128)only
dkdv_mma_tiler_mn: Tuple[int, int]dK/dV kernel MMA tile
Supported:
(128, 64)only
is_causal: boolEnables causal masking
If
True, the right window bound is forced to0
window_size: Tuple[int, int]Sliding-window bounds
(left, right)Default:
(-1, -1)(full window)
scale_softmax: Optional[float]Softmax scale
Default:
1/sqrt(D)when omitted or set to0.0
CUDA stream (
current_streamin both wrapper and class API)Default:
None(uses default stream)When provided, both workspace initialization and kernel launch are enqueued on that stream
Wrapper-specific parameters: sdpa_bwd_wrapper_sm100_d256#
q_tensor,k_tensor,v_tensor,o_tensor,do_tensor,lse_tensorcum_seqlen_q_tensor,cum_seqlen_k_tensormax_s_q,max_s_k(Optional[int], defaultNone; for varlen mode these are inferred from cumulative lengths when omitted)Common parameters listed above
Wrapper return values#
Returns a TupleDict with keys:
dq_tensordk_tensordv_tensor
Tuple unpacking order is: (dq_tensor, dk_tensor, dv_tensor).
Class-specific parameters: SdpabwdSm100D256#
SdpabwdSm100D256 (constructor)#
sample_q,sample_k,sample_v,sample_o,sample_do,sample_lsesample_dq,sample_dk,sample_dvsample_cum_seqlen_q,sample_cum_seqlen_kmax_s_q,max_s_k(Optional[int], defaultNone; for varlen mode these are inferred from sample cumulative lengths when omitted)Common parameters listed above
SdpabwdSm100D256.execute#
Runtime tensors for all inputs/outputs listed above
scale_softmax,current_stream
Support Surface and Constraints#
Shapes and head geometry#
D_qkmust equalD_vhead_dim(D) must be exactly256H_qmust be divisible byH_k(supports GQA/MQA-style head grouping)
Layout and stride constraints#
B,H,S,D mode (
Q/K/V/O/dO/dQ/dK/dV):both ‘cum_seqlen_q_tensor’ and ‘cum_seqlen_k_tensor’ must be ‘None’
Tensors must be rank-4 with
(B, H, S, D)shape semanticsStride order must match
s,h,d,b(as produced by views like(B, S, H, D).transpose(1, 2))LSEmust be contiguous
T,H,D mode (
Q/K/V/O/dO/dQ/dK/dV):both ‘cum_seqlen_q_tensor’ and ‘cum_seqlen_k_tensor’ must be provided
Tensors must be rank-3
(T, H, D)or rank-4(1, T, H, D)If rank-4 is used, batch dimension must be
1If constructor
max_s_q/max_s_kare provided, they must be greater than or equal to the maxima implied by sample cumulative lengths
Dtypes#
Q,K,V,O,dO,dQ,dK,dVmust have the same dtype. Supported dtypes:torch.float16torch.bfloat16
LSEandacc_dtypemust betorch.float32cum_seqlen_q_tensorandcum_seqlen_k_tensormust betorch.int32and have matching shape
Window/mask behavior#
window_size_left < s_k_max - 1window_size_right < s_q_max - 1Non-causal mode currently supports only full-window configuration:
window_size == (-1, -1)
Hardware requirements#
CUDA must be available
All tensors must be CUDA tensors on the same device (including
cum_seqlen_q_tensor/cum_seqlen_k_tensorwhen provided)Requires SM100+ compute capability
SM103 is not supported
Usage Examples#
For runnable examples and reference-comparison checks, see:
test/python/fe_api/test_sdpa_bwd.pytest/python/fe_api/test_sdpa_bwd_utils.py