fused_attn.h

Enums and functions for fused attention.

Enums

enum NVTE_QKV_Layout

Memory layouts of QKV tensors. S, B, H, D, and T stand for sequence length, batch size, number of heads, head size, and the total number of tokens in a batch, i.e. t = sum(s_i) for i = 0...b-1. SBHD and BSHD-based layouts are used when sequences in a batch are of equal length or padded to the same length, and THD-based layouts are used when sequences have different lengths in a batch. Paged_KV-based layouts are used for paged attention.

Values:

enumerator NVTE_SB3HD

SB3HD layout

enumerator NVTE_SBH3D

SBH3D layout

enumerator NVTE_SBHD_SB2HD

SBHD_SB2HD layout

enumerator NVTE_SBHD_SBH2D

SBHD_SBH2D layout

enumerator NVTE_SBHD_SBHD_SBHD

SBHD_SBHD_SBHD layout

enumerator NVTE_BS3HD

BS3HD layout

enumerator NVTE_BSH3D

BSH3D layout

enumerator NVTE_BSHD_BS2HD

BSHD_BS2HD layout

enumerator NVTE_BSHD_BSH2D

BSHD_BSH2D layout

enumerator NVTE_BSHD_BSHD_BSHD

BSHD_BSHD_BSHD layout

enumerator NVTE_T3HD

T3HD layout

enumerator NVTE_TH3D

TH3D layout

enumerator NVTE_THD_T2HD

THD_T2HD layout

enumerator NVTE_THD_TH2D

THD_TH2D layout

enumerator NVTE_THD_THD_THD

THD_THD_THD layout

enumerator NVTE_SBHD_BSHD_BSHD

SBHD_BSHD_BSHD layout

enumerator NVTE_BSHD_SBHD_SBHD

BSHD_SBHD_SBHD layout

enumerator NVTE_THD_BSHD_BSHD

THD_BSHD_BSHD layout

enumerator NVTE_THD_SBHD_SBHD

THD_SBHD_SBHD layout

enumerator NVTE_Paged_KV_BSHD_BSHD_BSHD

Paged_KV_BSHD_BSHD_BSHD layout

enumerator NVTE_Paged_KV_BSHD_SBHD_SBHD

Paged_KV_BSHD_SBHD_SBHD layout

enumerator NVTE_Paged_KV_SBHD_BSHD_BSHD

Paged_KV_SBHD_BSHD_BSHD layout

enumerator NVTE_Paged_KV_SBHD_SBHD_SBHD

Paged_KV_SBHD_SBHD_SBHD layout

enumerator NVTE_Paged_KV_THD_BSHD_BSHD

Paged_KV_THD_BSHD_BSHD layout

enumerator NVTE_Paged_KV_THD_SBHD_SBHD

Paged_KV_THD_SBHD_SBHD layout

enum NVTE_QKV_Layout_Group

QKV layout groups.

Values:

enumerator NVTE_3HD

3HD QKV layouts, i.e. BS3HD, SB3HD, T3HD

enumerator NVTE_H3D

H3D QKV layouts, i.e. BSH3D, SBH3D, TH3D

enumerator NVTE_HD_2HD

HD_2HD QKV layouts, i.e. BSHD_BS2HD, SBHD_SB2HD, THD_T2HD

enumerator NVTE_HD_H2D

HD_H2D QKV layouts, i.e. BSHD_BSH2D, SBHD_SBH2D, THD_TH2D

enumerator NVTE_HD_HD_HD

HD_HD_HD QKV layouts, i.e. BSHD_BSHD_BSHD, SBHD_SBHD_SBHD, THD_THD_THD

enumerator NVTE_Paged_KV_HD_HD_HD

Paged_KV_HD_HD_HD QKV layouts, e.g. Paged_KV_BSHD_BSHD_BSHD, Paged_KV_THD_SBHD_SBHD

enum NVTE_QKV_Format

QKV formats.

Values:

enumerator NVTE_SBHD

SBHD QKV format, i.e. SB3HD, SBH3D, SBHD_SB2HD, SBHD_SBH2D, SBHD_SBHD_SBHD, Paged_KV_SBHD_SBHD_SBHD

enumerator NVTE_BSHD

BSHD QKV format, i.e. BS3HD, BSH3D, BSHD_BS2HD, BSHD_BSH2D, BSHD_BSHD_BSHD, Paged_KV_BSHD_BSHD_BSHD

enumerator NVTE_THD

THD QKV format, i.e. T3HD, TH3D, THD_T2HD, THD_TH2D, THD_THD_THD

enumerator NVTE_BSHD_2SBHD

BSHD format for Q and SBHD format for KV, i.e. BSHD_SBHD_SBHD, Paged_KV_BSHD_SBHD_SBHD

enumerator NVTE_SBHD_2BSHD

SBHD format for Q and BSHD format for KV, i.e. SBHD_BSHD_BSHD, Paged_KV_SBHD_BSHD_BSHD

enumerator NVTE_THD_2BSHD

THD format for Q and BSHD format for KV, i.e. THD_BSHD_BSHD, Paged_KV_THD_BSHD_BSHD

enumerator NVTE_THD_2SBHD

THD format for Q and SBHD format for KV, i.e. THD_SBHD_SBHD, Paged_KV_THD_SBHD_SBHD

enum NVTE_Bias_Type

Bias types.

Values:

enumerator NVTE_NO_BIAS

No bias

enumerator NVTE_PRE_SCALE_BIAS

Bias before scale

enumerator NVTE_POST_SCALE_BIAS

Bias after scale

enumerator NVTE_ALIBI

ALiBi

enum NVTE_Mask_Type

Attention mask types.

Values:

enumerator NVTE_NO_MASK

No masking

enumerator NVTE_PADDING_MASK

Padding attention mask

enumerator NVTE_CAUSAL_MASK

Causal attention mask (aligned to the top left corner)

enumerator NVTE_PADDING_CAUSAL_MASK

Padding and causal attention mask (aligned to the top left corner)

enumerator NVTE_CAUSAL_BOTTOM_RIGHT_MASK

Causal attention mask (aligned to the bottom right corner)

enumerator NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK

Padding and causal attention mask (aligned to the bottom right corner)

enum NVTE_Fused_Attn_Backend

Fused attention backends.

Values:

enumerator NVTE_No_Backend

No supported backend

enumerator NVTE_F16_max512_seqlen

cuDNN-based FP16/BF16 fused attention for <= 512 sequence length

enumerator NVTE_F16_arbitrary_seqlen

cuDNN-based FP16/BF16 fused attention for any sequence length

enumerator NVTE_FP8

cuDNN-based FP8 fused attention for <= 512 sequence length

Functions

NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout)

Get QKV layout group for a given QKV layout.

Parameters:

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns:

qkv layout group, e.g. h3d.

NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout)

Get QKV format for a given QKV layout.

Parameters:

qkv_layout[in] QKV layout, e.g. sbh3d.

Returns:

qkv format, e.g. sbhd.

NVTE_QKV_Format nvte_get_q_format(NVTE_QKV_Layout qkv_layout)

Get Q format for a given QKV layout.

Parameters:

qkv_layout[in] QKV layout, e.g. sbhd_bshd_bshd.

Returns:

q format, e.g. sbhd.

NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout)

Get KV format for a given QKV layout.

Parameters:

qkv_layout[in] QKV layout, e.g. sbhd_bshd_bshd.

Returns:

kv format, e.g. bshd.

NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right)

Get fused attention backend based on input parameters.

Parameters:
  • q_dtype[in] The data type of Tensor Q.

  • kv_dtype[in] The data type of Tensors K, V.

  • qkv_layout[in] The layout of Tensors Q, K, V.

  • bias_type[in] The attention bias type.

  • attn_mask_type[in] The attention mask type.

  • dropout[in] The dropout probability.

  • num_attn_heads[in] The number of heads in Q.

  • num_gqa_groups[in] The number of heads in K, V.

  • max_seqlen_q[in] The sequence length of Q.

  • max_seqlen_kv[in] The sequence length of K, V.

  • head_dim_qk[in] The head dimension of Q, K.

  • head_dim_v[in] The head dimension of V.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed QKV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensor cu_seqlens_padded helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, the offset tensor is not used in the attention calculation and can be set to empty NVTETensor. When the QKV format is thd, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal to cu_seqlens, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • cu_seqlens_padded[in] Cumulative sequence offsets for QKV, [batch_size + 1].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed QKV input.

Support Matrix:

| backend | precision |        qkv layout       |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |       BS3HD,SB3HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|   2     |   FP8     |          T3HD           |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensor cu_seqlens_padded helps identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, the offset tensor is not used in the attention calculation and can be set to empty NVTETensor. When the QKV format is thd, this tensor should follow the following rules. When there is no padding between sequences, the offset tensor should be equal to cu_seqlens, When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • QKV[in] The QKV tensor in packed format, H3D or 3HD.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQKV[out] The gradient of the QKV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • cu_seqlens_padded[in] Cumulative sequence offsets for QKV, [batch_size + 1].

  • max_seqlen[in] Max sequence length used for computing, it may be >= max(seqlen_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • deterministic[in] Whether to execute with deterministic behaviours.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with packed KV input.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in 2HD or H2D layouts.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for KV, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • page_table_k[in] Page table for K cache, [batch_size, max_pages_per_seq_k].

  • page_table_v[in] Page table for V cache, [batch_size, max_pages_per_seq_v].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • deterministic[in] Whether to execute with deterministic behaviours.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with packed KV input.

Support Matrix:

| backend | precision |                 qkv layout                  |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |            BSHD_BS2HD,SBHD_SB2HD            |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • Q[in] The Q tensor, in HD layouts.

  • KV[in] The KV tensor, in H2D or 2HD layouts.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dKV[out] The gradient of the KV tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for KV, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for KV. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensor’s layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • deterministic[in] Whether to execute with deterministic behaviours.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream)

Compute dot product attention with separate Q, K and V.

Computes:

  • P = Q * Transpose(K) + Bias

  • S = ScaleMaskSoftmax(P)

  • D = Dropout(S)

  • O = D * Transpose(V)

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • Bias[in] The Bias tensor.

  • S[inout] The S tensor.

  • O[out] The output O tensor.

  • Aux_CTX_Tensors[out] Auxiliary output tensors when training, e.g. M, ZInv, rng_state.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • page_table_k[in] Page table for K cache, [batch_size, max_pages_per_seq_k].

  • page_table_v[in] Page table for V cache, [batch_size, max_pages_per_seq_v].

  • rng_state[in] Seed and offset of CUDA random number generator.

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • is_training[in] Whether this is in training mode or inference.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream)

Compute the backward of the dot product attention with separate Q, K and V.

Support Matrix:

| backend | precision |                qkv layout                   |           bias           |                 mask                  | dropout |  sequence length  | head_dim         |
|   0     | FP16/BF16 |     BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD       |   NO/POST_SCALE_BIAS     | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   | <= 512, % 64 == 0 |    64            |
|   1     | FP16/BF16 |          BS3HD,SB3HD,BSH3D,SBH3D            | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK |   Yes   |  > 512, % 64 == 0 | <= 128, % 8 == 0 |
|         |           | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D |                          |                                       |         |                   |                  |
|         |           |       BSHD_BSHD_BSHD,SBHD_SBHD_SBHD         |                          |                                       |         |                   |                  |
|   2     |   FP8     |                 T3HD                        |          NO_BIAS         |               PADDING_MASK            |   Yes   | <= 512, % 64 == 0 |    64            |

Notes:

Tensors cu_seqlens_q_padded and cu_seqlens_kv_padded help identify the correct offsets of different sequences in tensors Q, K, V and O. When the QKV format (nvte_get_qkv_format(qkv_layout)) is bshd or sbhd, offset tensors are not used in the attention calculation and can be set to empty NVTETensors. When the QKV format is thd, these tensors should follow the following rules. When there is no padding between sequences, the offset tensors should be equal to cu_seqlens_q and cu_seqlens_kv respectively. When there is padding between sequences, users are responsible to adjust the offsets as needed. For example, a tensor of 4 sequences [a, PAD, b, b, c, PAD, PAD, d, d] should have cu_seqlens = [0, 1, 3, 4, 6] and cu_seqlens_padded= [0, 2, 4, 7, 9].

Parameters:
  • Q[in] The Q tensor.

  • K[in] The K tensor.

  • V[in] The V tensor.

  • O[in] The O tensor from forward.

  • dO[in] The gradient of the O tensor.

  • S[in] The S tensor.

  • dP[inout] The gradient of the P tensor.

  • Aux_CTX_Tensors[in] Auxiliary tensors from context when in training mode, e.g. M, ZInv, rng_state.

  • dQ[out] The gradient of the Q tensor.

  • dK[out] The gradient of the K tensor.

  • dV[out] The gradient of the V tensor.

  • dBias[out] The gradient of the Bias tensor.

  • cu_seqlens_q[in] Cumulative sequence lengths for Q, [batch_size + 1].

  • cu_seqlens_kv[in] Cumulative sequence lengths for K and V, [batch_size + 1].

  • cu_seqlens_q_padded[in] Cumulative sequence offsets for Q, [batch_size + 1].

  • cu_seqlens_kv_padded[in] Cumulative sequence offsets for KV, [batch_size + 1].

  • max_seqlen_q[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • max_seqlen_kv[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • attn_scale[in] Scaling factor for Q * K.T.

  • dropout[in] Dropout probability.

  • qkv_layout[in] QKV tensors’ layout.

  • bias_type[in] Bias type.

  • attn_mask_type[in] Attention mask type.

  • window_size_left[in] Sliding window size (the left half).

  • window_size_right[in] Sliding window size (the right half).

  • deterministic[in] Whether to execute with deterministic behaviours.

  • workspace[in] Workspace tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_populate_rng_state_async(NVTETensor rng_state_dst, const NVTETensor seed, size_t q_max_seqlen, size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend, cudaStream_t stream)

Update the RNG state with the seed and calculated offset.

Warning

This API is experimental and subject to change.

Parameters:
  • rng_state_dst[in] RNG state to store seed and offset.

  • seed[in] Seed for RNG state.

  • q_max_seqlen[in] Max sequence length used for computing for Q. it may be >= max(seqlen_q_i) for i=0,…batch_size-1.

  • kv_max_seqlen[in] Max sequence length used for computing for K and V. it may be >= max(seqlen_kv_i) for i=0,…batch_size-1.

  • backend[in] Fused attention backend.

  • stream[in] CUDA stream used for this operation.

uint32_t nvte_get_runtime_num_segments(NVTETensor cu_seqlen, NVTETensor workspace, size_t len, cudaStream_t stream)

Get KV format for a given QKV layout.

Warning

This API is experimental and subject to change.

Parameters:
  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • workspace[in] Workspace tensor.

  • len[in] batch_size x sequence_length.

  • stream[in] CUDA stream used for this operation.

void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t *seed_ptr, uint64_t seed_val, int64_t *offset_ptr, uint64_t offset_val, uint32_t offset_intragraph, cudaStream_t stream)

Set the seed and offset for RNG state.

Warning

This API is experimental and subject to change.

Parameters:
  • rng_state_ptr[out] A size 2 array storing the RNG’s seed and offset respectively.

  • captured[in] Whether a CUDA graph is being captured.

  • seed_ptr[in] Seed pointer.

  • seed_val[in] Seed value.

  • offset_ptr[in] Offset pointer.

  • offset_val[in] Offset value.

  • offset_intragraph[in] Intragraph offset in RNG states. For use with CUDA Graphs.

  • stream[in] CUDA stream used for this operation.

void nvte_copy_to_kv_cache(NVTETensor new_k, NVTETensor new_v, NVTETensor k_cache, NVTETensor v_cache, NVTETensor page_table, NVTETensor cu_new_lens, NVTETensor cu_cached_lens, NVTE_QKV_Format qkv_format, int b, int max_ctx_len, int max_seq_len, int max_pages_per_seq, int is_non_paged, cudaStream_t stream)

Copy keys and values into the KV cache.

Warning

This API is experimental and subject to change.

Parameters:
  • new_k[in] Key tensor.

  • new_v[in] Value tensor.

  • k_cache[out] Key cache.

  • v_cache[out] Value cache.

  • page_table[in] Page table for K cache, [batch_size, max_pages_per_seq].

  • cu_new_lens[in] Cumulative sequence lengths.

  • cu_cached_lens[in] Cached cumulative sequence lengths.

  • qkv_format[in] QKV format, e.g. sbhd.

  • b[in] Batch size.

  • max_ctx_len[in] Maximum context length.

  • max_seq_len[in] Maximum sequence length.

  • max_pages_per_seq[in] Maximum number of pages per sequence.

  • is_non_paged[in] Whether the cache is paged or not.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_read_half_tensor(const NVTETensor &tensor, const NVTETensor &cu_seqlens, NVTETensor half, int half_idx, cudaStream_t stream)

Extract the first half (half_idx=0) or second half (half_idx=1) of a THD tensor.

Warning

This API is experimental and subject to change.

Parameters:
  • tensor[in] Input tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • half[out] Output tensor.

  • half_idx[in] Whether to read first or second half of input tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_second_half_lse_correction(NVTETensor lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int lse_packed, cudaStream_t stream)

Correct the second half of the softmax LSE (LogSumExp) for context parallelism.

Warning

This API is experimental and subject to change.

Parameters:
  • lse[out] Output tensor.

  • lse_per_step[in] Input tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • lse_packed[in] Whether or not lse_per_step is packed.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_read_second_half_lse(const NVTETensor &lse, const NVTETensor &cu_seqlens, NVTETensor half_lse, int lse_packed, int second_half_lse_seqlen, cudaStream_t stream)

Read the second half of the softmax LSE (LogSumExp) for context parallelism.

Warning

This API is experimental and subject to change.

Parameters:
  • lse[in] Input tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • half_lse[out] Output tensor.

  • lse_packed[in] Whether or the softmax LSE is in packed format.

  • second_half_lse_seqlen[in] Sequence length.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_out_correction(NVTETensor out, const NVTETensor &out_per_step, const NVTETensor &lse, const NVTETensor &lse_per_step, const NVTETensor &cu_seqlens, int only_second_half, int lse_packed, cudaStream_t stream)

Correct the THD format output of context parallelism in forward pass.

Warning

This API is experimental and subject to change.

Parameters:
  • out[out] Output tensor.

  • out_per_step[in] THD format output of context parallelism in forward pass.

  • lse[in] Softmax LSE.

  • lse_per_step[in] Softmax LSE per step.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • only_second_half[in] Whether or not to correct only second half.

  • lse_packed[in] Whether or the softmax LSE is in packed format.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_grad_correction(NVTETensor grad, const NVTETensor &grad_per_step, const NVTETensor &cu_seqlens, const char *first_half, const char *second_half, cudaStream_t stream)

Correct the THD format output of context parallelism in forward pass.

Warning

This API is experimental and subject to change.

Parameters:
  • grad[out] Output tensor.

  • grad_per_step[in] THD format gradient of context parallelism.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • first_half[in] One of (“add”, “copy”, “none”) correction op for first half.

  • second_half[in] One of (“add”, “copy”, “none”) correction op for second half. Must be different from first_half.

  • stream[in] CUDA stream used for this operation.

void nvte_cp_thd_get_partitioned_indices(const NVTETensor &cu_seqlens, NVTETensor output, int total_tokens, int world_size, int rank, cudaStream_t stream)

Generate partitioned indices for inputs in THD format.

Warning

This API is experimental and subject to change.

Parameters:
  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • output[out] Output tensor.

  • total_tokens[in] Total number of tokens.

  • world_size[in] Total number of devices for context parallelism.

  • rank[in] Device ID for current device.

  • stream[in] CUDA stream used for this operation.

void nvte_convert_thd_to_bshd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int b, int max_seq_len, cudaStream_t stream)

Convert tensor from THD to BSHD format.

Warning

This API is experimental and subject to change.

Parameters:
  • tensor[in] Input tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • new_tensor[out] Output tensor.

  • b[in] Batch size.

  • max_seq_len[in] Maximum sequence length.

  • stream[in] CUDA stream used for this operation.

void nvte_convert_bshd_to_thd(NVTETensor tensor, NVTETensor cu_seqlens, NVTETensor new_tensor, int t, cudaStream_t stream)

Convert tensor from BSHD to THD format.

Warning

This API is experimental and subject to change.

Parameters:
  • tensor[in] Input tensor.

  • cu_seqlens[in] Cumulative sequence lengths, [batch_size + 1].

  • new_tensor[out] Output tensor.

  • b[in] Batch size.

  • max_seq_len[in] Maximum sequence length.

  • stream[in] CUDA stream used for this operation.

void nvte_prepare_flash_attn_fwd(NVTETensor qkvi, NVTETensor qkv, cudaStream_t stream)

Prepare QKV tensor for Flash Attention forward kernel.

Warning

This API is experimental and subject to change.

Parameters:
  • qkvi[in] Input tensor.

  • qkv[out] Output tensor.

  • stream[in] CUDA stream used for this operation.

void nvte_prepare_flash_attn_bwd(NVTETensor q, NVTETensor k, NVTETensor v, NVTETensor qkv, cudaStream_t stream)

Prepare QKV tensor for Flash Attention backward kernel.

Warning

This API is experimental and subject to change.

Parameters:
  • q[in] Input query tensor.

  • k[in] Input key tensor.

  • v[in] Input value tensor.

  • qkv[out] Output tensor.

  • stream[in] CUDA stream used for this operation.