fused_attn.h
Enums and functions for fused attention.
Enums
-
enum NVTE_QKV_Layout
Memory layouts of QKV tensors.
S
,B
,H
,D
, andT
stand for sequence length, batch size, number of heads, head size, and the total number of tokens in a batch, i.e.t = sum(s_i) for i = 0...b-1
.SBHD
andBSHD
-based layouts are used when sequences in a batch are of equal length or padded to the same length, andTHD
-based layouts are used when sequences have different lengths in a batch.Values:
-
enumerator NVTE_SB3HD
SB3HD layout
-
enumerator NVTE_SBH3D
SBH3D layout
-
enumerator NVTE_SBHD_SB2HD
SBHD_SB2HD layout
-
enumerator NVTE_SBHD_SBH2D
SBHD_SBH2D layout
-
enumerator NVTE_SBHD_SBHD_SBHD
SBHD_SBHD_SBHD layout
-
enumerator NVTE_BS3HD
BS3HD layout
-
enumerator NVTE_BSH3D
BSH3D layout
-
enumerator NVTE_BSHD_BS2HD
BSHD_BS2HD layout
-
enumerator NVTE_BSHD_BSH2D
BSHD_BSH2D layout
-
enumerator NVTE_BSHD_BSHD_BSHD
BSHD_BSHD_BSHD layout
-
enumerator NVTE_T3HD
T3HD layout
-
enumerator NVTE_TH3D
TH3D layout
-
enumerator NVTE_THD_T2HD
THD_T2HD layout
-
enumerator NVTE_THD_TH2D
THD_TH2D layout
-
enumerator NVTE_THD_THD_THD
THD_THD_THD layout
-
enumerator NVTE_SB3HD
-
enum NVTE_QKV_Layout_Group
QKV layout groups.
Values:
-
enumerator NVTE_3HD
3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD
-
enumerator NVTE_H3D
H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D
-
enumerator NVTE_HD_2HD
HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD
-
enumerator NVTE_HD_H2D
HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D
-
enumerator NVTE_HD_HD_HD
HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD
-
enumerator NVTE_3HD
-
enum NVTE_QKV_Format
QKV formats.
Values:
-
enumerator NVTE_SBHD
SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD
-
enumerator NVTE_BSHD
BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD
-
enumerator NVTE_THD
THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD
-
enumerator NVTE_SBHD
-
enum NVTE_Bias_Type
Bias types.
Values:
-
enumerator NVTE_NO_BIAS
No bias
-
enumerator NVTE_PRE_SCALE_BIAS
Bias before scale
-
enumerator NVTE_POST_SCALE_BIAS
Bias after scale
-
enumerator NVTE_ALIBI
ALiBi
-
enumerator NVTE_NO_BIAS
-
enum NVTE_Mask_Type
Attention mask types.
Values:
-
enumerator NVTE_NO_MASK
No masking
-
enumerator NVTE_PADDING_MASK
Padding attention mask
-
enumerator NVTE_CAUSAL_MASK
Causal attention mask (aligned to the top left corner)
-
enumerator NVTE_PADDING_CAUSAL_MASK
Padding and causal attention mask (aligned to the top left corner)
-
enumerator NVTE_CAUSAL_BOTTOM_RIGHT_MASK
Causal attention mask (aligned to the bottom right corner)
-
enumerator NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK
Padding and causal attention mask (aligned to the bottom right corner)
-
enumerator NVTE_NO_MASK
-
enum NVTE_Fused_Attn_Backend
Fused attention backends.
Values:
-
enumerator NVTE_No_Backend
No supported backend
-
enumerator NVTE_F16_max512_seqlen
cuDNN-based FP16/BF16 fused attention for <= 512 sequence length
-
enumerator NVTE_F16_arbitrary_seqlen
cuDNN-based FP16/BF16 fused attention for any sequence length
-
enumerator NVTE_FP8
cuDNN-based FP8 fused attention for <= 512 sequence length
-
enumerator NVTE_No_Backend
Functions
-
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout)
Get QKV layout group for a given QKV layout.
- Parameters:
qkv_layout – [in] QKV layout, e.g. sbh3d.
- Returns:
qkv layout group, e.g. h3d.
-
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout)
Get QKV format for a given QKV layout.
- Parameters:
qkv_layout – [in] QKV layout, e.g. sbh3d.
- Returns:
qkv format, e.g. sbhd.
-
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right)
Get fused attention backend based on input parameters.
- Parameters:
q_dtype – [in] The data type of Tensor Q.
kv_dtype – [in] The data type of Tensors K, V.
qkv_layout – [in] The layout of Tensors Q, K, V.
bias_type – [in] The attention bias type.
attn_mask_type – [in] The attention mask type.
dropout – [in] The dropout probability.
num_attn_heads – [in] The number of heads in Q.
num_gqa_groups – [in] The number of heads in K, V.
max_seqlen_q – [in] The sequence length of Q.
max_seqlen_kv – [in] The sequence length of K, V.
head_dim_qk – [in] The head dimension of Q, K.
head_dim_v – [in] The head dimension of V.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
-
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)
Compute dot product attention with packed QKV input.
Computes:
P = Q * Transpose(K) + Bias
S = ScaleMaskSoftmax(P)
D = Dropout(S)
O = D * Transpose(V)
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
Notes:
Tensor
cu_seqlens_padded
helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, the offset tensor is not used in the attention calculation and can be set to emptyNVTETensor
. When the QKV format isthd
, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal tocu_seqlens
, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
QKV – [in] The QKV tensor in packed format, H3D or 3HD.
Bias – [in] The Bias tensor.
S – [inout] The S tensor.
O – [out] The output O tensor.
Aux_CTX_Tensors – [out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.
cu_seqlens – [in] Cumulative sequence lengths, [batch_size + 1].
cu_seqlens_padded – [in] Cumulative sequence offsets for QKV, [batch_size + 1].
rng_state – [in] Seed and offset of CUDA random number generator.
max_seqlen – [in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.
is_training – [in] Whether this is in training mode or inference.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensor’s layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.
-
void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)
Compute the backward of the dot product attention with packed QKV input.
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BS3HD,SB3HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
Notes:
Tensor
cu_seqlens_padded
helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, the offset tensor is not used in the attention calculation and can be set to emptyNVTETensor
. When the QKV format isthd
, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal tocu_seqlens
, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
QKV – [in] The QKV tensor in packed format, H3D or 3HD.
O – [in] The O tensor from forward.
dO – [in] The gradient of the O tensor.
S – [in] The S tensor.
dP – [inout] The gradient of the P tensor.
Aux_CTX_Tensors – [in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.
dQKV – [out] The gradient of the QKV tensor.
dBias – [out] The gradient of the Bias tensor.
cu_seqlens – [in] Cumulative sequence lengths, [batch_size + 1].
cu_seqlens_padded – [in] Cumulative sequence offsets for QKV, [batch_size + 1].
max_seqlen – [in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensor’s layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
deterministic – [in] Whether to execute with deterministic behaviours.
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.
-
void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)
Compute dot product attention with packed KV input.
Computes:
P = Q * Transpose(K) + Bias
S = ScaleMaskSoftmax(P)
D = Dropout(S)
O = D * Transpose(V)
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
Notes:
Tensors
cu_seqlens_q_padded
andcu_seqlens_kv_padded
help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, offset tensors are not used in the attention calculation and can be set to emptyNVTETensor
s. When the QKV format isthd
, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal tocu_seqlens_q
andcu_seqlens_kv
respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
Q – [in] The Q tensor, in HD layouts.
KV – [in] The KV tensor, in 2HD or H2D layouts.
Bias – [in] The Bias tensor.
S – [inout] The S tensor.
O – [out] The output O tensor.
Aux_CTX_Tensors – [out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.
cu_seqlens_q – [in] Cumulative sequence lengths for Q, [batch_size + 1].
cu_seqlens_kv – [in] Cumulative sequence lengths for KV, [batch_size + 1].
cu_seqlens_q_padded – [in] Cumulative sequence offsets for Q, [batch_size + 1].
cu_seqlens_kv_padded – [in] Cumulative sequence offsets for KV, [batch_size + 1].
rng_state – [in] Seed and offset of CUDA random number generator.
max_seqlen_q – [in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.
max_seqlen_kv – [in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.
is_training – [in] Whether this is in training mode or inference.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensor’s layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
deterministic – [in] Whether to execute with deterministic behaviours.
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.
-
void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)
Compute the backward of the dot product attention with packed KV input.
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 |
Notes:
Tensors
cu_seqlens_q_padded
andcu_seqlens_kv_padded
help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, offset tensors are not used in the attention calculation and can be set to emptyNVTETensor
s. When the QKV format isthd
, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal tocu_seqlens_q
andcu_seqlens_kv
respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
Q – [in] The Q tensor, in HD layouts.
KV – [in] The KV tensor, in H2D or 2HD layouts.
O – [in] The O tensor from forward.
dO – [in] The gradient of the O tensor.
S – [in] The S tensor.
dP – [inout] The gradient of the P tensor.
Aux_CTX_Tensors – [in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.
dQ – [out] The gradient of the Q tensor.
dKV – [out] The gradient of the KV tensor.
dBias – [out] The gradient of the Bias tensor.
cu_seqlens_q – [in] Cumulative sequence lengths for Q, [batch_size + 1].
cu_seqlens_kv – [in] Cumulative sequence lengths for KV, [batch_size + 1].
cu_seqlens_q_padded – [in] Cumulative sequence offsets for Q, [batch_size + 1].
cu_seqlens_kv_padded – [in] Cumulative sequence offsets for KV, [batch_size + 1].
max_seqlen_q – [in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.
max_seqlen_kv – [in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensor’s layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
deterministic – [in] Whether to execute with deterministic behaviours.
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.
-
void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)
Compute dot product attention with separate Q, K and V.
Computes:
P = Q * Transpose(K) + Bias
S = ScaleMaskSoftmax(P)
D = Dropout(S)
O = D * Transpose(V)
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
Notes:
Tensors
cu_seqlens_q_padded
andcu_seqlens_kv_padded
help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, offset tensors are not used in the attention calculation and can be set to emptyNVTETensor
s. When the QKV format isthd
, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal tocu_seqlens_q
andcu_seqlens_kv
respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
Q – [in] The Q tensor.
K – [in] The K tensor.
V – [in] The V tensor.
Bias – [in] The Bias tensor.
S – [inout] The S tensor.
O – [out] The output O tensor.
Aux_CTX_Tensors – [out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.
cu_seqlens_q – [in] Cumulative sequence lengths for Q, [batch_size + 1].
cu_seqlens_kv – [in] Cumulative sequence lengths for K and V, [batch_size + 1].
cu_seqlens_q_padded – [in] Cumulative sequence offsets for Q, [batch_size + 1].
cu_seqlens_kv_padded – [in] Cumulative sequence offsets for KV, [batch_size + 1].
rng_state – [in] Seed and offset of CUDA random number generator.
max_seqlen_q – [in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.
max_seqlen_kv – [in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.
is_training – [in] Whether this is in training mode or inference.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensors’ layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.
-
void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)
Compute the backward of the dot product attention with separate Q, K and V.
Support Matrix:
| backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 |
Notes:
Tensors
cu_seqlens_q_padded
andcu_seqlens_kv_padded
help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)
) isbshd
orsbhd
, offset tensors are not used in the attention calculation and can be set to emptyNVTETensor
s. When the QKV format isthd
, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal tocu_seqlens_q
andcu_seqlens_kv
respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences[a, PAD, b, b, c, PAD, PAD, d, d]
should havecu_seqlens = [0, 1, 3, 4, 6]
andcu_seqlens_padded= [0, 2, 4, 7, 9]
.- Parameters:
Q – [in] The Q tensor.
K – [in] The K tensor.
V – [in] The V tensor.
O – [in] The O tensor from forward.
dO – [in] The gradient of the O tensor.
S – [in] The S tensor.
dP – [inout] The gradient of the P tensor.
Aux_CTX_Tensors – [in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.
dQ – [out] The gradient of the Q tensor.
dK – [out] The gradient of the K tensor.
dV – [out] The gradient of the V tensor.
dBias – [out] The gradient of the Bias tensor.
cu_seqlens_q – [in] Cumulative sequence lengths for Q, [batch_size + 1].
cu_seqlens_kv – [in] Cumulative sequence lengths for K and V, [batch_size + 1].
cu_seqlens_q_padded – [in] Cumulative sequence offsets for Q, [batch_size + 1].
cu_seqlens_kv_padded – [in] Cumulative sequence offsets for KV, [batch_size + 1].
max_seqlen_q – [in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.
max_seqlen_kv – [in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.
attn_scale – [in] Scaling factor for Q * K.T.
dropout – [in] Dropout probability.
qkv_layout – [in] QKV tensors’ layout.
bias_type – [in] Bias type.
attn_mask_type – [in] Attention mask type.
window_size_left – [in] Sliding window size (the left half).
window_size_right – [in] Sliding window size (the right half).
deterministic – [in] Whether to execute with deterministic behaviours.
workspace – [in] Workspace tensor.
stream – [in] CUDA stream used for this operation.