fused_attn.h

Enums

enum NVTE_QKV_Layout

Values:

enumerator NVTE_NOT_INTERLEAVED

separate Q, K, V tensors: Q: [total_seqs_q, num_heads, head_dim] | Q Q Q … Q | ___________ _____________/ total_seqs_q <| \/ | num_heads * head_dim K: [total_seqs_kv, num_heads, head_dim] | K K K … K | ___________ _____________/ total_seqs_kv <| \/ | num_heads * head_dim V: [total_seqs_kv, num_heads, head_dim] | V V V … V | ___________ _____________/ total_seqs_kv <| \/ | num_heads * head_dim packed QKV tensor: QKV: [total_seqs, 3, num_heads, head_dim] | Q Q Q … Q K K K … K V V V … V | ___________ _____________/ total_seqs <| \/ | num_heads * head_dim

enumerator NVTE_QKV_INTERLEAVED

Q and packed KV tensor: Q: [total_seqs_q, num_heads, head_dim] | Q Q Q … Q | ___________ _____________/ total_seqs_q <| \/ | num_heads * head_dim KV: [total_seqs_kv, 2, num_heads, head_dim] | K K K … K V V V … V | ___________ _____________/ total_seqs_kv <| \/ | num_heads * head_dim

enumerator NVTE_KV_INTERLEAVED
enum NVTE_Bias_Type

Values:

enumerator NVTE_NO_BIAS

no bias

enumerator NVTE_PRE_SCALE_BIAS

bias before scale

enumerator NVTE_POST_SCALE_BIAS

bias after scale

enum NVTE_Mask_Type

Values:

enumerator NVTE_PADDING_MASK

padding attention mask

enumerator NVTE_CAUSAL_MASK

causal attention mask

enumerator NVTE_NO_MASK

no masking

Functions

void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_Output_Tensors, const NVTETensor cu_seqlens, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed QKV input.

Computes:

  • P = Q * K.T + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * V.T

Support Matrix: | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |

Parameters
  • QKV[in] The QKV tensor in packed format, [total_seqs, 3, num_heads, head_dim].

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_Output_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(cu_seqlens).

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed QKV input.

Support Matrix: | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP8 | QKV_INTERLEAVED | NO_BIAS | PADDING | Yes | <= 512 | 64 | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |

Parameters
  • QKV[in] The QKV tensor in packed format, [total_seqs, 3, num_heads, head_dim].

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from forward when in training mode.

  • dQKV[out] The gradient of the QKV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(cu_seqlens).

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_Output_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed KV input.

Computes:

  • P = Q * K.T + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * V.T

Support Matrix: | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |

Parameters
  • Q[in] The Q tensor, [total_seqs_q, num_heads, head_dim].

  • KV[in] The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_Output_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in]

    Max sequence length used for computing for Q.

    it may be >= max(cu_seqlens_q).

  • max_seqlen_kv[in]

    Max sequence length used for computing for KV.

    it may be >= max(cu_seqlens_kv).

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed KV input.

Support Matrix: | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | FP16/BF16 | QKV_INTERLEAVED | NO_BIAS/POST_SCALE_BIAS | PADDING/CAUSAL | No | <= 512 | 64 |

Parameters
  • Q[in] The Q tensor, [total_seqs_q, num_heads, head_dim].

  • KV[in] The KV tensor, [total_seqs_kv, 2, num_heads, head_dim].

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from forward when in training mode.

  • dQ[out] The gradient of the Q tensor.

  • dKV[out] The gradient of the KV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • max_seqlen_q[in]

    Max sequence length used for computing for Q.

    it may be >= max(cu_seqlens_q).

  • max_seqlen_kv[in]

    Max sequence length used for computing for KV.

    it may be >= max(cu_seqlens_kv).

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.