{ "cells": [ { "cell_type": "markdown", "id": "24184f3f", "metadata": {}, "source": [ "# Performance Optimizations" ] }, { "cell_type": "markdown", "id": "6dcbf25a", "metadata": {}, "source": [ "This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). " ] }, { "cell_type": "code", "execution_count": 1, "id": "2b53dfa7", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import transformer_engine.pytorch as te\n", "from transformer_engine.common.recipe import Format, DelayedScaling\n", "import quickstart_utils as utils\n", "\n", "# Layer configuration\n", "hidden_size = 4096\n", "sequence_length = 2048\n", "batch_size = 4\n", "ffn_hidden_size = 16384\n", "num_attention_heads = 32\n", "dtype = torch.float16\n", "\n", "# Synthetic data\n", "x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n", "dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)" ] }, { "cell_type": "code", "execution_count": 2, "id": "b96a9ef6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mean time: 27.82952880859375 ms\n" ] } ], "source": [ "# Construct layer\n", "basic_transformer = te.TransformerLayer(\n", " hidden_size,\n", " ffn_hidden_size,\n", " num_attention_heads,\n", ")\n", "basic_transformer.to(dtype=dtype).cuda()\n", "\n", "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(\n", " fp8_format=fp8_format,\n", " amax_history_len=16,\n", " amax_compute_algo=\"max\",\n", ")\n", "# Training step\n", "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", " y = basic_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", "# Measure step time\n", "utils.speedometer(\n", " basic_transformer,\n", " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", ")" ] }, { "cell_type": "markdown", "id": "11367f5b", "metadata": {}, "source": [ "## Multi-GPU training\n", "\n", "