SDPA Forward (SM100, D=256)#
This is an experimental API and subject to change.
Overview#
SDPA forward pass for D=256 on NVIDIA Blackwell GPUs (SM100+).
Computes the attention output O and log-sum-exp statistics LSE using a CUTE DSL implementation.
This is available through a standalone API (documented below) and is also used by the experimental SDPA PyTorch custom operator for supported D=256 BHSD cases.
API Usage#
High-level wrapper#
import torch
from cudnn import sdpa_fwd_wrapper_sm100_d256
result = sdpa_fwd_wrapper_sm100_d256(
q_tensor=q,
k_tensor=k,
v_tensor=v,
cum_seqlen_q_tensor=cum_seqlen_q, # None for B,H,S,D layout
cum_seqlen_k_tensor=cum_seqlen_k, # None for B,H,S,D layout
max_s_q=1024, # Optional; inferred from cum_seqlen_q_tensor when None
max_s_k=1024, # Optional; inferred from cum_seqlen_k_tensor when None
qk_acc_dtype=torch.float32,
pv_acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
is_causal=False,
window_size=(-1, -1),
scale_softmax=None, # Defaults to 1/sqrt(D)
scale_output=1.0,
current_stream=None,
)
o_tensor, lse_tensor = result
# Key access: result["o_tensor"], result["lse_tensor"]
Class API#
import torch
from cudnn import SdpafwdSm100D256
sdpa_fwd = SdpafwdSm100D256(
sample_q=q,
sample_k=k,
sample_v=v,
sample_o=o,
sample_lse=lse,
sample_cum_seqlen_q=cum_seqlen_q, # None for B,H,S,D layout
sample_cum_seqlen_k=cum_seqlen_k, # None for B,H,S,D layout
max_s_q=1024,
max_s_k=1024,
qk_acc_dtype=torch.float32,
pv_acc_dtype=torch.float32,
mma_tiler_mn=(128, 128),
is_causal=False,
window_size=(-1, -1),
scale_softmax=None,
scale_output=1.0,
)
assert sdpa_fwd.check_support()
sdpa_fwd.compile()
sdpa_fwd.execute(
q_tensor=q,
k_tensor=k,
v_tensor=v,
o_tensor=o,
lse_tensor=lse,
cum_seqlen_q_tensor=cum_seqlen_q,
cum_seqlen_k_tensor=cum_seqlen_k,
scale_softmax=None,
scale_output=1.0,
current_stream=None,
)
Parameters#
Input / output tensors#
QB,H,S,D layout:
(B, H_q, S_q, D)T,H,D layout:
(T_q, H_q, D)or(1, T_q, H_q, D)Dtype:
torch.float16ortorch.bfloat16
KB,H,S,D layout:
(B, H_k, S_k, D)T,H,D layout:
(T_k, H_k, D)or(1, T_k, H_k, D)Dtype: must match
Q
VSame layout family and dtype constraints as
K
OSame layout family and dtype constraints as
Q
LSEB,H,S,D layout:
(B, H_q, S_q)and must be contiguousT,H,D layout:
(T_q, H_q)or(T_q, H_q, 1)Dtype:
torch.float32
Varlen cumulative lengths (T,H,D layout only)
cum_seqlen_q_tensor,cum_seqlen_k_tensorShape:
(batch_size + 1,)Dtype:
torch.int32
Common parameters#
qk_acc_dtype,pv_acc_dtypeSupported:
torch.float32only
mma_tiler_mnSupported:
(128, 128)only
is_causalIf
True, the right window bound is forced to0
window_sizeSliding-window bounds
(left, right)Default:
(-1, -1)(full window)
scale_softmaxDefaults to
1 / sqrt(D)when omitted or set to0.0
scale_outputOutput scaling applied after the attention reduction
Default:
1.0
Wrapper return values#
Returns a TupleDict with keys:
o_tensorlse_tensor
Tuple unpacking order is (o_tensor, lse_tensor).
Support surface and constraints#
head_dimmust be exactly256H_qmust be divisible byH_kD_qkmust equalD_vRequires NVIDIA Blackwell (
SM100+)window_size_left < s_k_max - 1window_size_right < s_q_max - 1Non-causal mode currently supports only full-window configuration:
window_size == (-1, -1)
The standalone API supports both BHSD and THD-style varlen inputs
The experimental SDPA PyTorch custom op currently routes only the plain BHSD
D=256cases through this OSS forward kernel