Mixed Precision Training#
Mixed precision training significantly enhances computational efficiency by conducting operations in low-precision format, while selectively maintaining minimal data in single-precision to preserve critical information throughout key areas of the network. NeMo Framework now supports FP16, BF16, and FP8 via Transformer Engine (TE) across most models.
Half-Precision Training#
NeMo Framework supports half-precision FP16 and BF16 computation training via Megatron Core and the distributed optimizer. This training recipe uses half-precision in all layer computation keeping the model states (optimizer states and master parameters) in single-precision. To avoid repeated data type casting at each layer computation, Megatron Core keeps a separate copy of half-precision parameters that is updated after each optimizer step.
Half-precision training is enabled when setting trainer’s plugins
to either of fp16-mixed
or bf16-mixed
.
The parameter gradients are computed in the same half-precision, and the precision of gradient reduce-scatter across data-parallel GPUs is set automatically according to the trainer’s precision.
Implement Half-Precision Training#
import nemo_run as run
from nemo import lightning as nl
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed
trainer_args = {TRAINER_ARGS}
# Set up trainer with bf16 precision
trainer_bf16 = run.Config(
nl.Trainer,
plugins=bf16_mixed(),
**trainer_args,
)
# Set up trainer with fp16 precision
trainer_fp16 = run.Config(
nl.Trainer,
plugins=fp16_mixed(),
**trainer_args,
)
It’s also possible to change precision for a specific recipe:
from functools import partial
from nemo.collections import llm
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed, fp16_mixed
# Load recipe
recipe = partial(llm.llama3_8b.pretrain_recipe)()
# Change precision
recipe.trainer.plugins = fp16_mixed()
FP8 Training#
NVIDIA H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. NeMo Framework uses the NVIDIA TransformerEngine (TE) to leverage speedups from FP8. The following table summarizes the FP8-related arguments that can be configured in NeMo (example config setting). For a more detailed overview, refer to the TE documentation, specifically the FP8 format and recipe.
Argument |
Description |
---|---|
fp8 |
The training recipe format for FP8 can be set to either ‘hybrid’ or ‘e4m3’, with ‘hybrid’ being the default. In the ‘hybrid’ format, activations and weight tensors use the E4M3 format, while gradients use the E5M2 format to meet the additional dynamic range requirements for backward tensors. |
fp8_margin |
The scaling factor for FP8 tensors can be shifted by a factor of $2 ^ {margin}$ using this argument. |
fp8_amax_history_len |
The window size for amax history. The window size determines how many instances of the most recent absolute max values (amaxes) are stored per tensor. |
fp8_amax_compute_algo |
The choice between “max” and “most_recent” specifies how to select an amax value from the given history. |
fp8_params |
Indicates whether to store module-level parameters in FP8. Enabling this option can reduce memory consumption by eliminating the need to store a copy of weights in higher precision for cases where these weights are externally maintained, such as master parameters in the optimizer. For more information, refer to the fp8_model_init API in TE. |
Implement FP8 Training#
import nemo_run as run
from nemo import lightning as nl
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_with_fp8_mixed, fp16_with_fp8_mixed
trainer_args = {TRAINER_ARGS}
fp8_args = {FP8_ARGS}
# Set up trainer with bf16 & fp8 precision
trainer_bf16_fp8 = run.Config(
nl.Trainer,
plugins=bf16_with_fp8_mixed(),
**trainer_args,
**fp8_args,
)
# Set up trainer with fp16 & fp8 precision
trainer_fp16_fp8 = run.Config(
nl.Trainer,
plugins=fp16_with_fp8_mixed(),
**trainer_args,
**fp8_args,
)
Resources#
Intro to FP8, floating point formats, and mixed precision training
Performance optimizations that are natively supported in NeMo Framework by enabling FP8 training with TE