Getting Started

Overview

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs.

TE implements a collection of highly optimized building blocks for popular Transformer architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly with your deep learning code.

Currently two frameworks are supported: PyTorch and JAX.

Basic knowledge of PyTorch is recommended:

Baseline: Pure Framework Implementation

Let’s build a Transformer decoder layer!

We’ll create a basic GPT-style layer with causal masking, which prevents each position from attending to future positions. This will be our baseline for later comparisons with Transformer Engine.

LayerNorm QKV Projection Dot Product Attention Projection Dropout + LayerNorm MLP +

Structure of a GPT decoder layer

We construct the components as follows:

  • LayerNorm: torch.nn.LayerNorm

  • QKV Projection: torch.nn.Linear (fused Q, K, V into single layer 3x larger)

  • DotProductAttention: Custom implementation using torch.bmm

  • Projection: torch.nn.Linear

  • Dropout: torch.nn.Dropout

  • MLP: Two torch.nn.Linear layers with torch.nn.functional.gelu activation

Putting it all together:

First, define the MLP block:

class PyTorchMLP(torch.nn.Module):
    """Feed-forward network in Transformer layer.
    Built with plain PyTorch modules.
    """

    hidden_size: int
    ffn_hidden_size: int

    def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
        self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = torch.nn.functional.gelu(x, approximate="tanh")
        x = self.linear2(x)
        return x


Now, putting it all together into a GPT decoder layer:

class PyTorchTransformerLayer(torch.nn.Module):
    """Basic Transformer layer using plain PyTorch modules."""

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: float = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads
        self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.attention = DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
        )
        self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
        self.dropout = torch.nn.Dropout(hidden_dropout)
        self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
        self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)

    def forward(
        self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        res = x
        x = self.ln1(x)

        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout(x)
        x = res + x

        # Second residual connection
        res = x
        x = self.ln2(x)
        x = self.mlp(x)

        return x + res


Benchmark the baseline implementation:

baseline = (
    PyTorchTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

print("Baseline PyTorch:")
time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline")
Output:
Baseline PyTorch:
Mean time: 48.280 ms

TE Unfused: Basic TE Modules

Now let’s replace the standard framework modules with TE equivalents. This is the simplest way to start using Transformer Engine.

Replace PyTorch modules with TE equivalents:

import transformer_engine.pytorch as te

Mapping:

  • torch.nn.Linearte.Linear

  • torch.nn.LayerNormte.LayerNorm

class TEUnfusedMLP(torch.nn.Module):
    """MLP using TE modules."""

    hidden_size: int
    ffn_hidden_size: int

    def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
        self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = torch.nn.functional.gelu(x, approximate="tanh")
        x = self.linear2(x)
        return x


class TEUnfusedTransformerLayer(torch.nn.Module):
    """Transformer layer using basic TE modules."""

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: float = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads
        self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
        self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.attention = DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
        )
        self.projection = te.Linear(hidden_size, hidden_size, bias=True)
        self.dropout1 = torch.nn.Dropout(hidden_dropout)
        self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
        self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
        self.dropout2 = torch.nn.Dropout(hidden_dropout)

    def forward(
        self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        res = x
        x = self.ln1(x)

        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout1(x)
        x = res + x

        # Second residual connection
        res = x
        x = self.ln2(x)
        x = self.mlp(x)
        x = self.dropout2(x)

        return x + res


Benchmark the TE unfused implementation:

te_unfused = (
    TEUnfusedTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

print("TE Unfused:")
time_te_unfused = speedometer(
    te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused"
)
Output:
TE Unfused:
Mean time: 49.342 ms

TE Unfused + TE Attention

Now let’s also replace the attention mechanism with TE’s optimized DotProductAttention. TE’s attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration, delivering optimal performance without manual tuning.

Replace the custom attention with TE’s optimized implementation:

  • Custom DotProductAttentionte.DotProductAttention

class TEUnfusedAttnTransformerLayer(torch.nn.Module):
    """Transformer layer using TE modules including TE DotProductAttention."""

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: float = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads
        self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
        self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
        self.attention = te.DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
            attn_mask_type="causal",
        )
        self.projection = te.Linear(hidden_size, hidden_size, bias=True)
        self.dropout1 = torch.nn.Dropout(hidden_dropout)
        self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
        self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
        self.dropout2 = torch.nn.Dropout(hidden_dropout)

    def forward(
        self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        res = x
        x = self.ln1(x)

        # Fused QKV projection
        qkv = self.qkv_projection(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout1(x)
        x = res + x

        # Second residual connection
        res = x
        x = self.ln2(x)
        x = self.mlp(x)
        x = self.dropout2(x)

        return x + res


Benchmark TE Unfused with TE Attention:

te_unfused_attn = (
    TEUnfusedAttnTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
    te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn"
)
Output:
TE Unfused + TE Attention:
Mean time: 35.709 ms

TE Unfused + TE Attention + FP8

Now let’s combine TE modules with TE Attention and enable FP8 precision. Wrap your code within an autocast context manager to enable FP8. This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs).

from transformer_engine.common.recipe import Format, DelayedScaling

recipe = DelayedScaling(
    fp8_format=Format.HYBRID,
    amax_history_len=16,
    amax_compute_algo="max"
)

with te.autocast(enabled=True, recipe=recipe):
    y = te_unfused(x, attention_mask=None)

Note

The autocast should only wrap the forward pass and must exit before starting a backward pass.

Benchmark TE Unfused with FP8:

recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")

te_unfused_fp8 = (
    TEUnfusedAttnTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
    te_unfused_fp8,
    x,
    forward_kwargs={"attention_mask": None},
    autocast_kwargs={"enabled": True, "recipe": recipe},
    label="te_unfused_fp8",
)
Output:
TE Unfused + TE Attention + FP8:
Mean time: 23.406 ms

TE Fused + TE Attention + FP8: Optimized Modules

Fused modules use kernel fusion to combine multiple operations. While speedups are modest on a single GPU, they scale better in multi-GPU setups. Combined with TE Attention and FP8, this delivers peak performance.

Fused modules available:

  • te.LayerNormLinear - fuses LayerNorm + Linear

  • te.LayerNormMLP - fuses LayerNorm + MLP

class TEFusedTransformerLayer(torch.nn.Module):
    """Transformer layer using fused TE modules for better performance."""

    def __init__(
        self,
        hidden_size: int,
        ffn_hidden_size: int,
        num_attention_heads: int,
        layernorm_eps: float = 1e-5,
        attention_dropout: float = 0.1,
        hidden_dropout: float = 0.1,
    ):
        super().__init__()
        self.num_attention_heads = num_attention_heads
        self.kv_channels = hidden_size // num_attention_heads

        # Fused LayerNorm + QKV projection
        self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)
        self.attention = te.DotProductAttention(
            num_attention_heads=num_attention_heads,
            kv_channels=self.kv_channels,
            attention_dropout=attention_dropout,
            attn_mask_type="causal",
        )
        self.projection = te.Linear(hidden_size, hidden_size, bias=True)
        self.dropout1 = torch.nn.Dropout(hidden_dropout)

        # Fused LayerNorm + MLP
        self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)
        self.dropout2 = torch.nn.Dropout(hidden_dropout)

    def forward(
        self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        res = x

        # Fused LayerNorm + QKV projection
        qkv = self.ln_qkv(x)
        qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
        q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)

        x = self.attention(q, k, v, attention_mask)
        x = self.projection(x)
        x = self.dropout1(x)
        x = res + x

        # Fused LayerNorm + MLP
        res = x
        x = self.ln_mlp(x)
        x = self.dropout2(x)

        return x + res


Benchmark TE Fused with FP8:

te_fused_fp8 = (
    TEFusedTransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
    )
    .to(dtype=dtype)
    .cuda()
)

print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
    te_fused_fp8,
    x,
    forward_kwargs={"attention_mask": None},
    autocast_kwargs={"enabled": True, "recipe": recipe},
    label="te_fused_fp8",
)
Output:
TE Fused + TE Attention + FP8:
Mean time: 22.964 ms

TE TransformerLayer + FP8: Ready-to-use Module

For the simplest integration, Transformer Engine provides a ready-to-use TransformerLayer module that includes all optimizations out of the box.

Just use te.TransformerLayer - it handles everything for you:

te_transformer_layer = (
    te.TransformerLayer(
        hidden_size=hidden_size,
        ffn_hidden_size=ffn_hidden_size,
        num_attention_heads=num_attention_heads,
        self_attn_mask_type="causal",
        layernorm_epsilon=1e-5,
        bias=True,
        hidden_dropout=0.0,
        attention_dropout=0.0,
    )
    .to(dtype=dtype)
    .cuda()
)

print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
    te_transformer_layer,
    x,
    forward_kwargs={"attention_mask": None},
    autocast_kwargs={"enabled": True, "recipe": recipe},
    label="te_transformer_layer",
)
Output:
TE TransformerLayer + FP8:
Mean time: 21.670 ms

Benchmark Summary

The table below summarizes the performance improvements achieved with Transformer Engine on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this tutorial focuses on a simple single-GPU scenario, features like fused layers can provide additional benefits in more complex setups such as multi-GPU training.

Implementation

Time (ms)

Speedup

Baseline PyTorch

48.28

1.00x

TE Unfused

49.34

0.98x

TE Unfused + TE Attention

35.71

1.35x

TE Unfused + TE Attention + FP8

23.41

2.06x

TE Fused + TE Attention + FP8

22.96

2.10x

TE TransformerLayer + FP8

21.67

2.23x