{ "cells": [ { "cell_type": "markdown", "id": "7b3e6954", "metadata": {}, "source": [ "# Using FP8 with Transformer Engine\n", "\n", "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.\n", "\n", "## Introduction to FP8\n", "\n", "### Structure\n", "\n", "The FP8 datatype supported by H100 is actually 2 distinct datatypes, useful in different parts of the training of neural networks:\n", "\n", "* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n", "* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n", "\n", "\n", "\n", "During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.\n", "\n", "### Mixed precision training - a quick introduction\n", "\n", "In order to understand how FP8 can be used for training Deep Learning models, it is useful to first remind ourselves how mixed precision works with other datatypes, especially FP16.\n", "\n", "Mixed precision recipe for FP16 training has 2 components: choosing which operations should be performed in FP16 and dynamic loss scaling.\n", "\n", "* Choosing the operations to be performed in FP16 precision requires analysis of the numerical behavior of the outputs with respect to inputs of the operation as well as the expected performance benefit. This enables marking operations like matrix multiplies, convolutions and normalization layers as safe, while leaving `norm` or `exp` operations as requiring high precision.\n", "* Dynamic loss scaling enables avoiding both over- and underflows of the gradients during training. Those may happen since, while the dynamic range of FP16 is enough to store the distribution of the gradient values, this distribution may be centered around values too high or too low for FP16 to handle. Scaling the loss shifts those distributions (without affecting numerics by using only powers of 2) into the range representable in FP16. \n", "\n", "\n", "\n", "### Mixed precision training with FP8\n", "\n", "While the dynamic range provided by the FP8 types is sufficient to store any particular activation or gradient, it is not sufficient for all of them at the same time. This makes the single loss scaling factor strategy, which worked for FP16, infeasible for FP8 training and instead requires using distinct scaling factors for each FP8 tensor.\n", "\n", "There are multiple strategies for choosing a scaling factor that is appropriate for a given FP8 tensor:\n", "\n", "* just-in-time scaling. This strategy chooses the scaling factor based on the maximum of absolute values (amax) of the tensor being produced. In practice it is infeasible, as it requires multiple passes through data - the operator produces and writes out the output in higher precision, then the maximum absolute value of the output is found and applied to all values in order to obtain the final FP8 output. This results in a lot of overhead, severely diminishing gains from using FP8.\n", "* delayed scaling. This strategy chooses the scaling factor based on the maximums of absolute values seen in some number of previous iterations. This enables full performance of FP8 computation, but requires storing the history of maximums as additional parameters of the FP8 operators. \n", "\n", "\n", "\n", "As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration." ] }, { "cell_type": "markdown", "id": "cf5e0b0d", "metadata": {}, "source": [ "## Using FP8 with Transformer Engine\n", "\n", "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using delayed scaling strategy.\n", "\n", "### FP8 recipe\n", "\n", "[DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from `transformer_engine.common.recipe` module stores all of the required options for FP8 training - length of the amax history to use for scaling factor computation, FP8 data format etc." ] }, { "cell_type": "code", "execution_count": 1, "id": "0c8fd0ef", "metadata": {}, "outputs": [], "source": [ "from transformer_engine.common.recipe import Format, DelayedScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" ] }, { "cell_type": "markdown", "id": "f9591eb5", "metadata": {}, "source": [ "This recipe is then used to configure the FP8 training." ] }, { "cell_type": "markdown", "id": "734d3934", "metadata": {}, "source": [ "### FP8 autocasting\n", "\n", "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." ] }, { "cell_type": "code", "execution_count": 2, "id": "f8b1ff7f", "metadata": {}, "outputs": [], "source": [ "import transformer_engine.pytorch as te\n", "import torch\n", "\n", "torch.manual_seed(12345)\n", "\n", "my_linear = te.Linear(768, 768, bias=True)\n", "\n", "inp = torch.rand((1024, 768)).cuda()\n", "\n", "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", " out_fp8 = my_linear(inp)" ] }, { "cell_type": "markdown", "id": "e41161f1", "metadata": {}, "source": [ "The `fp8_autocast` context manager hides the complexity of handling FP8:\n", "\n", "- All FP8-safe operations have their inputs cast to FP8\n", "- Amax history is updated\n", "- New scaling factors are computed and ready for the next iteration\n", "\n", "