Performance Optimizations

This guide is a follow-up to the discussion in the quickstart guide. We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in quickstart_utils.py.

[1]:
import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
import quickstart_utils as utils

# Layer configuration
hidden_size = 4096
sequence_length = 2048
batch_size = 4
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.float16

# Synthetic data
x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
[2]:
# Construct layer
basic_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
)
basic_transformer.to(dtype=dtype).cuda()

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(
    fp8_format=fp8_format,
    amax_history_len=16,
    amax_compute_algo="max",
)
# Training step
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = basic_transformer(x, attention_mask=None)
y.backward(dy)

# Measure step time
utils.speedometer(
    basic_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe },
)
Mean time: 27.82952880859375 ms

Multi-GPU training

Summary

We parallelize a Transformer layer with data, tensor, and sequence parallelism.

A variety of parallelism strategies can be used to enable multi-GPU training of Transformer models, often based on different approaches to distribute their \(\text{sequence_length} \times \text{batch_size} \times \text{hidden_size}\) activation tensors. The most common approach is data parallelism, which distributes along the \(\text{batch_size}\) dimension. By storing duplicate copies of the model on each GPU, the forward and backward passes of the training step can be done independently, followed by a gradient synchronization. A more advanced strategy is tensor parallelism, a type of model parallelism that distributes along the \(\text{hidden_size}\) dimension. This allows us to scale past the limits of data parallelism (typically \(\text{hidden_size} > \text{batch_size}\)) and to reduce the per-GPU memory usage (since model parameters are also distributed), but it also incurs the overhead of communicating activation tensors between GPUs at every step. For a more detailed explanation, please see the Megatron-LM paper. Finally, sequence parallelism distributes along the \(\text{sequence_length}\) dimension. This can be used when tensor parallelism is enabled in order to parallelize operations that run outside the tensor-parallel region (e.g. layer norm). For more details, please see this paper.

To show this in action, let’s first initialize NCCL with a trivial process group:

[3]:
# Configure parallel groups
import os
import torch
world_group = torch.distributed.init_process_group(
    "nccl",
    init_method="file:///tmp/rdzv",
    world_size=1,
    rank=0,
)
data_parallel_group = torch.distributed.new_group(ranks=[0], backend="nccl")
tensor_parallel_group = torch.distributed.new_group(ranks=[0], backend="nccl")

We only initialize with one GPU to keep this example simple. Please consult the documentation torch.distributed for guidance on running with multiple GPUs. Note that we require that each distributed process corresponds to exactly one GPU, so we treat them interchangeably. In practice, there are multiple factors that can affect the optimal parallel layout: the system hardware, the network topology, usage of other parallelism schemes like pipeline parallelism. A rough rule-of-thumb is to interpret the GPUs as a 2D grid with dimensions of \(\text{num_nodes} \times \text{gpus_per_node}\). The rows are tensor-parallel groups and the columns are data-parallel groups.

Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with torch.nn.parallel.DistributedDataParallel. FP8 training requires extra synchronization for the scaling factors, so the data-parallel process group must also be passed to the fp8_autocast context manager. Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.

[4]:
# Construct layer
parallel_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    set_parallel_mode=True,
    tp_group=tensor_parallel_group,
    sequence_parallel=True,
)
parallel_transformer.to(dtype=dtype).cuda()
parallel_transformer = torch.nn.parallel.DistributedDataParallel(
    parallel_transformer,
    process_group=data_parallel_group,
)

# Training step
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=data_parallel_group):
    y = parallel_transformer(x, attention_mask=None)
y.backward(dy)

# Measure step time
utils.speedometer(
    parallel_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
    fp8_autocast_kwargs = {
        "enabled": True,
        "fp8_recipe": fp8_recipe,
        "fp8_group": data_parallel_group,
    },
)
Mean time: 29.09606689453125 ms

Gradient accumulation fusion

Summary

We take advantage of the ability of Tensor Cores to accumulate outputs directly into FP32.

PyTorch’s autograd functionality assumes that a model parameter and its corresponding gradient have the same data type. However, while low-precision data types like FP8 are sufficient for evaluating a neural network’s forward and backward passes, the optimization step typically requires full FP32 precision to avoid signficant learning degradation. In addition, Tensor Cores on Hopper GPUs have the option to accumulate matrix products directly into FP32, resulting in better numerical accuracy and avoiding the need for a separate casting kernel. Thus, Transformer Engine provides an option to directly generate FP32 gradients for weight tensors. The FP32 gradients are not output to the parameter’s grad tensor, but rather to a main_grad tensor that must be initialized before the backward pass.

[5]:
# Construct layer
wgrad_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
    fuse_wgrad_accumulation=True,
    fuse_qkv_params=True, # Required for fuse_wgrad_accumulation
)
wgrad_transformer.to(dtype=dtype).cuda()
for param in wgrad_transformer.parameters():
    param.grad = None
    param.main_grad = torch.zeros_like(param, dtype=torch.float32)

# Training step
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = wgrad_transformer(x, attention_mask=None)
y.backward(dy)
for param in wgrad_transformer.parameters():
    if param.grad is not None:
        param.main_grad.copy_(param.grad)
        param.grad = None

# Measure step time
utils.speedometer(
    wgrad_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None },
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe },
)
Mean time: 27.510029296875 ms

FP8 weight caching

Summary

We avoid redundant FP8 casting when training with multiple gradient accumulation steps.

Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the fp8_autocast context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.

[6]:
# Construct layer
weight_caching_transformer = te.TransformerLayer(
    hidden_size,
    ffn_hidden_size,
    num_attention_heads,
)
weight_caching_transformer.to(dtype=dtype).cuda()

# Cast weights in first gradient accumulation step
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)
y.backward(dy)

# Reuse FP8 weights in subsequent gradient accumulation steps
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)
y.backward(dy)

# Measure step time
utils.speedometer(
    weight_caching_transformer,
    x,
    dy,
    forward_kwargs = { "attention_mask": None, "is_first_microbatch": False },
    fp8_autocast_kwargs = { "enabled": True, "fp8_recipe": fp8_recipe },
)
Mean time: 27.262666015625 ms