Using FP8 with Transformer Engine

H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.

Introduction to FP8

Structure

The FP8 datatype supported by H100 is actually 2 distinct datatypes, useful in different parts of the training of neural networks:

  • E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and nan.

  • E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- inf and nan. The tradeoff of the increased dynamic range is lower precision of the stored values.

7778468d001c45eaaeb5d76b0312dd93

Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.

During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.

Mixed precision training - a quick introduction

In order to understand how FP8 can be used for training Deep Learning models, it is useful to first remind ourselves how mixed precision works with other datatypes, especially FP16.

Mixed precision recipe for FP16 training has 2 components: choosing which operations should be performed in FP16 and dynamic loss scaling.

  • Choosing the operations to be performed in FP16 precision requires analysis of the numerical behavior of the outputs with respect to inputs of the operation as well as the expected performance benefit. This enables marking operations like matrix multiplies, convolutions and normalization layers as safe, while leaving norm or exp operations as requiring high precision.

  • Dynamic loss scaling enables avoiding both over- and underflows of the gradients during training. Those may happen since, while the dynamic range of FP16 is enough to store the distribution of the gradient values, this distribution may be centered around values too high or too low for FP16 to handle. Scaling the loss shifts those distributions (without affecting numerics by using only powers of 2) into the range representable in FP16.

db65585f149c450cae8639ad74d032ed

Figure 2: Scaling the loss enables shifting the gradient distribution into the representable range of FP16 datatype.

Mixed precision training with FP8

While the dynamic range provided by the FP8 types is sufficient to store any particular activation or gradient, it is not sufficient for all of them at the same time. This makes the single loss scaling factor strategy, which worked for FP16, infeasible for FP8 training and instead requires using distinct scaling factors for each FP8 tensor.

There are multiple strategies for choosing a scaling factor that is appropriate for a given FP8 tensor:

  • just-in-time scaling. This strategy chooses the scaling factor based on the maximum of absolute values (amax) of the tensor being produced. In practice it is infeasible, as it requires multiple passes through data - the operator produces and writes out the output in higher precision, then the maximum absolute value of the output is found and applied to all values in order to obtain the final FP8 output. This results in a lot of overhead, severely diminishing gains from using FP8.

  • delayed scaling. This strategy chooses the scaling factor based on the maximums of absolute values seen in some number of previous iterations. This enables full performance of FP8 computation, but requires storing the history of maximums as additional parameters of the FP8 operators.

dad886d4c4ef4951897dccb1061e5204

Figure 3: Delayed scaling strategy. The FP8 operator uses scaling factor obtained using the history of amaxes (maximums of absolute values) seen in some number of previous iterations and produces both the FP8 output and the current amax, which gets stored in the history.

As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration.

MXFP8 and block scaling

NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: MXFP8.

MXFP8 vs FP8

The main difference between “regular” FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to “fit” within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).

MXFP8 addresses this by assigning a different scaling factor to each block of 32 consecutive values. This allows all values to be represented with the E4M3 datatype.

22945ca768c34c448ad92359add50ee5

Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.

4efa92e96d744e62a952a3c721c97ea4

Figure 5: Due to multiple scaling factors, tensor’s dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.

The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).

cdb4b3f1e8554abf8aef29b01ffb3787

Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.

Handling transposes

The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be “consecutive” over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.

To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.

bef7aeeb815041f290a63b46f8b598bf

Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.

Using FP8 with Transformer Engine

Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.

FP8 recipe

The DelayedScaling recipe from the transformer_engine.common.recipe module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc. Similarly, MXFP8BlockScaling from the same module may be used to enable MXFP8 training.

[1]:
from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling

fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
mxfp8_format = Format.E4M3  # E4M3 used everywhere
mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)

This recipe is then used to configure the FP8 training.

FP8 autocasting

Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the fp8_autocast context manager.

[2]:
import transformer_engine.pytorch as te
import torch

torch.manual_seed(12345)

my_linear = te.Linear(768, 768, bias=True)

inp = torch.rand((1024, 768)).cuda()

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out_fp8 = my_linear(inp)

The fp8_autocast context manager hides the complexity of handling FP8:

  • All FP8-safe operations have their inputs cast to FP8

  • Amax history is updated

  • New scaling factors are computed and ready for the next iteration

Note

Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.

Handling backward pass

When a model is run inside the fp8_autocast region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, fp8_autocast context manager aggregates the tensors before performing the communication.

Due to this aggregation the backward call needs to happen outside of the fp8_autocast context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass.

[3]:
loss_fp8 = out_fp8.mean()

loss_fp8.backward()  # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast

out_fp32 = my_linear(inp)
loss_fp32 = out_fp32.mean()
loss_fp32.backward()  # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast

Precision

If we compare the results of the FP32 and FP8 execution, we will see that they are relatively close, but different:

[4]:
out_fp8
[4]:
tensor([[ 0.2276,  0.2627,  0.3001,  ...,  0.0346,  0.2211,  0.1188],
        [-0.0963, -0.3725,  0.1717,  ...,  0.0901,  0.0522, -0.3472],
        [ 0.4526,  0.3482,  0.5976,  ..., -0.0687, -0.0382,  0.1566],
        ...,
        [ 0.1698,  0.6061,  0.0385,  ..., -0.2875, -0.1152, -0.0260],
        [ 0.0679,  0.2946,  0.2751,  ..., -0.2284,  0.0517, -0.1441],
        [ 0.1865,  0.2353,  0.9172,  ...,  0.1085,  0.1135,  0.1438]],
       device='cuda:0', grad_fn=<_LinearBackward>)
[5]:
out_fp32
[5]:
tensor([[ 0.2373,  0.2674,  0.2980,  ...,  0.0233,  0.2498,  0.1131],
        [-0.0767, -0.3778,  0.1862,  ...,  0.0858,  0.0676, -0.3369],
        [ 0.4615,  0.3593,  0.5813,  ..., -0.0779, -0.0349,  0.1422],
        ...,
        [ 0.1914,  0.6038,  0.0382,  ..., -0.2847, -0.0991, -0.0423],
        [ 0.0864,  0.2895,  0.2719,  ..., -0.2388,  0.0772, -0.1541],
        [ 0.2019,  0.2275,  0.9027,  ...,  0.1022,  0.1300,  0.1444]],
       device='cuda:0', grad_fn=<_LinearBackward>)

That happens because in the FP8 case both the input and weights are cast to FP8 before the computation. We can see this if instead of the original inputs we use the inputs representable in FP8 (using a function defined in quickstart_utils.py):

[6]:
from quickstart_utils import cast_to_representable

inp_representable = cast_to_representable(inp)
my_linear.weight.data = cast_to_representable(my_linear.weight.data)

out_fp32_representable = my_linear(inp_representable)

print(out_fp32_representable)
tensor([[ 0.2276,  0.2629,  0.3000,  ...,  0.0346,  0.2211,  0.1188],
        [-0.0963, -0.3724,  0.1717,  ...,  0.0901,  0.0522, -0.3470],
        [ 0.4526,  0.3479,  0.5976,  ..., -0.0686, -0.0382,  0.1566],
        ...,
        [ 0.1698,  0.6062,  0.0385,  ..., -0.2876, -0.1152, -0.0260],
        [ 0.0679,  0.2947,  0.2750,  ..., -0.2284,  0.0516, -0.1441],
        [ 0.1865,  0.2353,  0.9170,  ...,  0.1085,  0.1135,  0.1438]],
       device='cuda:0', grad_fn=<_LinearBackward>)

This time the difference is really small:

[7]:
out_fp8 - out_fp32_representable
[7]:
tensor([[ 4.9591e-05, -1.9073e-04,  9.5367e-05,  ..., -3.8147e-06,
          4.1962e-05,  2.2888e-05],
        [ 2.2888e-05, -3.4332e-05,  2.2888e-05,  ...,  2.6703e-05,
          5.3406e-05, -1.4114e-04],
        [-3.8147e-05,  2.6703e-04, -3.8147e-06,  ..., -5.7220e-05,
          4.1962e-05, -1.9073e-05],
        ...,
        [ 1.1444e-05, -7.2479e-05, -3.8147e-06,  ...,  5.3406e-05,
         -1.5259e-05,  2.2888e-05],
        [ 4.9591e-05, -9.5367e-05,  6.8665e-05,  ..., -1.5259e-05,
          7.6294e-05,  4.5776e-05],
        [-1.5259e-05, -7.6294e-06,  1.8692e-04,  ..., -3.0518e-05,
         -4.5776e-05,  7.6294e-06]], device='cuda:0', grad_fn=<SubBackward0>)

The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model.