JAX: Integrating TE into an existing framework
This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as MaxText’s TE integration or your own model framework.
Let’s start with a standard JAX+Flax Transformer layer
[1]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import quickstart_jax_utils as utils
from typing import Optional
[11]:
class FlaxMLP(nn.Module):
"""Feed-forward network in Transformer layer
Built with plain Flax modules.
"""
hidden_size: int
ffn_hidden_size: int
dot_general_cls: callable = lambda: None
@nn.compact
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)
x = nn.gelu(x, approximate=True) # equivalent to tanh approximation
x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)
return x
class FlaxTransformerLayer(nn.Module):
"""Basic Transformer layer using plain Flax modules"""
hidden_size: int
ffn_hidden_size: int
num_attention_heads: int
layernorm_eps: float = 1e-5
attention_dropout: float = 0.1
dot_general_cls: callable = lambda: None
def setup(self):
self.kv_channels = self.hidden_size // self.num_attention_heads
@nn.compact
def __call__(
self,
x: jnp.ndarray,
attention_mask: Optional[jnp.ndarray] = None,
deterministic: bool = False
) -> jnp.ndarray:
# Create causal mask if not provided
if attention_mask is None:
attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# Fused QKV projection
qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)
qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)
q, k, v = jnp.split(qkv, 3, axis=3)
# q, k, v now have shape [batch, seq_len, num_heads, kv_channels]
# which is the correct format for dot_product_attention
# Apply dot product attention
# Note: dot_product_attention expects mask to be broadcastable to
# [batch, num_heads, q_length, kv_length], but attention_mask from
# nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]
# Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)
dropout_rng = None
if not deterministic and self.attention_dropout > 0:
dropout_rng = self.make_rng('dropout')
# See quickstart_jax.ipynb for details on using TE's faster fused attention
x = nn.dot_product_attention(
query=q,
key=k,
value=v,
mask=attention_mask,
dropout_rng=dropout_rng,
dropout_rate=self.attention_dropout,
deterministic=deterministic,
broadcast_dropout=True,
)
# Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]
x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)
# Output projection
x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)
x = res + x
# Second residual connection
res = x
x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)
# MLP
mlp = FlaxMLP(
hidden_size=self.hidden_size,
ffn_hidden_size=self.ffn_hidden_size,
dot_general_cls=self.dot_general_cls,
)
x = mlp(x)
return x + res
We’ve exposed dot_general_cls here so we can test out different GEMM implementations later. By default, Flax’s nn.Dense will use JAX’s GEMM jax.lax.dot_general when dot_general is None.
Testing Performance
Now let’s test the performance of our FlaxTransformerLayer:
[3]:
# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = jnp.bfloat16
# Synthetic data
key, dropout_key = jax.random.split(jax.random.PRNGKey(42))
x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)
dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)
[4]:
# Initialize the FlaxTransformerLayer
flax_transformer = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
)
# Initialize parameters
params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)
print("Pure Flax FlaxTransformerLayer initialized successfully!")
print(f"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}")
Pure Flax FlaxTransformerLayer initialized successfully!
Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}
[5]:
# Example usage of forward pass
y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output dtype: {y.dtype}")
print("Forward pass completed successfully!")
Input shape: (4, 2048, 4096)
Output shape: (4, 2048, 4096)
Output dtype: float32
Forward pass completed successfully!
[6]:
import importlib
import quickstart_jax_utils
importlib.reload(quickstart_jax_utils)
utils.speedometer(
model_apply_fn=flax_transformer.apply,
variables=params,
input=x,
output_grad=dy,
forward_kwargs={"attention_mask": None, "deterministic": False},
rngs={"dropout": dropout_key},
)
Mean time: 18.83516788482666 ms
Transformer Engine
TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku. * Use Flax NNX and Linen together * Haiku and Flax interop
Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE’s dot_general_cls instead of the default Dense dot_general implementation. TE’s dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state.
Now we’ll select a recipe. DelayedScaling and CurrentScaling use per-tensor scaling and are supported on Hopper and Blackwell. MXFP8BlockScaling and NVFP4BlockScaling use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.
If you would like to customize the recipe further, various options can be changed by passing args to the recipe’s constructor.
[7]:
from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling
from transformer_engine.jax import flax as te_flax
# Choose a quantization recipe. This can be modified to any of the recipes imported above.
quantization_recipe = DelayedScaling()
te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)
rngs = {'dropout': dropout_key}
if isinstance(quantization_recipe, NVFP4BlockScaling):
# The NVFP4 recipe requires a Flax RNG for stochastic rounding
rngs['sr_rng'] = jax.random.PRNGKey(0)
Now using this quantized dense in our model is as simple as passing in dot_general_fn=te_dot_general. Let’s try it out!
Important: Remat Policy
TE’s quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as jax.checkpoint_policies.checkpoint_dots, please replace the policy with transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms or similar policies to ensure TE’s quantized GEMM primitives are checkpointed correctly.
If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.
[12]:
# Initialize the FlaxTransformerLayer
flax_transformer = FlaxTransformerLayer(
hidden_size=hidden_size,
ffn_hidden_size=ffn_hidden_size,
num_attention_heads=num_attention_heads,
dot_general_cls=te_dot_general_cls,
)
# Initialize parameters
var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)
print("Pure Flax FlaxTransformerLayer initialized successfully!")
print(f"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}")
print(f"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}")
Pure Flax FlaxTransformerLayer initialized successfully!
Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}
Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}
If using a recipe that stores additional state, such as DelayedScaling, you’ll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables var_collect across training steps, not just the model params, for proper usage of stateful recipes like DelayedScaling.
For example, above inside Additional state: you’ll see the amax_history of each quantization which is used to compute the per-tensor scale in the DelayedScaling recipe.
The reason we need te_dot_general_cls as a Flax module instead of a module-less function like jax.lax.dot_general is for some quantization recipes to track internal state separate from model parameters.
Flax modules can manage 3 things: 1. Model parameters/weights, e.g. your Dense “kernel”, “bias”, etc. 2. RNGs for dropout, stochastic rounding, etc. 3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don’t take gradients or optimize them. Currently, we only use this for DelayedScaling’s amax_history state
With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don’t need to worry about preserving the sharding, init distribution, etc.. So we don’t need point 1 since we don’t do model param creation in this codepath with dot_general_cls, but we still do need te_dot_general_cls() to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module.
[13]:
# Example usage of forward pass
y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output dtype: {y.dtype}")
print("Forward pass completed successfully!")
Input shape: (4, 2048, 4096)
Output shape: (4, 2048, 4096)
Output dtype: float32
Forward pass completed successfully!
Now let’s measure the performance!
[14]:
import importlib
import quickstart_jax_utils
importlib.reload(quickstart_jax_utils)
utils.speedometer(
model_apply_fn=flax_transformer.apply,
variables=var_collect,
input=x,
output_grad=dy,
forward_kwargs={"attention_mask": None, "deterministic": False},
rngs=rngs,
)
Mean time: 10.553865432739258 ms