fused_attn.h

Enums and functions for fused attention.

Enums

enum NVTE_QKV_Layout

Memory layouts of QKV tensors. S, B, H, D, and T stand for sequence length, batch size, number of heads, head size, and the total number of sequences in a batch, i.e. t = sum(s_i) for i = 0...b-1. SBHD and BSHD-based layouts are used when sequences in a batch are of equal length or padded to the same length, and THD-based layouts are used when sequences have different lengths in a batch.

Values:

enumerator NVTE_SB3HD

SB3HD layout

enumerator NVTE_SBH3D

SBH3D layout

enumerator NVTE_SBHD_SB2HD

SBHD_SB2HD layout

enumerator NVTE_SBHD_SBH2D

SBHD_SBH2D layout

enumerator NVTE_SBHD_SBHD_SBHD

SBHD_SBHD_SBHD layout

enumerator NVTE_BS3HD

BS3HD layout

enumerator NVTE_BSH3D

BSH3D layout

enumerator NVTE_BSHD_BS2HD

BSHD_BS2HD layout

enumerator NVTE_BSHD_BSH2D

BSHD_BSH2D layout

enumerator NVTE_BSHD_BSHD_BSHD

BSHD_BSHD_BSHD layout

enumerator NVTE_T3HD

T3HD layout

enumerator NVTE_TH3D

TH3D layout

enumerator NVTE_THD_T2HD

THD_T2HD layout

enumerator NVTE_THD_TH2D

THD_TH2D layout

enumerator NVTE_THD_THD_THD

THD_THD_THD layout

enum NVTE_QKV_Layout_Group

QKV layout groups.

Values:

enumerator NVTE_3HD

3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD

enumerator NVTE_H3D

H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D

enumerator NVTE_HD_2HD

HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD

enumerator NVTE_HD_H2D

HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D

enumerator NVTE_HD_HD_HD

HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD

enum NVTE_QKV_Format

QKV formats.

Values:

enumerator NVTE_SBHD

SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD

enumerator NVTE_BSHD

BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD

enumerator NVTE_THD

THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD

enum NVTE_Bias_Type

Bias types.

Values:

enumerator NVTE_NO_BIAS

No bias

enumerator NVTE_PRE_SCALE_BIAS

Bias before scale

enumerator NVTE_POST_SCALE_BIAS

Bias after scale

enumerator NVTE_ALIBI

ALiBi

enum NVTE_Mask_Type

Attention mask types.

Values:

enumerator NVTE_NO_MASK

No masking

enumerator NVTE_PADDING_MASK

Padding attention mask

enumerator NVTE_CAUSAL_MASK

Causal attention mask

enumerator NVTE_PADDING_CAUSAL_MASK

Padding and causal attention mask

enum NVTE_Fused_Attn_Backend

Fused attention backends.

Values:

enumerator NVTE_No_Backend

No supported backend

enumerator NVTE_F16_max512_seqlen

cuDNN-based FP16/BF16 fused attention for <= 512 sequence length

enumerator NVTE_F16_arbitrary_seqlen

cuDNN-based FP16/BF16 fused attention for any sequence length

enumerator NVTE_FP8

cuDNN-based FP8 fused attention for <= 512 sequence length

Functions

NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout)

Get QKV layout group for a given QKV layout.

Parameters

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns

qkv layout group, e.g. h3d.

NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout)

Get QKV format for a given QKV layout.

Parameters

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns

qkv format, e.g. sbhd.

NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim)

Get fused attention backend based on input parameters.

Parameters
  • q_dtype[in] The data type of Tensor Q.

  • kv_dtype[in] The data type of Tensors K, V.

  • qkv_layout[in] The layout of Tensors Q, K, V.

  • bias_type[in] The attention bias type.

  • attn_mask_type[in] The attention mask type.

  • dropout[in] The dropout probability.

  • num_attn_heads[in] The number of heads in Q.

  • num_gqa_groups[in] The number of heads in K, V.

  • max_seqlen_q[in] The sequence length of Q.

  • max_seqlen_kv[in] The sequence length of K, V.

  • head_dim[in] The head dimension of Q, K, V.

void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed QKV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Parameters
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed QKV input.

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Parameters
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQKV[out] The gradient of the QKV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens[in] Accumulative sequence lengths, [batch_size + 1].

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed KV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Parameters
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in 2HD or H2D layouts.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed KV input.

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Parameters
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in H2D or 2HD layouts.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dKV[out] The gradient of the KV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Accumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Accumulative sequence lengths for KV, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with separate Q, K and V.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Parameters
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with separate Q, K and V.

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Parameters
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dK[out] The gradient of the K tensor.

  • dV[out] The gradient of the V tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.