fused_attn.h

Enums and functions for fused attention.

Enums

enum NVTE_QKV_Layout

QKV matrix layouts.

Values:

enumerator NVTE_NOT_INTERLEAVED

Separate Q, K, V tensors.

  Q: [total_seqs_q, num_heads, head_dim]
                      | Q   Q   Q        ...       Q
                      | \___________  _____________/
      total_seqs_q   <|             \/
                      |   num_heads * head_dim
  K: [total_seqs_kv, num_heads, head_dim]
                      | K   K   K        ...       K
                      | \___________  _____________/
      total_seqs_kv  <|             \/
                      |   num_heads * head_dim
  V: [total_seqs_kv, num_heads, head_dim]
                      | V   V   V        ...       V
                      | \___________  _____________/
      total_seqs_kv  <|             \/
                      |   num_heads * head_dim

enumerator NVTE_QKV_INTERLEAVED

Packed QKV.

  QKV: [total_seqs, 3, num_heads, head_dim]
                      | Q   Q   Q        ...       Q K K K ... K V V V ... V
                      | \___________  _____________/
        total_seqs   <|             \/
                      |   num_heads * head_dim

enumerator NVTE_KV_INTERLEAVED

Q and packed KV.

   Q: [total_seqs_q, num_heads, head_dim]
                      | Q   Q   Q        ...       Q
                      | \___________  _____________/
       total_seqs_q  <|             \/
                      |   num_heads * head_dim
   KV: [total_seqs_kv, 2, num_heads, head_dim]
                      | K   K   K        ...       K V V V ... V
                      | \___________  _____________/
       total_seqs_kv <|             \/
                      |   num_heads * head_dim

enum NVTE_Bias_Type

Bias types.

Values:

enumerator NVTE_NO_BIAS

No bias

enumerator NVTE_PRE_SCALE_BIAS

Bias before scale

enumerator NVTE_POST_SCALE_BIAS

Bias after scale

enum NVTE_Mask_Type

Attention mask types.

Values:

enumerator NVTE_NO_MASK

No masking

enumerator NVTE_PADDING_MASK

Padding attention mask

enumerator NVTE_CAUSAL_MASK

Causal attention mask

enum NVTE_Fused_Attn_Backend

Values:

enumerator NVTE_No_Backend

No supported backend

enumerator NVTE_F16_max512_seqlen

cuDNN-based FP16/BF16 fused attention for <= 512 sequence length

enumerator NVTE_F16_arbitrary_seqlen

cuDNN-based FP16/BF16 fused attention for any sequence length

enumerator NVTE_FP8

cuDNN-based FP8 fused attention for <= 512 sequence length

Functions

NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim)

Get fused attention backend based on input parameters.

Parameters
  • q_dtype[in] The data type of Tensor Q.

  • kv_dtype[in] The data type of Tensors K, V.

  • qkv_layout[in] The layout of Tensors Q, K, V.

  • bias_type[in] The attention bias type.

  • attn_mask_type[in] The attention mask type.

  • dropout[in] The dropout probability.

  • max_seqlen_q[in] The sequence length of Q.

  • max_seqlen_kv[in] The sequence length of K, V.

  • head_dim[in] The head dimension of Q, K, V.

void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed QKV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |    qkv layout   |          bias           |      mask      | dropout | sequence length | head_dim |
| 0       | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL |   Yes   |     <= 512      |    64    |
| 1       | FP16/BF16 | QKV_INTERLEAVED |         NO_BIAS         |    CAUSAL      |   Yes   |      > 512      |  64, 128 |
| 2       | FP8       | QKV_INTERLEAVED |         NO_BIAS         |    PADDING     |   Yes   |     <= 512      |    64    |

Parameters
  • QKV[in] The QKV tensor in packed format, [total_seqs, 3, num_heads, head_dim].

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(cu_seqlens).

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed QKV input.

Support Matrix:

| backend | precision |    qkv layout   |          bias           |      mask      | dropout | sequence length | head_dim |
| 0       | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL |   Yes   |     <= 512      |    64    |
| 1       | FP16/BF16 | QKV_INTERLEAVED |         NO_BIAS         |    CAUSAL      |   Yes   |      > 512      |  64, 128 |
| 2       | FP8       | QKV_INTERLEAVED |         NO_BIAS         |    PADDING     |   Yes   |     <= 512      |    64    |

Parameters
  • QKV[in] The QKV tensor in packed format, [total_seqs, 3, num_heads, head_dim].

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQKV[out] The gradient of the QKV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(cu_seqlens).

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed KV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |    qkv layout   |          bias           |      mask      | dropout | sequence length | head_dim |
| 0       | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL |   Yes   |     <= 512      |    64    |

Parameters
  • Q[in] The Q tensor, [total_seqs_q, num_heads, head_dim].

  • KV[in] The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in]

    Max sequence length used for computing for Q.

    it may be >= max(cu_seqlens_q).

  • max_seqlen_kv[in]

    Max sequence length used for computing for KV.

    it may be >= max(cu_seqlens_kv).

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed KV input.

Support Matrix:

| backend | precision |    qkv layout   |          bias           |      mask      | dropout | sequence length | head_dim |
| 0       | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL |   Yes   |     <= 512      |    64    |

Parameters
  • Q[in] The Q tensor, [total_seqs_q, num_heads, head_dim].

  • KV[in] The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dKV[out] The gradient of the KV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • max_seqlen_q[in]

    Max sequence length used for computing for Q.

    it may be >= max(cu_seqlens_q).

  • max_seqlen_kv[in]

    Max sequence length used for computing for KV.

    it may be >= max(cu_seqlens_kv).

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.