fused_attn.h

Enums and functions for fused attention.

Enums

enum NVTE_QKV_Layout

Memory layouts of QKV tensors. S, B, H, D, and T stand for sequence length, batch size, number of heads, head size, and the total number of sequences in a batch, i.e. t = sum(s_i) for i = 0...b-1. SBHD and BSHD-based layouts are used when sequences in a batch are of equal length or padded to the same length, and THD-based layouts are used when sequences have different lengths in a batch.

Values:

enumerator NVTE_SB3HD

SB3HD layout

enumerator NVTE_SBH3D

SBH3D layout

enumerator NVTE_SBHD_SB2HD

SBHD_SB2HD layout

enumerator NVTE_SBHD_SBH2D

SBHD_SBH2D layout

enumerator NVTE_SBHD_SBHD_SBHD

SBHD_SBHD_SBHD layout

enumerator NVTE_BS3HD

BS3HD layout

enumerator NVTE_BSH3D

BSH3D layout

enumerator NVTE_BSHD_BS2HD

BSHD_BS2HD layout

enumerator NVTE_BSHD_BSH2D

BSHD_BSH2D layout

enumerator NVTE_BSHD_BSHD_BSHD

BSHD_BSHD_BSHD layout

enumerator NVTE_T3HD

T3HD layout

enumerator NVTE_TH3D

TH3D layout

enumerator NVTE_THD_T2HD

THD_T2HD layout

enumerator NVTE_THD_TH2D

THD_TH2D layout

enumerator NVTE_THD_THD_THD

THD_THD_THD layout

enum NVTE_QKV_Layout_Group

QKV layout groups.

Values:

enumerator NVTE_3HD

3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD

enumerator NVTE_H3D

H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D

enumerator NVTE_HD_2HD

HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD

enumerator NVTE_HD_H2D

HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D

enumerator NVTE_HD_HD_HD

HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD

enum NVTE_QKV_Format

QKV formats.

Values:

enumerator NVTE_SBHD

SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD

enumerator NVTE_BSHD

BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD

enumerator NVTE_THD

THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD

enum NVTE_Bias_Type

Bias types.

Values:

enumerator NVTE_NO_BIAS

No bias

enumerator NVTE_PRE_SCALE_BIAS

Bias before scale

enumerator NVTE_POST_SCALE_BIAS

Bias after scale

enumerator NVTE_ALIBI

ALiBi

enum NVTE_Mask_Type

Attention mask types.

Values:

enumerator NVTE_NO_MASK

No masking

enumerator NVTE_PADDING_MASK

Padding attention mask

enumerator NVTE_CAUSAL_MASK

Causal attention mask

enumerator NVTE_PADDING_CAUSAL_MASK

Padding and causal attention mask

enum NVTE_Fused_Attn_Backend

Fused attention backends.

Values:

enumerator NVTE_No_Backend

No supported backend

enumerator NVTE_F16_max512_seqlen

cuDNN-based FP16/BF16 fused attention for <= 512 sequence length

enumerator NVTE_F16_arbitrary_seqlen

cuDNN-based FP16/BF16 fused attention for any sequence length

enumerator NVTE_FP8

cuDNN-based FP8 fused attention for <= 512 sequence length

Functions

NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout)

Get QKV layout group for a given QKV layout.

Parameters

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns

qkv layout group, e.g. h3d.

NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout)

Get QKV format for a given QKV layout.

Parameters

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns

qkv format, e.g. sbhd.

NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim)

Get fused attention backend based on input parameters.

Parameters
  • q_dtype[in] The data type of Tensor Q.

  • kv_dtype[in] The data type of Tensors K, V.

  • qkv_layout[in] The layout of Tensors Q, K, V.

  • bias_type[in] The attention bias type.

  • attn_mask_type[in] The attention mask type.

  • dropout[in] The dropout probability.

  • num_attn_heads[in] The number of heads in Q.

  • num_gqa_groups[in] The number of heads in K, V.

  • max_seqlen_q[in] The sequence length of Q.

  • max_seqlen_kv[in] The sequence length of K, V.

  • head_dim[in] The head dimension of Q, K, V.

void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed QKV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensor cu_seqlens_padded helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, the offset tensor is not used in the attention calculation and can be set to empty NVTETensor. When the QKV format is thd, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal to cu_seqlens, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • cu_seqlens_padded[in] Cumulative sequence offsets for QKV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed QKV input.

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensor cu_seqlens_padded helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, the offset tensor is not used in the attention calculation and can be set to empty NVTETensor. When the QKV format is thd, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal to cu_seqlens, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQKV[out] The gradient of the QKV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • cu_seqlens_padded[in] Cumulative sequence offsets for QKV, [batch_size + 1].

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed KV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in 2HD or H2D layouts.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for KV, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed KV input.

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in H2D or 2HD layouts.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dKV[out] The gradient of the KV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for KV, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with separate Q, K and V.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with separate Q, K and V.

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dK[out] The gradient of the K tensor.

  • dV[out] The gradient of the V tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.