Mixed Precision Training#
Mixed precision training significantly enhances computational efficiency by conducting operations in low-precision format, while selectively maintaining minimal data in single-precision to preserve critical information throughout key areas of the network. Megatron Bridge supports FP16, BF16, and FP8 via Transformer Engine (TE) across most models through the bridge.training.mixed_precision.MixedPrecisionConfig
configuration.
Configuration Overview#
Mixed precision is configured in Megatron Bridge through the mixed_precision
field in bridge.training.config.ConfigContainer
, which accepts either:
A string name referencing a predefined recipe (e.g.,
"bf16_mixed"
)A
bridge.training.mixed_precision.MixedPrecisionConfig
object for custom configurations
The mixed precision configuration automatically updates the model, optimizer, and distributed data parallel settings with the appropriate precision parameters.
Half-Precision Training#
Megatron Bridge supports half-precision FP16 and BF16 computation training via Megatron Core and the distributed optimizer. This training recipe uses half-precision in all layer computation while keeping the model states (optimizer states and master parameters) in single-precision. To avoid repeated data type casting at each layer computation, Megatron Core keeps a separate copy of half-precision parameters that is updated after each optimizer step.
Using Predefined Recipes#
The simplest way to enable mixed precision is using predefined recipe names:
from megatron.bridge.training.config import ConfigContainer
# Configure with BF16 mixed precision
config = ConfigContainer(
mixed_precision="bf16_mixed",
# ... other config parameters
)
# Configure with FP16 mixed precision
config = ConfigContainer(
mixed_precision="fp16_mixed",
# ... other config parameters
)
Custom Mixed Precision Configuration#
For more control, create a custom bridge.training.mixed_precision.MixedPrecisionConfig
:
from megatron.bridge.training.config import ConfigContainer
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig
import torch
# Custom BF16 configuration
bf16_config = MixedPrecisionConfig(
bf16=True,
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=True,
)
config = ConfigContainer(
mixed_precision=bf16_config,
# ... other config parameters
)
FP8 Training#
NVIDIA H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Megatron Bridge uses the NVIDIA TransformerEngine (TE) to leverage speedups from FP8. For a more detailed overview, refer to the TE documentation, specifically the FP8 format and recipe.
FP8 Configuration Parameters#
The bridge.training.mixed_precision.MixedPrecisionConfig
provides several FP8-specific parameters:
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
FP8 format: |
|
|
|
FP8 recipe type: |
|
|
|
If True, retains first and last N TransformerBlocks in BF16 as opposed to FP8 |
|
|
|
Number of layers at the start of the model to keep in BF16 precision when |
|
|
|
Number of layers at the end of the model to keep in BF16 precision when |
|
|
|
Scaling factor shift by \(2^{margin}\) |
|
|
|
Window size for amax history storage |
|
|
|
Amax selection algorithm: |
|
|
|
Store module-level parameters in FP8 |
|
|
|
Enable FP8 parameter gathering |
FP8 Recipe Examples#
Use any of the predefined FP8 recipe names with the mixed_precision
parameter:
# Example: BF16 with FP8 current scaling
config = ConfigContainer(
mixed_precision="bf16_with_fp8_current_scaling_mixed",
# ... other config parameters
)
Available Mixed Precision Recipes#
Megatron Bridge provides numerous predefined mixed precision recipes for different use cases. You can use the get_mixed_precision_config()
utility function to convert from a string shortname to a class instance. For the complete list of available recipes and their specific configurations, see the megatron.bridge.training.mixed_precision
module.
Custom FP8 Configuration#
For advanced use cases, create a custom FP8 configuration:
from megatron.bridge.training.mixed_precision import MixedPrecisionConfig
import torch
# Custom FP8 configuration
fp8_config = MixedPrecisionConfig(
bf16=True,
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
fp8="hybrid",
fp8_recipe="tensorwise",
fp8_margin=0,
fp8_amax_history_len=1024,
fp8_amax_compute_algo="max",
fp8_param_gather=True,
)
config = ConfigContainer(
mixed_precision=fp8_config,
# ... other config parameters
)
Registering Custom Mixed Precision Recipes#
You can also register your own custom mixed precision configurations to work with the shortname system. Use the register()
decorator on a function that returns a MixedPrecisionConfig
object:
from megatron.bridge.training.mixed_precision import register, MixedPrecisionConfig
@register
def my_custom_fp8_recipe() -> MixedPrecisionConfig:
"""Custom FP8 recipe with specific settings for my use case."""
return MixedPrecisionConfig(
bf16=True,
fp8="hybrid",
fp8_recipe="tensorwise",
fp8_param_gather=True,
# ... other custom settings
)
# Now you can use it with the utility function
config = get_mixed_precision_config("my_custom_fp8_recipe")
Common recipe categories include:
Half-precision recipes: Basic BF16 and FP16 mixed precision
FP8 recipes: Various FP8 scaling strategies (delayed, current, subchannel)
Architecture-specific recipes: Optimized for specific GPU architectures (Hopper, Blackwell)
Model-specific recipes: Tuned for particular model families
Configuration Synchronization#
When a mixed precision configuration is provided, it automatically synchronizes precision-related settings across the model, optimizer, and distributed data parallel (DDP) configurations. This ensures consistent precision behavior throughout the training pipeline.
Important: Mixed precision settings will override any conflicting precision parameters that may have been set directly on the model, optimizer, or DDP configurations. The mixed precision configuration acts as the authoritative source for all precision-related parameters.
For example, if you specify both:
# This will be overridden
model_config.bf16 = False
optimizer_config.bf16 = False
config = ConfigContainer(
model=model_config,
optimizer=optimizer_config,
mixed_precision="bf16_mixed", # This takes precedence during training
# ... other configs
)
The mixed precision configuration will set bf16=True
on both the model and optimizer configs, overriding the explicitly set False
values. This synchronization prevents configuration mismatches that could lead to training issues.
Performance Considerations#
FP8 recipes are experimental and convergence has not been fully validated for all models
BF16 is generally recommended over FP16 for better numerical stability
FP8 provides the best performance on H100 GPUs but requires careful tuning
MXFP8 recipes are only supported on Blackwell architecture GPUs
Blockwise scaling recipes are optimized for Hopper architecture GPUs
Resources#
Intro to FP8, floating point formats, and mixed precision training
Performance optimizations that are natively supported in Megatron Bridge by enabling FP8 training with TE