Attention Is All You Need!

The core idea behind Transformer models is the attention mechanism [1]. It identifies the correlation between words, selects the most important parts of the sentence to focus on, and captures meaningful patterns and dependencies in the data. Figure 1 shows a typical attention mechanism, where pre-softmax operations can be a combination of scaling, bias and masking while the post-softmax operation is often just dropout.

187b685f09694b67911e16273401e253

Figure 1: Dot product attention.

Transformer Engine supports the calculation of dot product attention in three frameworks, PyTorch, JAX and PaddlePaddle. The API for each framework is

1. Attention Backends

Transformer Engine provides multiple attention backends for each supported framework. The framework-native backends provide a robust baseline, while the fused, GPU-optimized implementations offer more performance. For example, the flash-attention and cuDNN attention backends in PyTorch. The framework-native backends are often named with “unfused”, while the more optimized backends are “fused” or “flash”.

Framework

Backend (Module Name)

Module Location

PyTorch

cuDNN attention (FusedAttention)

transformer_engine.pytorch.attention

flash-attention (FlashAttention)

PyTorch-native attention (UnfusedDotProductAttention)

JAX

cuDNN attention (_FusedDotProductAttention)

transformer_engine.jax.flax.transformer

JAX-native attention (_UnfusedDotProductAttention)

PaddlePaddle

cuDNN attention (_te_forward)

transformer_engine.paddle.layer.attention

PaddlePaddle-native attention (_pd_forward)

1.1 Flash vs. Non-Flash

The attention calculation has quadratic computational and memory complexities to the sequence length. Its runtime and memory requirements quadruple, when the sequence length doubles. This presents a significant challenge to scale Transformer models up for longer contexts, in order to achieve higher model quality.

Compared to the standard, non-flash algorithm, the flash algorithm [2] was proposed to reduce the memory scaling to linear and improve the computational efficiency through optimized memory accesses. It employs the following two distinctive techniques.

  • Tiling: The non-flash algorithm tries to process the query, key, value tensors in one single step, requiring large amounts of global memory and incurring high volumes of reads/writes between global memory and shared memory. The flash algorithm decomposes the input into several tiles, based on the available shared memory and register size, and it computes the softmax one tile at a time.

  • Recomputation: The non-flash algorithm stores the softmax matrix (quadratic to sequence length) to global memory for the backward pass, while the flash algorithm only saves the softmax normalization factors (linear to sequence length). This reduces the amount of memory required as well as the bandwidth utilization between global memory and shared memory. Even though there is extra computation incurred in order to recalculate the attention in the backward pass, the bandwidth savings still provide significant improvement in efficiency.

Note

Transformer Engine’s flash-attention backend, available in PyTorch, and cuDNN attention backend (sub-backends 1 and 2), in PyTorch, JAX and PaddlePaddle, are both based on the flash algorithm.

1.2 flash-attention

The flash-attention backend, available only in PyTorch, is a module wrapped around the public flash-attn package [3].

The flash-attention backend supports flash-attn’s features as they are released, and to facilitate the use of flash-attn, flash-attention also offers a few functionalities such as converting the attention_mask to cumulative sequence lengths cu_seqlens for padding mask. Please see transformer_engine.pytorch.attention.FlashAttention for more details.

The flash-attn dependency is regularly updated in Transformer Engine. As of v1.7, Transformer Engine supports flash-attn 2.0.6+ (see setup.py).

To understand flash-attn’s performance, please refer to their benchmarks.

1.3 cuDNN Attention

The cuDNN attention backend, available in PyTorch, JAX and PaddlePaddle, offers another high-performance solution to the attention calculation. It requires cuDNN to run, and has several sub-backends to support the different precisions and sequence lengths. Out of the three, sub-backends 1 and 2 are based on the flash algorithm, as flash-attn is.

Sub-Backend

Algorithm

Precision

Sequence Length

Architecture

Additional info

0

Non-Flash

BF16/FP16

≤512

sm80, 90

cuDNN

1

Flash

BF16/FP16

Any

sm80+

cuDNN, cudnn-frontend

2

Flash

FP8

cuDNN pre-9.0: ≤512

cuDNN pre-9.0: sm90

cuDNN 9.0+: Any

cuDNN 9.0+: sm90+

cuDNN 9.0+: cudnn-frontend

The cuDNN attention backend and flash-attention backend have several notable differences. As of Transformer Engine 1.7, cuDNN 9.0 and flash-attn 2.4.2,

  • flash-attention only supports the PyTorch framework while cuDNN attention supports PyTorch, JAX and PaddlePaddle.

  • flash-attention supports BF16, FP16 precisions while cuDNN attention also supports FP8 (through its sub-backend 2).

  • flash-attention supports bshd, thd input formats, without any transposes, and sbhd format, with transposes, while cuDNN attention supports all three without transposes (see Section 3.1 for more details).

  • flash-attention does not support post_scale_bias, and cuDNN attention does.

  • flash-attention supports sliding window attention, and cuDNN attention does not.

  • flash-attention uses bottom right diagonal for causal mask in cross attention, and cuDNN attention uses top left (see flash-attn’s change log).

  • flash-attention outperforms cuDNN attention on Ampere architectures, and cuDNN attention has 20-50% advantages on Hopper architectures, based on our benchmarks for a number of commonly-used model configurations.

To compare cuDNN attention and flash-attention, users can modify the model_configs dictionary in benchmarks/attention/benchmark_attention.py to collect performance numbers. The script runs each entry in model_configs for num_iters times, each time with one forward pass and one backward pass. Both backends are tried, and if one backend does not have support for the specific user input, the runtimes and speedups in the final table would be 0.

[ ]:
model_configs = {
    #   test:             b,  h, hg,   d,   sq,  skv,   p,     mask,              bias
    "test_0": ModelConfig(2, 16, 16,  64,  512,  512, 0.0, "no_mask",         "no_bias"), # short seq
    "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal",         "no_bias"), # longer seq, mask
    "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0,  "causal", "post_scale_bias"), # bias
    "test_3": ModelConfig(2, 32,  4, 128, 8192, 8192, 0.0,  "causal",         "no_bias"), # GQA
}
[2]:
!cd ../../../benchmarks/attention/ && python benchmark_attention.py
Device 0: NVIDIA H100 PCIe GPU, sm90 compute capability, 79.1GB memory
Running test_0 with cuDNN attention and flash-attention...
Running test_1 with cuDNN attention and flash-attention...
Running test_2 with cuDNN attention...
Running test_3 with cuDNN attention and flash-attention...

        cuDNN fwd+bwd (ms)  flash-attn fwd+bwd (ms)  cuDNN vs flash speedup
test_0              0.0638                   0.0858                  1.3454
test_1              0.5415                   0.7496                  1.3842
test_2              1.2302                   0.0000                  0.0000
test_3             12.0122                  19.0716                  1.5877

2. Backend Selection

Given the various attention backends, Transformer Engine has a selection logic in place to choose the most appropriate backend for a particular set of user inputs and runtime environment. The selection logic is based on both backend availability and backend performance.

Backend availability is determined by factors such as model configuration, training hyper-parameters, software versions, and the GPU architecture in question. For example, some considerations are the sequence length, number of attention heads, head size, attention mask type, attention bias type, training or inference mode, self or cross attention, MHA or MQA/GQA, flash-attn/cuDNN library versions, and the compute capability of the GPU.

When there are multiple backends available, Transformer Engine makes backend selection based on performance. In general, there are a few rules being followed in our selection logic (see table below). As we monitor the performance of different backends, the selection logic may change.

Framework

Selection Order

PyTorch

sm90: cuDNN attention > flash-attention > PyTorch-native attention

sm80: flash-attention > cuDNN attention > PyTorch-native attention

cuDNN attention: sub-backend 1 > sub-backend 0

JAX

cuDNN attention > JAX-native attention

PaddlePaddle

cuDNN attention > PaddlePaddle-native attention

2.1 Debug Information

To find out which backend is being used during runtime, users can turn on these debugging flags. Logging is done using the logging package.

NVTE_DEBUG       = 0/1   # disables/enables debugging
NVTE_DEBUG_LEVEL = 0/1/2 # enables logging.WARNING/INFO/DEBUG-level messages

Note

These flags are supported in PyTorch only as of Transformer Engine 1.7. JAX and PaddlePaddle support is expected to be added in the future.

The example_attention.py script runs a very basic model with two attention backends, cuDNN attention and flash-attention. Here NVTE_DEBUG_LEVEL=1 allows us to find out which backend/sub-backend was actually used during runtime.

[22]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python example_attention.py

Run cuDNN attention...
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)

Run flash-attention...
[INFO     | DotProductAttention]: Running with FlashAttention backend

Test passed.

To collect more information, users can turn on NVTE_DEBUG_LEVEL=2. In this example, it allows us to find out more about the run config. Users are encouraged to provide if users intend to file a bug with Transformer Engine. For example,

[25]:
!NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2 python example_attention.py

Run cuDNN attention...
[DEBUG    | DotProductAttention]: Disabling FlashAttention due to NVTE_FLASH_ATTN=0
[INFO     | DotProductAttention]: Running with FusedAttention backend (sub-backend 1)
[DEBUG    | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}
[DEBUG    | FusedAttnFunc      ]: Running forward in torch.bfloat16
[DEBUG    | FusedAttnFunc      ]: Running backward in torch.bfloat16

Run flash-attention...
[DEBUG    | DotProductAttention]: Disabling FusedAttention due to NVTE_FUSED_ATTN=0
[INFO     | DotProductAttention]: Running with FlashAttention backend
[DEBUG    | DotProductAttention]: Running with {'compute_capability': 'sm90', 'q_dtype': torch.bfloat16, 'k_dtype': torch.bfloat16, 'v_dtype': torch.bfloat16, 'q_shape': [2, 512, 16, 64], 'k_shape': [2, 512, 16, 64], 'v_shape': [2, 512, 16, 64], 'qkv_format': 'bshd', 'qkv_layout': 'bshd_bshd_bshd', 'mask_type': 'no_mask', 'bias_type': 'no_bias', 'bias_shape': None, 'dropout': 0.0, 'context_parallel': False, 'is_training': True, 'transformer_engine_version': <Version('1.8.0.dev0')>, 'flash_attn_version': <Version('2.4.2')>, 'cudnn_version': '9.2.0'}

Test passed.

2.2 User Control

Users usually do not need to worry about the backend selection. However, if there is a convergence or performance issue encountered, Transformer Engine provides a few other environment variables for users to experiment with different backends.

flash-attention or cuDNN attention: Users can enable/disable the flash-attention backend or cuDNN attention backend via the following two environment variables in PyTorch.

NVTE_FLASH_ATTN = 0 # disables flash-attention; default = 1
NVTE_FUSED_ATTN = 0 # disables cuDNN attention; default = 1

cuDNN attention sub-backends: This environment variable allows users to express their preference of cuDNN attention sub-backends. However, the elected sub-backend will only be used if it is eligible, i.e. if it has support for the provided inputs and runtime environment.

NVTE_FUSED_ATTN_BACKEND = 0/1/2 # user preference of cuDNN sub-backend

Execution paths of cuDNN sub-backend 1: cuDNN attention sub-backend 1 also offers two execution paths: workspace optimization path and non-workspace optimization path. The workspace optimization path requires a larger amount of global memory, provides determinism, and offers bias gradient support. Before cuDNN 9.0, it also has 20-30% performance advantage over the non-workspace optimization path. But after cuDNN 9.0, it is 20-30% slower than the non-workspace optimization path.

Users can experiment with these two paths through the following environment variable. However, please be aware of the possible Out-Of-Memory risks.

Before cuDNN 9.0:
    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 0 # disables workspace optimization path
    NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT = 1 # enables workspace optimization path

After cuDNN 9.0:
    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 1 # disables workspace optimization path
    NVTE_ALLOW_NONDETERMINISTIC_ALGO = 0 # enables workspace optimization path

Note

Environment variables NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT and NVTE_ALLOW_NONDETERMINISTIC_ALGO are only supported in PyTorch, and will be added to JAX and PaddlePaddle in the future.

2.3 Example Tests

Our unit tests demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.

For example, in PyTorch, test_dot_product_attention offers a variety of use cases of pytorch.DotProductAttention, from data types, model configs, checkpointing, to QKV layouts.

3. Backend Support

Transformer Engine supports commonly-used features such as self and cross attention, FP16/BF16 precisions, dropout, and checkpointing. But it also offers a range of other features. As of v1.7, Transformer Engine’s attention backends have the following support matrix.

Attention Backend

Precision

Architecture

Sliding Window Attention

MQA/GQA

Context Parallelism

Determinism Possible

cuDNN attention (all frameworks)

BF16, FP16, FP8 (PyTorch only)

sm80+

No

Yes

Yes (only for bshd,sbhd)

Yes

flash-attention (PyTorch)

BF16, FP16

sm80+

Yes

Yes

Yes (only for bshd,thd)

Yes

Framework-native attention

BF16, FP16, FP32

Any

No, unless used as a mask

Yes

No

Yes

Some unit tests are provided to serve as a starting point for integrating such features into users’ models. For example, - sliding window attention: test_dpa_swa - MQA/GQA: test_te_layer_mqa_gqa - context parallelism: test_cp_with_fused_attention, test_cp_with_flash_attention

3.1 QKV Layout

Transformer Engine supports various layouts of the query q, key k, value v tensors. It has defined 15 QKV layouts, which are grouped into 3 QKV formats and 5 QKV layout groups to help with similar memory/computational operations across different layouts. The mapping relationships of these layouts and groups are,

qkv_layout

qkv_layout_group=3hd

h3d

hd_2hd

hd_h2d

hd_hd_hd

qkv_format=sbhd

sb3hd

sbh3d

sbhd_sb2hd

sbhd_sbh2d

sbhd_sbhd_sbhd

bshd

bs3hd

bsh3d

bshd_bs2hd

bshd_bsh2d

bshd_bshd_bshd

thd

t3hd

th3d

thd_t2hd

thd_th2d

thd_thd_thd

The notation system is that b stands for the batch size, s sequence length, h number of attention heads, d head dimension, and t the total number of tokens in the batch, i.e. t = sum(s_i) for i in 0,...,b-1. Here are a few examples of the layouts and their explanations to help clarify the definition.

qkv_layout=sb3hd: q, k, v are sequence first, i.e. s is the leading dimension in each tensor. They are different slices of one tensor qkv: q, k, v = [qkv[:,:,i,:,:] for i in range(3)]. They are interleaved at the h * d dimension.

qkv_layout=bshd_bsh2d: q, k, v are batch first, i.e. b is the leading dimension in each tensor. q is contiguous, and k, v are different slices of tensor kv: k, v = [kv[:,:,:,i,:] for i in range(2)]. k, v are interleaved at the d dimension.

The s and h in bsh2d are the max sequence length and number of heads for k, v, which can be different from the s and h in bshd for q. We denoted them as the same for brevity reasons. Transformer Engine does differentiate their values for actual execution.

qkv_layout=thd_thd_thd: q, k, v have variable sequence lengths in a batch. They are all contiguous and have no interleaving.

As of v1.7, Transformer Engine has the following support matrix.

Backend

Supported QKV Formats

Notes

flash-attention

bshd, sbhd, thd

PyTorch: 3 formats, i.e. 15 layouts

cuDNN attention

bshd, sbhd, thd

PyTorch: 3 formats, i.e. 15 layouts

JAX, PaddlePaddle: bs3hd, bshd_bs2hd, bshd_bshd_bshd layouts

Framework-native attention

bshd, sbhd

PyTorch, JAX, PaddlePaddle: 2 formats, i.e. 10 layouts

Some example usage of the different layouts can be found at test_dpa_qkv_layout and test_dpa_qkv_layout_thd. Transformer Engine also provides a utility function transformer_engine.pytorch.attention.get_qkv_layout to help determine which layout a set of q, k, v tensors have (PyTorch only).

Note

When RoPE is employed, the qkv_layout may change in Transformer Engine PyTorch through get_qkv_layout. This is due to the in-place nature of our RoPE implementations. We convert q, k, v tensors from their initial layout to the corresponding hd_hd_hd layout. For example, from sbh3d in pytorch.MultiHeadAttention before RoPE, to sbhd_sbhd_sbhd in pytorch.DotProductAttention after RoPE.

3.2 Attention Mask

Transformer Engine supports 5 mask types, and all the masks are defined as True masking out the corresponding element and False including the corresponding element in attention calculation.

  • no_mask, padding, causal, padding_causal (equivalent to causal_padding), arbitrary

Different backends offer different support for attention mask. As of Transformer Engine 1.7,

Backend

Supported Mask Types

Requires attention_mask

flash-attention

no_mask, causal, padding, padding_causal

no_mask, causal: No

padding, padding_causal: Yes if cu_seqlens not provided

cuDNN attention

no_mask, causal, padding, padding_causal

no_mask, causal: No

padding, padding_causal: Yes if cu_seqlens not provided

Framework-native attention

no_mask, causal, arbitrary

no_mask, causal: No

arbitrary: Yes

padding and padding_causal: For these two mask types, users need to provide sequence length information to help Transformer Engine figure out where each sequence ends in a batch. As of Transformer Engine 1.7, there are two options to do so in PyTorch and one in JAX and PaddlePaddle.

  • PyTorch: When both options are provided by the user, cu_seqlens is preferred as there is no extra conversion needed.

    • cu_seqlens: Users can provide cumulative sequence length tensors cu_seqlens_q and cu_seqlens_kv for q and k/v to the flash-attention or cuDNN attention backend. An example of cu_seqlens is [0, 2, 6, 7] for a batch of 3 [aa000, bbbb0, c0000].

    • attention_mask: Users can also provide attention_mask as an alternative, which will then be converted to cu_seqlens. For self-attention, attention_mask should be one single tensor in shape [batch_size, 1, 1, seqlen_q], and for cross-attention, attention_mask should be a list of two tensors in shapes [batch_size, 1, 1, seqlen_q] and [batch_size, 1, 1, seqlen_kv], respectively.

  • JAX and PaddlePaddle: Users should provide the attention_mask tensor in shape [batch_size, 1, seqlen_q, seqlen_kv].

qkv_format=thd: Transformer Engine extracts the max sequence length information from q, k, v if max_seqlen_q and max_seqlen_kv are not provided. This requires GPU-CPU copy and synchronization operations. For performance reasons, please set max_seqlen_q and max_seqlen_kv to their appropriate values for thd QKV format.

Arbitrary mask: cuDNN does not support Arbitrary mask type as of v9.0. However, users can convert the mask to a regular post_scale_bias bias and achieve the same functionality. An example script for this conversion is arbitrary_mask_to_post_scale_bias.py.

[6]:
!NVTE_DEBUG=1 python arbitrary_mask_to_post_scale_bias.py
Run with post_scale_bias:
[DotProductAttention]: using cuDNN attention (sub-backend 1)
Run with arbitrary mask:
[DotProductAttention]: using unfused DPA
Test passed!

Some more examples of running Transformer Engine with different attention masks can be found at test_dpa_mask.

3.3 Attention Bias

Transformer Engine supports 4 attention bias types, no_bias, pre_scale_bias, post_scale_bias, and ALiBi (with/without custom slopes). As of Transformer Engine 1.7, their support matrix is as follows.

Backend

Bias Type

Bias Shape

Bias Data Type

Architecture

flash-attention

no_bias, ALiBi (with slopes)

N/A

ALiBi slopes: FP32

sm80+

cuDNN attention

PyTorch: no_bias, post_scale_bias, ALiBi (without slopes)

post_scale_bias: BHSS, 1HSS, B1SS, 11SS for forward, 1HSS for backward

post_scale_bias: same as QKV type

cuDNN 8.9.6+: sm90

JAX, PaddlePaddle: no_bias, post_scale_bias

ALiBi slopes: FP32

cuDNN 9.0+: sm80+

Framework-native attention

no_bias, pre_scale_bias, post_scale_bias

post_scale_bias: BHSS, 1HSS, B1SS, 11SS

post_scale_bias: same as QKV type

sm80+

The flash-attention backend enables ALiBi by asking user to pass in an alibi_slopes tensor, which can be the default slopes of vanilla ALiBi, or user-defined slopes. On the other hand, cuDNN attention supports ALiBi by taking in a Boolean flag, and it only supports vanilla ALiBi as of cuDNN 9.0.

The framework-native backends do not explicitly support ALiBi, but users can convert ALiBi to a regular post_scale_bias bias to achieve the same effect. In PyTorch, this utility function, transformer_engine.pytorch.attention.get_alibi, can be used to help with the conversion.

More examples of how to use the various attention biases are at test_dpa_bias.

3.4 FP8 Attention

A unique feature of Transformer Engine is its FP8 support, not only for the Linear layers but also for dot product attention. Transformer Engine’s FP8 attention support is through its cuDNN attention sub-backend 2. Recall Figure 1: the two MatMul operations are performed in FP8 for computational efficiency, and the SoftMax operation is performed in FP32 for numerical accuracy.

Transformer Engine supports FP8 attention through its C APIs, and PyTorch API, as of v1.7. Its PyTorch API offers two options, both controlled through the FP8 recipe definition, transformer_engine.common.recipe.DelayedScaling.

  • DelayedScaling.fp8_dpa=True (default=False): This enables the use of cuDNN attention sub-backend 2, when it does support the provided user inputs. The FusedAttention module for cuDNN attention takes FP16 or BF16 tensors as inputs, performs dot product attention in FP8, and returns attention logits in FP16 or BF16 (same as the input type). Casting operations are required to cast tensors to FP8 at the beginning, and back to FP16/BF16 at the end of the module.

  • DelayedScaling.fp8_mha=True (default=False): This option, on top of fp8_dpa=True, removes the casting operations at the beginning and end of the FusedAttention module. This feature is experimental.

Examples of using the two features are available at test_dpa_fp8_vs_f16 and test_mha_fp8_vs_f16. To disable FP8 attention for backward and only use it for forward, users can also set NVTE_FP8_DPA_BWD=0 (default=1). This should result in the following print when the debug flags are turned on, NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=2.

[DEBUG    | DotProductAttention]: Running with fp8_recipe.fp8_mha=False, fp8_recipe.fp8_dpa=True and NVTE_FP8_DPA_BWD=0
[DEBUG    | FusedAttnFunc      ]: Running forward in FP8
[DEBUG    | FusedAttnFunc      ]: Running backward in torch.bfloat16