{ "cells": [ { "cell_type": "markdown", "id": "da9fd6a8", "metadata": {}, "source": [ "# Getting Started\n", "\n", "## Overview\n", "\n", "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessy with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", "\n", "## Let's build a Transformer layer!\n", "\n", "
\n", "\n", "Summary\n", " \n", "We build a basic Transformer layer using regular PyTorch modules. This will be our baseline for later comparisons with Transformer Engine.\n", "\n", "
\n", "\n", "Let's start with creating a GPT encoder layer using plain PyTorch. Figure 1 shows the overall structure.\n", "\n", "
\n", "\n", "
Figure 1: Structure of a GPT encoder layer.
\n", "
\n", "\n", "We construct the components as follows:\n", "\n", "- `LayerNorm`: `torch.nn.LayerNorm`\n", "- `QKV Projection`: `torch.nn.Linear` (conceptually three `Linear` layers for Q, K, and V separately, but we fuse into a single `Linear` layer that is three times larger)\n", "- `DotProductAttention`: `DotProductAttention` from [quickstart_utils.py](quickstart_utils.py)\n", "- `Projection`: `torch.nn.Linear`\n", "- `Dropout`: `torch.nn.Dropout`\n", "- `MLP`: `BasicMLP` from [quickstart_utils.py](quickstart_utils.py)\n", "\n", "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_utils.py](quickstart_utils.py). Putting it all together:" ] }, { "cell_type": "code", "execution_count": 1, "id": "2be43d64", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import quickstart_utils as utils\n", "\n", "class BasicTransformerLayer(torch.nn.Module):\n", " def __init__(\n", " self,\n", " hidden_size: int,\n", " ffn_hidden_size: int,\n", " num_attention_heads: int,\n", " layernorm_eps: int = 1e-5,\n", " attention_dropout: float = 0.1,\n", " hidden_dropout: float = 0.1,\n", " ):\n", " super().__init__()\n", " self.num_attention_heads = num_attention_heads\n", " self.kv_channels = hidden_size // num_attention_heads\n", " self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", " self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)\n", " self.attention = utils.DotProductAttention(\n", " num_attention_heads=num_attention_heads,\n", " kv_channels=self.kv_channels,\n", " attention_dropout=attention_dropout,\n", " )\n", " self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)\n", " self.dropout = torch.nn.Dropout(hidden_dropout)\n", " self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", " self.mlp = utils.BasicMLP(\n", " hidden_size=hidden_size,\n", " ffn_hidden_size=ffn_hidden_size,\n", " ) \n", " \n", " def forward(\n", " self, \n", " x: torch.Tensor, \n", " attention_mask: torch.Tensor\n", " ) -> torch.Tensor:\n", " res = x\n", " x = self.ln1(x)\n", " \n", " # Fused QKV projection\n", " qkv = self.qkv_projection(x)\n", " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", " \n", " x = self.attention(q, k, v, attention_mask)\n", " x = self.projection(x)\n", " x = self.dropout(x)\n", " x = res + x\n", " res = x\n", " x = self.ln2(x)\n", " x = self.mlp(x)\n", " \n", " return x + res" ] }, { "cell_type": "markdown", "id": "40724d1d", "metadata": {}, "source": [ "That's it! We now have a simple Transformer layer. We can test it:" ] }, { "cell_type": "code", "execution_count": 2, "id": "a786f0ea", "metadata": {}, "outputs": [], "source": [ "# Layer configuration\n", "hidden_size = 4096\n", "sequence_length = 2048\n", "batch_size = 4\n", "ffn_hidden_size = 16384\n", "num_attention_heads = 32\n", "dtype = torch.float16\n", "\n", "# Synthetic data\n", "x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n", "dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)" ] }, { "cell_type": "code", "execution_count": 3, "id": "ffdbfb7a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BasicTransformerLayer(\n", " (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", " (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)\n", " (attention): DotProductAttention(\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " )\n", " (projection): Linear(in_features=4096, out_features=4096, bias=True)\n", " (dropout): Dropout(p=0.1, inplace=False)\n", " (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", " (mlp): BasicMLP(\n", " (linear1): Linear(in_features=4096, out_features=16384, bias=True)\n", " (linear2): Linear(in_features=16384, out_features=4096, bias=True)\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "basic_transformer = BasicTransformerLayer(\n", " hidden_size,\n", " ffn_hidden_size,\n", " num_attention_heads,\n", ")\n", "basic_transformer.to(dtype=dtype).cuda()" ] }, { "cell_type": "code", "execution_count": 4, "id": "0162ad40", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(1234)\n", "y = basic_transformer(x, attention_mask=None)" ] }, { "cell_type": "code", "execution_count": 5, "id": "65ae6dd6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 43.0663916015625 ms\n" ] } ], "source": [ "utils.speedometer(\n", " basic_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", ")" ] }, { "cell_type": "markdown", "id": "43717e36", "metadata": {}, "source": [ "## Meet Transformer Engine\n", "\n", "
\n", "\n", "Summary\n", " \n", "We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.\n", "\n", "
\n", "\n", "Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. " ] }, { "cell_type": "code", "execution_count": 6, "id": "004d3c92", "metadata": {}, "outputs": [], "source": [ "import transformer_engine.pytorch as te" ] }, { "cell_type": "markdown", "id": "1931f911", "metadata": {}, "source": [ "TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:" ] }, { "cell_type": "code", "execution_count": 7, "id": "1f44db50", "metadata": {}, "outputs": [], "source": [ "class BasicTEMLP(torch.nn.Module):\n", " def __init__(self,\n", " hidden_size: int,\n", " ffn_hidden_size: int) -> None:\n", " super().__init__()\n", " self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)\n", " self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)\n", "\n", " def forward(self, x):\n", " x = self.linear1(x)\n", " x = torch.nn.functional.gelu(x, approximate='tanh')\n", " x = self.linear2(x)\n", " return x \n", " \n", "class BasicTETransformerLayer(torch.nn.Module):\n", " def __init__(self,\n", " hidden_size: int,\n", " ffn_hidden_size: int,\n", " num_attention_heads: int,\n", " layernorm_eps: int = 1e-5,\n", " attention_dropout: float = 0.1,\n", " hidden_dropout: float = 0.1):\n", " super().__init__()\n", " self.num_attention_heads = num_attention_heads\n", " self.kv_channels = hidden_size // num_attention_heads\n", " self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", " self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)\n", " self.attention = utils.DotProductAttention(\n", " num_attention_heads=num_attention_heads,\n", " kv_channels=self.kv_channels,\n", " attention_dropout=attention_dropout,\n", " )\n", " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", " self.dropout = torch.nn.Dropout(hidden_dropout)\n", " self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", " self.mlp = BasicTEMLP(\n", " hidden_size=hidden_size,\n", " ffn_hidden_size=ffn_hidden_size,\n", " )\n", " \n", " def forward(self, \n", " x: torch.Tensor, \n", " attention_mask: torch.Tensor):\n", " res = x\n", " x = self.ln1(x)\n", " \n", " # Fused QKV projection\n", " qkv = self.qkv_projection(x)\n", " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", " \n", " x = self.attention(q, k, v, attention_mask)\n", " x = self.projection(x)\n", " x = self.dropout(x)\n", " x = res + x\n", " res = x\n", " x = self.ln2(x)\n", " x = self.mlp(x)\n", " \n", " return x + res" ] }, { "cell_type": "code", "execution_count": 8, "id": "916531e8", "metadata": {}, "outputs": [], "source": [ "basic_te_transformer = BasicTETransformerLayer(\n", " hidden_size, \n", " ffn_hidden_size, \n", " num_attention_heads,\n", ")\n", "basic_te_transformer.to(dtype=dtype).cuda()\n", "utils.share_parameters_with_basic_te_model(basic_te_transformer, basic_transformer)" ] }, { "cell_type": "code", "execution_count": 9, "id": "3643fa54", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(1234)\n", "y = basic_te_transformer(x, attention_mask=None)" ] }, { "cell_type": "code", "execution_count": 10, "id": "10b92894", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 43.1413232421875 ms\n" ] } ], "source": [ "utils.speedometer(\n", " basic_te_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", ")" ] }, { "cell_type": "markdown", "id": "3f990226", "metadata": {}, "source": [ "## Fused TE Modules\n", "\n", "
\n", "\n", "Summary\n", " \n", "We optimize the example Transformer layer with TE modules for fused operations.\n", "\n", "
\n", "\n", "The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.\n", "\n", "Transformer Engine therefore provides coarser modules that span multiple layers:\n", "\n", "* `LayerNormLinear`\n", "* `LayerNormMLP`\n", "* `TransformerLayer`\n", "\n", "Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:" ] }, { "cell_type": "code", "execution_count": 11, "id": "c55eae1f", "metadata": {}, "outputs": [], "source": [ "class FusedTETransformerLayer(torch.nn.Module):\n", " def __init__(self,\n", " hidden_size: int,\n", " ffn_hidden_size: int,\n", " num_attention_heads: int,\n", " layernorm_eps: int = 1e-5,\n", " attention_dropout: float = 0.1,\n", " hidden_dropout: float = 0.1):\n", " super().__init__()\n", " self.num_attention_heads = num_attention_heads\n", " self.kv_channels = hidden_size // num_attention_heads\n", " self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)\n", " self.attention = utils.DotProductAttention(\n", " num_attention_heads=num_attention_heads,\n", " kv_channels=self.kv_channels,\n", " attention_dropout=attention_dropout,\n", " )\n", " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", " self.dropout = torch.nn.Dropout(hidden_dropout)\n", " self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)\n", " \n", " \n", " def forward(self, \n", " x: torch.Tensor, \n", " attention_mask: torch.Tensor):\n", " res = x\n", " qkv = self.ln_qkv(x)\n", " \n", " # Split qkv into query, key and value\n", " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", " \n", " x = self.attention(q, k, v, attention_mask)\n", " x = self.projection(x)\n", " x = self.dropout(x)\n", " x = res + x\n", " res = x\n", " x = self.ln_mlp(x)\n", " \n", " return x + res" ] }, { "cell_type": "code", "execution_count": 12, "id": "85949421", "metadata": {}, "outputs": [], "source": [ "fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", "fused_te_transformer.to(dtype=dtype).cuda()\n", "utils.share_parameters_with_fused_te_model(fused_te_transformer, basic_transformer)" ] }, { "cell_type": "code", "execution_count": 13, "id": "2c263e71", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(1234)\n", "y = fused_te_transformer(x, attention_mask=None)" ] }, { "cell_type": "code", "execution_count": 14, "id": "24e101bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 43.1981201171875 ms\n" ] } ], "source": [ "utils.speedometer(\n", " fused_te_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", ")" ] }, { "cell_type": "markdown", "id": "33f13c26", "metadata": {}, "source": [ "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:" ] }, { "cell_type": "code", "execution_count": 15, "id": "ec8c3685", "metadata": {}, "outputs": [], "source": [ "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", "te_transformer.to(dtype=dtype).cuda()\n", "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)" ] }, { "cell_type": "code", "execution_count": 16, "id": "e48cd590", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(1234)\n", "y = te_transformer(x, attention_mask=None)" ] }, { "cell_type": "code", "execution_count": 17, "id": "3ec3707d-e63f-4899-8308-b11c55b5caa4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 39.99169921875 ms\n" ] } ], "source": [ "utils.speedometer(\n", " te_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", ")" ] }, { "cell_type": "markdown", "id": "4034c3eb-8958-49f2-85f6-30c94977d884", "metadata": {}, "source": [ "## Enabling FP8\n", "\n", "
\n", "\n", "Summary\n", " \n", "We configure a TE module to perform compute in FP8.\n", "\n", "
\n", "\n", "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." ] }, { "cell_type": "code", "execution_count": 18, "id": "31256aa7-3d5e-425c-91ab-502b1326a748", "metadata": {}, "outputs": [], "source": [ "from transformer_engine.common.recipe import Format, DelayedScaling\n", "\n", "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", "te_transformer.to(dtype=dtype).cuda()\n", "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)\n", "\n", "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "torch.manual_seed(1234)\n", "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", " y = te_transformer(x, attention_mask=None)" ] }, { "cell_type": "code", "execution_count": 19, "id": "793ebd2d-b84b-47bc-811a-7991df8500aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 28.61394775390625 ms\n" ] } ], "source": [ "utils.speedometer(\n", " te_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", ")" ] } ], "metadata": { "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 5 }