{ "cells": [ { "cell_type": "markdown", "id": "962d87bb", "metadata": {}, "source": [ "\n", "\n", "# JAX: Integrating TE into an existing framework\n", "\n", "This tutorial will cover how to integrate TransformerEngine into an existing JAX model framework, such as [MaxText's TE integration](https://github.com/AI-Hypercomputer/maxtext/blob/ed517cf80d9aa81f76e236c5516dacebfe39e96d/src/MaxText/layers/quantizations.py#L753) or your own model framework. \n" ] }, { "cell_type": "markdown", "id": "b36876bb", "metadata": {}, "source": [ "Let's start with a standard JAX+Flax Transformer layer" ] }, { "cell_type": "code", "execution_count": 1, "id": "d5284a38", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "from flax import linen as nn\n", "import quickstart_jax_utils as utils\n", "from typing import Optional" ] }, { "cell_type": "code", "execution_count": 11, "id": "a4d1cfdc", "metadata": {}, "outputs": [], "source": [ "class FlaxMLP(nn.Module):\n", " \"\"\"Feed-forward network in Transformer layer\n", " Built with plain Flax modules.\n", " \"\"\"\n", " hidden_size: int\n", " ffn_hidden_size: int\n", " dot_general_cls: callable = lambda: None\n", "\n", " @nn.compact\n", " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", " x = nn.Dense(features=self.ffn_hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", " x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n", " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", " return x\n", "\n", "class FlaxTransformerLayer(nn.Module):\n", " \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n", " hidden_size: int\n", " ffn_hidden_size: int\n", " num_attention_heads: int\n", " layernorm_eps: float = 1e-5\n", " attention_dropout: float = 0.1\n", " dot_general_cls: callable = lambda: None\n", " \n", " def setup(self):\n", " self.kv_channels = self.hidden_size // self.num_attention_heads\n", "\n", " @nn.compact\n", " def __call__(\n", " self, \n", " x: jnp.ndarray, \n", " attention_mask: Optional[jnp.ndarray] = None,\n", " deterministic: bool = False\n", " ) -> jnp.ndarray:\n", " # Create causal mask if not provided\n", " if attention_mask is None:\n", " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", " \n", " res = x\n", " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", " \n", " # Fused QKV projection\n", " qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", " q, k, v = jnp.split(qkv, 3, axis=3)\n", " \n", " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n", " # which is the correct format for dot_product_attention\n", " \n", " # Apply dot product attention\n", " # Note: dot_product_attention expects mask to be broadcastable to \n", " # [batch, num_heads, q_length, kv_length], but attention_mask from \n", " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n", " \n", " # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n", " dropout_rng = None\n", " if not deterministic and self.attention_dropout > 0:\n", " dropout_rng = self.make_rng('dropout')\n", " \n", " # See quickstart_jax.ipynb for details on using TE's faster fused attention\n", " x = nn.dot_product_attention(\n", " query=q,\n", " key=k,\n", " value=v,\n", " mask=attention_mask,\n", " dropout_rng=dropout_rng,\n", " dropout_rate=self.attention_dropout,\n", " deterministic=deterministic,\n", " broadcast_dropout=True,\n", " )\n", " \n", " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n", " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n", "\n", " # Output projection\n", " x = nn.Dense(features=self.hidden_size, use_bias=True, dot_general=self.dot_general_cls())(x)\n", " \n", " x = res + x\n", " \n", " # Second residual connection\n", " res = x\n", " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", " \n", " # MLP\n", " mlp = FlaxMLP(\n", " hidden_size=self.hidden_size,\n", " ffn_hidden_size=self.ffn_hidden_size,\n", " dot_general_cls=self.dot_general_cls,\n", " )\n", " x = mlp(x)\n", " \n", " return x + res\n" ] }, { "cell_type": "markdown", "id": "db16bf70", "metadata": {}, "source": [ "We've exposed `dot_general_cls` here so we can test out different GEMM implementations later. By default, Flax's `nn.Dense` will use JAX's GEMM `jax.lax.dot_general` when `dot_general` is `None`." ] }, { "cell_type": "markdown", "id": "fbc3510b", "metadata": {}, "source": [ "## Testing Performance\n", "\n", "Now let's test the performance of our FlaxTransformerLayer:\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "8b44649d", "metadata": {}, "outputs": [], "source": [ "# Layer configuration\n", "hidden_size = 4096\n", "sequence_length = 2048\n", "batch_size = 4\n", "ffn_hidden_size = 16384\n", "num_attention_heads = 32\n", "dtype = jnp.bfloat16\n", "\n", "# Synthetic data\n", "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n", "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n", "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "e44ed26d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pure Flax FlaxTransformerLayer initialized successfully!\n", "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" ] } ], "source": [ "# Initialize the FlaxTransformerLayer\n", "flax_transformer = FlaxTransformerLayer(\n", " hidden_size=hidden_size,\n", " ffn_hidden_size=ffn_hidden_size,\n", " num_attention_heads=num_attention_heads,\n", ")\n", "\n", "# Initialize parameters\n", "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", "\n", "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "de91af7a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: (4, 2048, 4096)\n", "Output shape: (4, 2048, 4096)\n", "Output dtype: float32\n", "Forward pass completed successfully!\n" ] } ], "source": [ "# Example usage of forward pass\n", "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n", "print(f\"Input shape: {x.shape}\")\n", "print(f\"Output shape: {y.shape}\")\n", "print(f\"Output dtype: {y.dtype}\")\n", "print(\"Forward pass completed successfully!\")\n" ] }, { "cell_type": "code", "execution_count": 6, "id": "037bc8d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 18.83516788482666 ms\n" ] } ], "source": [ "import importlib\n", "import quickstart_jax_utils\n", "importlib.reload(quickstart_jax_utils)\n", "\n", "utils.speedometer(\n", " model_apply_fn=flax_transformer.apply,\n", " variables=params,\n", " input=x,\n", " output_grad=dy,\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " rngs={\"dropout\": dropout_key},\n", ")" ] }, { "cell_type": "markdown", "id": "5e9310c9", "metadata": {}, "source": [ "## Transformer Engine" ] }, { "cell_type": "markdown", "id": "1f8e213e", "metadata": {}, "source": [ "TransformerEngine/JAX is currently using Flax Linen. However, it is easily compatible with Flax NNX or Haiku.\n", "* [Use Flax NNX and Linen together](https://flax.readthedocs.io/en/latest/guides/bridge_guide.html)\n", "* [Haiku and Flax interop](https://dm-haiku.readthedocs.io/en/latest/notebooks/flax.html)\n", "\n", "Additionally, with the tutorial below, no model parameters need to be managed by TransformerEngine. You can keep all your existing model parameters, initialization, and sharding the same. The only change required is to call TE's dot_general_cls instead of the default Dense dot_general implementation. TE's dot_general_cls is a small module that performs a quantized dense VJP and stores some small recipe-specific state." ] }, { "cell_type": "markdown", "id": "4477d4e9", "metadata": {}, "source": [ "Now we'll select a recipe. `DelayedScaling` and `CurrentScaling` use per-tensor scaling and are supported on Hopper and Blackwell. `MXFP8BlockScaling` and `NVFP4BlockScaling` use block scaling or a combination of both per-tensor and block scaling and are supported on Blackwell.\n", "\n", "If you would like to customize the recipe further, various options can be changed by passing args to the recipe's constructor." ] }, { "cell_type": "code", "execution_count": 7, "id": "5ddf41e7", "metadata": {}, "outputs": [], "source": [ "from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, MXFP8BlockScaling, NVFP4BlockScaling\n", "from transformer_engine.jax import flax as te_flax \n", "\n", "# Choose a quantization recipe. This can be modified to any of the recipes imported above.\n", "quantization_recipe = DelayedScaling()\n", "\n", "te_dot_general_cls = te_flax.make_dot_general_cls(quantization_recipe)\n", "\n", "rngs = {'dropout': dropout_key}\n", "if isinstance(quantization_recipe, NVFP4BlockScaling):\n", " # The NVFP4 recipe requires a Flax RNG for stochastic rounding\n", " rngs['sr_rng'] = jax.random.PRNGKey(0)\n" ] }, { "cell_type": "markdown", "id": "c8769655", "metadata": {}, "source": [ "Now using this quantized dense in our model is as simple as passing in `dot_general_fn=te_dot_general`. Let's try it out!\n", "\n", "
\n", "\n", "Important: Remat Policy\n", "\n", "TE's quantization uses specialized TE quantized GEMM primitives. If you are using any built-in JAX checkpoint policies that look for JAX GEMMs (dots), such as `jax.checkpoint_policies.checkpoint_dots`, please replace the policy with `transformer_engine.jax.checkpoint_policies.checkpoint_dots_and_te_gemms` or similar policies to ensure TE's quantized GEMM primitives are checkpointed correctly.\n", "\n", "If this is not performed, TE GEMMs will be rematerialized introducing an incorrect performance comparison.\n", "\n", "
" ] }, { "cell_type": "code", "execution_count": 12, "id": "8407d2ea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Pure Flax FlaxTransformerLayer initialized successfully!\n", "Parameter shapes: {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}\n", "Additional state: {'_overwrite_with_gradient': {'FlaxMLP_0': {'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}, 'TEWrapper_dot_general_0': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}, 'TEWrapper_dot_general_1': {'grad_amax_history': (1024,), 'grad_scale': (1,), 'kernel_amax_history': (1024,), 'kernel_scale': (1,), 'x_amax_history': (1024,), 'x_scale': (1,)}}}\n" ] } ], "source": [ "# Initialize the FlaxTransformerLayer\n", "flax_transformer = FlaxTransformerLayer(\n", " hidden_size=hidden_size,\n", " ffn_hidden_size=ffn_hidden_size,\n", " num_attention_heads=num_attention_heads,\n", " dot_general_cls=te_dot_general_cls,\n", ")\n", "\n", "# Initialize parameters\n", "var_collect = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", "\n", "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, var_collect['params'])}\")\n", "print(f\"Additional state: {jax.tree_util.tree_map(lambda x: x.shape, {k: v for k, v in var_collect.items() if k != 'params'})}\")" ] }, { "cell_type": "markdown", "id": "abe27237", "metadata": {}, "source": [ "If using a recipe that stores additional state, such as `DelayedScaling`, you'll see this additional state stored as Flax variables. It is important to maintain and pass the whole state of Flax variables `var_collect` across training steps, not just the model params, for proper usage of stateful recipes like `DelayedScaling`.\n", "\n", "For example, above inside `Additional state: ` you'll see the `amax_history` of each quantization which is used to compute the per-tensor scale in the `DelayedScaling` recipe." ] }, { "cell_type": "markdown", "id": "5ab72935", "metadata": {}, "source": [ "The reason we need `te_dot_general_cls` as a Flax module instead of a module-less function like `jax.lax.dot_general` is for some quantization recipes to track internal state separate from model parameters.\n", "\n", "Flax modules can manage 3 things:\n", "1. Model parameters/weights, e.g. your Dense \"kernel\", \"bias\", etc.\n", "2. RNGs for dropout, stochastic rounding, etc.\n", "3. Flax variables. These are additional state variables that are used across training steps but are distinct from model params in that you don't take gradients or optimize them. Currently, we only use this for DelayedScaling's amax_history state\n", "\n", "With the simplest quantization integration shown in this tutorial, we want users to keep their existing model param setup so they don't need to worry about preserving the sharding, init distribution, etc.. So we don't need point 1 since we don't do model param creation in this codepath with dot_general_cls, but we still do need `te_dot_general_cls()` to produce a Flax module since we potentially need to do points 2 or 3 which need to be in a Flax module." ] }, { "cell_type": "code", "execution_count": 13, "id": "3b6b344b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: (4, 2048, 4096)\n", "Output shape: (4, 2048, 4096)\n", "Output dtype: float32\n", "Forward pass completed successfully!\n" ] } ], "source": [ "# Example usage of forward pass\n", "y = flax_transformer.apply(var_collect, x, attention_mask=None, deterministic=True, rngs=rngs)\n", "print(f\"Input shape: {x.shape}\")\n", "print(f\"Output shape: {y.shape}\")\n", "print(f\"Output dtype: {y.dtype}\")\n", "print(\"Forward pass completed successfully!\")\n" ] }, { "cell_type": "markdown", "id": "d178f247", "metadata": {}, "source": [ "Now let's measure the performance!" ] }, { "cell_type": "code", "execution_count": 14, "id": "5cc6c2a7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 10.553865432739258 ms\n" ] } ], "source": [ "import importlib\n", "import quickstart_jax_utils\n", "importlib.reload(quickstart_jax_utils)\n", "\n", "utils.speedometer(\n", " model_apply_fn=flax_transformer.apply,\n", " variables=var_collect,\n", " input=x,\n", " output_grad=dy,\n", " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", " rngs=rngs,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 5 }