Getting Started
Overview
Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs.
TE implements a collection of highly optimized building blocks for popular Transformer architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly with your deep learning code.
Currently two frameworks are supported: PyTorch and JAX.
Basic knowledge of PyTorch is recommended:
We recommend understanding the basics of JAX first:
Baseline: Pure Framework Implementation
Let’s build a Transformer decoder layer!
We’ll create a basic GPT-style layer with causal masking, which prevents each position from attending to future positions. This will be our baseline for later comparisons with Transformer Engine.
Structure of a GPT decoder layer
We construct the components as follows:
LayerNorm:
torch.nn.LayerNormQKV Projection:
torch.nn.Linear(fused Q, K, V into single layer 3x larger)DotProductAttention: Custom implementation using
torch.bmmProjection:
torch.nn.LinearDropout:
torch.nn.DropoutMLP: Two
torch.nn.Linearlayers withtorch.nn.functional.geluactivation
LayerNorm:
nn.LayerNormQKV Projection:
nn.Dense(fused Q, K, V into single layer 3x larger)DotProductAttention:
nn.dot_product_attentionProjection:
nn.DenseDropout:
nn.DropoutMLP: Two
nn.Denselayers withnn.geluactivation
Putting it all together:
First, define the MLP block:
class PyTorchMLP(torch.nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain PyTorch modules.
"""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
Now, putting it all together into a GPT decoder layer:
class PyTorchTransformerLayer(torch.nn.Module):
"""Basic Transformer layer using plain PyTorch modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)
self.dropout = torch.nn.Dropout(hidden_dropout)
self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
return x + res
Benchmark the baseline implementation:
baseline = (
PyTorchTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("Baseline PyTorch:")
time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline")
Baseline PyTorch:
Mean time: 48.280 ms
First, define the MLP block:
class FlaxMLP(nn.Module):
"""Feed-forward network in Transformer layer.
Built with plain Flax modules.
"""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)
x = nn.gelu(x, approximate=True)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
return x
Now, putting it all together into a GPT decoder layer:
class FlaxTransformerLayer(nn.Module):
"""Basic Transformer layer using plain Flax modules."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# Fused QKV projection
qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = nn.Dense(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x)
return x + res
Benchmark the baseline implementation:
baseline = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = baseline.init(key, x, deterministic=False)
print("Baseline Flax:")
time_baseline = speedometer(
baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline"
)
Baseline Flax:
Mean time: 86.580 ms
TE Unfused: Basic TE Modules
Now let’s replace the standard framework modules with TE equivalents. This is the simplest way to start using Transformer Engine.
Replace PyTorch modules with TE equivalents:
import transformer_engine.pytorch as te
Mapping:
torch.nn.Linear→te.Lineartorch.nn.LayerNorm→te.LayerNorm
class TEUnfusedMLP(torch.nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None:
super().__init__()
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)
self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
class TEUnfusedTransformerLayer(torch.nn.Module):
"""Transformer layer using basic TE modules."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
Benchmark the TE unfused implementation:
te_unfused = (
TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused"
)
TE Unfused:
Mean time: 49.342 ms
Replace Flax modules with TE equivalents:
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
Mapping:
nn.Dense→te_flax.DenseGeneralnn.LayerNorm→te_flax.LayerNorm
class TEUnfusedMLP(nn.Module):
"""MLP using TE modules."""
hidden_size: int
ffn_hidden_size: int
@nn.compact
def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x)
x = x.reshape(*x.shape[:-1], 1, x.shape[-1])
x = te.activation.activation(x, activation_type=("gelu",))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
return x
class TEUnfusedTransformerLayer(nn.Module):
"""Transformer layer using basic TE modules (without TE attention)."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng("dropout")
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
Benchmark the TE unfused implementation:
te_unfused = TEUnfusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
params = te_unfused.init(key, x, deterministic=False)
print("TE Unfused:")
time_te_unfused = speedometer(
te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused"
)
TE Unfused:
Mean time: 42.252 ms
TE Unfused + TE Attention
Now let’s also replace the attention mechanism with TE’s optimized DotProductAttention.
TE’s attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration,
delivering optimal performance without manual tuning.
Replace the custom attention with TE’s optimized implementation:
Custom
DotProductAttention→te.DotProductAttention
class TEUnfusedAttnTransformerLayer(torch.nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)
self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
x = self.ln1(x)
# Fused QKV projection
qkv = self.qkv_projection(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Second residual connection
res = x
x = self.ln2(x)
x = self.mlp(x)
x = self.dropout2(x)
return x + res
Benchmark TE Unfused with TE Attention:
te_unfused_attn = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn"
)
TE Unfused + TE Attention:
Mean time: 35.709 ms
Replace Flax’s attention with TE’s optimized implementation:
nn.dot_product_attention→te_flax.DotProductAttention
class TEUnfusedAttnTransformerLayer(nn.Module):
"""Transformer layer using TE modules including TE DotProductAttention."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x)
qkv = te_flax.DenseGeneral(
features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16
)(x)
qkv = qkv.reshape(
qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels
)
q, k, v = jnp.split(qkv, 3, axis=3)
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)
mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size)
x = mlp(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
Benchmark TE Unfused with TE Attention:
te_unfused_attn = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=False, mesh_resource=mesh_resource):
params = te_unfused_attn.init(key, x, deterministic=False)
print("TE Unfused + TE Attention:")
time_te_unfused_attn = speedometer(
te_unfused_attn.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource},
label="te_unfused_attn",
)
TE Unfused + TE Attention:
Mean time: 35.054 ms
TE Unfused + TE Attention + FP8
Now let’s combine TE modules with TE Attention and enable FP8 precision.
Wrap your code within an autocast context manager to enable FP8.
This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs).
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
y = te_unfused(x, attention_mask=None)
Note
The autocast should only wrap the forward pass and must exit before
starting a backward pass.
Benchmark TE Unfused with FP8:
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = (
TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_unfused_fp8",
)
TE Unfused + TE Attention + FP8:
Mean time: 23.406 ms
from transformer_engine.common.recipe import Format, DelayedScaling
recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=16,
amax_compute_algo="max"
)
with te.autocast(enabled=True, recipe=recipe):
params = te_unfused.init(key, x, deterministic=False)
y = te_unfused.apply(params, x, deterministic=True)
Important
When using FP8 in JAX, the model must be initialized within the autocast context
to create the fp8_metas collection.
Benchmark TE Unfused with FP8:
recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max")
te_unfused_fp8 = TEUnfusedAttnTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_unfused_fp8.init(key, x, deterministic=False)
print("TE Unfused + TE Attention + FP8:")
time_te_unfused_fp8 = speedometer(
te_unfused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_unfused_fp8",
)
TE Unfused + TE Attention + FP8:
Mean time: 22.638 ms
TE Fused + TE Attention + FP8: Optimized Modules
Fused modules use kernel fusion to combine multiple operations. While speedups are modest on a single GPU, they scale better in multi-GPU setups. Combined with TE Attention and FP8, this delivers peak performance.
Fused modules available:
te.LayerNormLinear- fuses LayerNorm + Linearte.LayerNormMLP- fuses LayerNorm + MLP
class TEFusedTransformerLayer(torch.nn.Module):
"""Transformer layer using fused TE modules for better performance."""
def __init__(
self,
hidden_size: int,
ffn_hidden_size: int,
num_attention_heads: int,
layernorm_eps: float = 1e-5,
attention_dropout: float = 0.1,
hidden_dropout: float = 0.1,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.kv_channels = hidden_size // num_attention_heads
# Fused LayerNorm + QKV projection
self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)
self.attention = te.DotProductAttention(
num_attention_heads=num_attention_heads,
kv_channels=self.kv_channels,
attention_dropout=attention_dropout,
attn_mask_type="causal",
)
self.projection = te.Linear(hidden_size, hidden_size, bias=True)
self.dropout1 = torch.nn.Dropout(hidden_dropout)
# Fused LayerNorm + MLP
self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)
self.dropout2 = torch.nn.Dropout(hidden_dropout)
def forward(
self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
res = x
# Fused LayerNorm + QKV projection
qkv = self.ln_qkv(x)
qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)
q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)
x = self.attention(q, k, v, attention_mask)
x = self.projection(x)
x = self.dropout1(x)
x = res + x
# Fused LayerNorm + MLP
res = x
x = self.ln_mlp(x)
x = self.dropout2(x)
return x + res
Benchmark TE Fused with FP8:
te_fused_fp8 = (
TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
.to(dtype=dtype)
.cuda()
)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_fused_fp8",
)
TE Fused + TE Attention + FP8:
Mean time: 22.964 ms
Fused modules available:
te_flax.LayerNormDenseGeneral- fuses LayerNorm + Densete_flax.LayerNormMLP- fuses LayerNorm + MLP
class TEFusedTransformerLayer(nn.Module):
"""Transformer layer using fused TE modules for better performance."""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False,
) -> jnp.ndarray:
res = x
# Fused LayerNorm + QKV projection
qkv, _ = te_flax.LayerNormDenseGeneral(
features=3 * self.hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
return_layernorm_output=False,
)(x)
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels)
q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :]
attention = te_flax.DotProductAttention(
head_dim=self.kv_channels,
num_attention_heads=self.num_attention_heads,
num_gqa_groups=self.num_attention_heads,
attention_dropout=self.attention_dropout,
attn_mask_type="causal",
qkv_layout="bshd_bshd_bshd",
transpose_batch_sequence=False,
)
x = attention(q, k, v, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
x = res + x
res = x
# Fused LayerNorm + MLP
x, _ = te_flax.LayerNormMLP(
intermediate_dim=self.ffn_hidden_size,
epsilon=self.layernorm_eps,
use_bias=True,
activations=("gelu",),
intermediate_dropout_rate=0.0,
return_layernorm_output=False,
)(x, deterministic=deterministic)
x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)
return x + res
Benchmark TE Fused with FP8:
te_fused_fp8 = TEFusedTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_fused_fp8.init(key, x, deterministic=False)
print("TE Fused + TE Attention + FP8:")
time_te_fused_fp8 = speedometer(
te_fused_fp8.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_fused_fp8",
)
TE Fused + TE Attention + FP8:
Mean time: 23.703 ms
TE TransformerLayer + FP8: Ready-to-use Module
For the simplest integration, Transformer Engine provides a ready-to-use TransformerLayer
module that includes all optimizations out of the box.
Just use te.TransformerLayer - it handles everything for you:
te_transformer_layer = (
te.TransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
bias=True,
hidden_dropout=0.0,
attention_dropout=0.0,
)
.to(dtype=dtype)
.cuda()
)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer,
x,
forward_kwargs={"attention_mask": None},
autocast_kwargs={"enabled": True, "recipe": recipe},
label="te_transformer_layer",
)
TE TransformerLayer + FP8:
Mean time: 21.670 ms
Just use te_flax.TransformerLayer - it handles everything for you:
te_transformer_layer = te_flax.TransformerLayer(
hidden_size=hidden_size,
mlp_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
mlp_activations=("gelu",),
self_attn_mask_type="causal",
layernorm_epsilon=1e-5,
use_bias=True,
attention_dropout=0.0,
intermediate_dropout=0.0,
hidden_dropout=0.0,
enable_relative_embedding=False,
self_attn_bias_type="no_bias",
dtype=jnp.bfloat16,
transpose_batch_sequence=False,
)
with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource):
params = te_transformer_layer.init(key, x, deterministic=False)
print("TE TransformerLayer + FP8:")
time_te_transformer_layer = speedometer(
te_transformer_layer.apply,
params,
x,
forward_kwargs={"deterministic": True},
autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource},
label="te_transformer_layer",
)
TE TransformerLayer + FP8:
Mean time: 22.812 ms
Benchmark Summary
The table below summarizes the performance improvements achieved with Transformer Engine on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this tutorial focuses on a simple single-GPU scenario, features like fused layers can provide additional benefits in more complex setups such as multi-GPU training.
Implementation |
Time (ms) |
Speedup |
|---|---|---|
Baseline PyTorch |
48.28 |
1.00x |
TE Unfused |
49.34 |
0.98x |
TE Unfused + TE Attention |
35.71 |
1.35x |
TE Unfused + TE Attention + FP8 |
23.41 |
2.06x |
TE Fused + TE Attention + FP8 |
22.96 |
2.10x |
TE TransformerLayer + FP8 |
21.67 |
2.23x |
Implementation |
Time (ms) |
Speedup |
|---|---|---|
Baseline Flax |
86.58 |
1.00x |
TE Unfused |
42.25 |
2.05x |
TE Unfused + TE Attention |
35.05 |
2.47x |
TE Unfused + TE Attention + FP8 |
22.64 |
3.82x |
TE Fused + TE Attention + FP8 |
23.70 |
3.65x |
TE TransformerLayer + FP8 |
22.81 |
3.80x |