# Copyright (c) NVIDIA CORPORATION and affiliates.
# All rights reserved.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Literal, Optional
import torch
import torch.nn as nn
from nemo_automodel.shared.import_utils import MISSING_TORCHAO_MSG
logger = logging.getLogger(__name__)
try:
from torchao.float8 import Float8LinearConfig, convert_to_float8_training
HAVE_TORCHAO = True
except ImportError:
HAVE_TORCHAO = False
[docs]
@dataclass
class FP8Config:
"""Configuration for FP8 quantization settings."""
enabled: bool = False
"""Whether FP8 quantization is enabled."""
recipe_name: Optional[Literal["tensorwise", "rowwise", "rowwise_with_gw_hp"]] = None
"""FP8 recipe to use. If None, uses tensorwise scaling with manual configuration."""
enable_fsdp_float8_all_gather: bool = False
"""Whether to enable float8 all-gather in FSDP, recommended for tensorwise scaling."""
precompute_float8_dynamic_scale_for_fsdp: bool = False
"""Whether to precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling."""
force_recompute_fp8_weight_in_bwd: bool = False
"""Whether to force the recomputation of FP8 weights during backward pass."""
filter_fqns: List[str] = field(default_factory=list)
"""
List of fully qualified names of modules to skip applying float8 training to.
nn.Linear modules with any dim size not divisible by 16 are always skipped due to hardware requirements.
Example: ["attention.wq", "attention.wk", "attention.wv", "lm_head"]
"""
emulate: bool = False
"""If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only"""
def __init__(
self,
enabled: bool = False,
recipe_name: Optional[Literal["tensorwise", "rowwise", "rowwise_with_gw_hp"]] = None,
enable_fsdp_float8_all_gather: bool = False,
precompute_float8_dynamic_scale_for_fsdp: bool = False,
force_recompute_fp8_weight_in_bwd: bool = False,
filter_fqns: List[str] = None,
emulate: bool = False,
):
self.enabled = enabled
self.recipe_name = recipe_name
self.enable_fsdp_float8_all_gather = enable_fsdp_float8_all_gather
self.precompute_float8_dynamic_scale_for_fsdp = precompute_float8_dynamic_scale_for_fsdp
self.force_recompute_fp8_weight_in_bwd = force_recompute_fp8_weight_in_bwd
self.filter_fqns = filter_fqns or []
self.emulate = emulate
[docs]
@classmethod
def from_config_node(cls, config_node):
"""Create FP8Config from a configuration node."""
if config_node is None:
return cls()
kwargs = {}
for field_name in cls.__dataclass_fields__:
if hasattr(config_node, field_name):
kwargs[field_name] = getattr(config_node, field_name)
return cls(**kwargs)
[docs]
def to_dict(self):
return {
"enabled": self.enabled,
"fp8_recipe_name": self.recipe_name,
"enable_fsdp_float8_all_gather": self.enable_fsdp_float8_all_gather,
"precompute_float8_dynamic_scale_for_fsdp": self.precompute_float8_dynamic_scale_for_fsdp,
"force_recompute_fp8_weight_in_bwd": self.force_recompute_fp8_weight_in_bwd,
"fp8_filter_fqns": self.filter_fqns,
"fp8_emulate": self.emulate,
}
[docs]
def _has_cuda_capability(major: int, minor: int) -> bool:
"""Check if CUDA device has required compute capability."""
if not torch.cuda.is_available():
return False
device = torch.cuda.current_device()
capability = torch.cuda.get_device_capability(device)
return capability >= (major, minor)
[docs]
def _module_filter_fn(module, name, filter_fqns: List[str] = None):
"""
Filter function to exclude certain modules from FP8 conversion.
Args:
module: The module to check
name: Fully qualified name of the module
filter_fqns: List of FQNs to filter out
Returns:
True if module should be converted to FP8, False otherwise
"""
if filter_fqns is None:
filter_fqns = []
# Skip modules in filter list
for fqn in filter_fqns:
if fqn in name:
return False
# Always skip non-linear layers
if not isinstance(module, nn.Linear):
return False
# Skip layers with dimensions not divisible by 16
if hasattr(module, "weight"):
weight = module.weight
if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0:
logger.info(f"Skipping fp8 for layer {name} with weight shape {weight.shape}")
return False
return True
[docs]
def apply_fp8_to_model(
model: nn.Module,
config: Optional[FP8Config] = None,
# Individual parameter options for backward compatibility
filter_fqns: Optional[List[str]] = None,
recipe_name: Optional[str] = None,
force_recompute_fp8_weight_in_bwd: bool = False,
enable_fsdp_float8_all_gather: bool = False,
emulate: bool = False,
enabled: bool = True,
precompute_float8_dynamic_scale_for_fsdp: bool = False,
) -> nn.Module:
"""
Apply FP8 quantization to a PyTorch model using torchao.
This function can be called in two ways:
1. With an FP8Config object: apply_fp8_to_model(model, config=fp8_config)
2. With individual parameters: apply_fp8_to_model(model, filter_fqns=..., recipe_name=..., etc.)
Args:
model: The model to convert
config: FP8Config object containing all configuration. If provided, individual
parameters are ignored.
filter_fqns: List of module names to exclude from FP8 conversion
recipe_name: Recipe name for FP8 configuration ("tensorwise", "rowwise", etc.)
force_recompute_fp8_weight_in_bwd: Whether to force recompute FP8 weight in backward pass
enable_fsdp_float8_all_gather: Whether to enable FSDP FP8 all-gather
emulate: Use emulation instead of hardware acceleration (for testing on older GPUs)
enabled: Whether FP8 quantization is enabled (only used when config is None)
precompute_float8_dynamic_scale_for_fsdp: Whether to precompute float8 scales dynamically
Returns:
The model with FP8 linear layers (modified in-place)
Raises:
ImportError: If torchao is not installed
ValueError: If hardware doesn't support FP8 and emulation is disabled
"""
# If config is provided, use it; otherwise create config from individual parameters
if config is not None:
# Use provided FP8Config
fp8_config = config
else:
# Create FP8Config from individual parameters
fp8_config = FP8Config(
enabled=enabled,
recipe_name=recipe_name,
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
precompute_float8_dynamic_scale_for_fsdp=precompute_float8_dynamic_scale_for_fsdp,
force_recompute_fp8_weight_in_bwd=force_recompute_fp8_weight_in_bwd,
filter_fqns=filter_fqns or [],
emulate=emulate,
)
# Check if FP8 is disabled
if not fp8_config.enabled:
logger.info("FP8 quantization is disabled")
return model
# Check if torchao is available
if not HAVE_TORCHAO:
raise ImportError(MISSING_TORCHAO_MSG)
# Set precompute attribute on model
model.precompute_float8_dynamic_scale_for_fsdp = (
fp8_config.precompute_float8_dynamic_scale_for_fsdp
and fp8_config.recipe_name == "tensorwise"
and fp8_config.enable_fsdp_float8_all_gather
)
# Handle config creation or recipe-based configuration
if fp8_config.recipe_name is not None and fp8_config.recipe_name != "tensorwise":
torchao_config = Float8LinearConfig.from_recipe_name(fp8_config.recipe_name)
logger.info(f"Using FP8 recipe: {fp8_config.recipe_name}")
# Enable inductor precision cast emulation for rowwise recipe
if fp8_config.recipe_name == "rowwise":
torch._inductor.config.emulate_precision_casts = True
logger.debug("Enabled torch._inductor.config.emulate_precision_casts for rowwise recipe")
else:
# Manual configuration for tensorwise scaling
torchao_config = Float8LinearConfig(
enable_fsdp_float8_all_gather=fp8_config.enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=fp8_config.force_recompute_fp8_weight_in_bwd,
emulate=fp8_config.emulate,
)
logger.info("Using FP8 tensorwise scaling")
# Check hardware capability if not using emulation
config_emulate = getattr(torchao_config, "emulate", fp8_config.emulate)
if not _has_cuda_capability(8, 9) and not config_emulate:
raise ValueError(
"FP8 is only supported on SM89 or later GPUs (H100+). "
"To enable testing on older hardware, set emulate=True in Float8LinearConfig or pass emulate=True."
)
try:
filter_fqns_list = fp8_config.filter_fqns or []
filter_fn = partial(_module_filter_fn, filter_fqns=filter_fqns_list)
# Convert model to use FP8 linear layers
convert_to_float8_training(
model,
config=torchao_config,
module_filter_fn=filter_fn,
)
logger.info(
f"Successfully converted model to FP8 with torchAO, recipe: {fp8_config.recipe_name or 'tensorwise'}, "
f"fp8 all-gather enabled: {torchao_config.enable_fsdp_float8_all_gather}, "
f"force recompute FP8 weight in backward pass: {torchao_config.force_recompute_fp8_weight_in_bwd}"
)
verify_fp8_conversion(model)
logger.info("FP8 quantization applied successfully")
return model
except Exception as e:
logger.warning(f"FP8 quantization failed: {e}. Returning original model")
return model
[docs]
def verify_fp8_conversion(model: nn.Module) -> dict:
"""
Verify that FP8 conversion was successful by counting converted modules.
Args:
model: The model to verify
Returns:
Dict with conversion statistics
"""
from torchao.float8.float8_linear import Float8Linear
total_linear = 0
fp8_modules = []
for name, module in model.named_modules():
module_type = type(module).__name__
# Count both nn.Linear and Float8Linear as linear layers
if isinstance(module, nn.Linear):
total_linear += 1
logger.debug(f"Found nn.Linear: {name} ({module_type})")
# Check if it's a Float8Linear by comparing class names or checking attributes
if isinstance(module, Float8Linear):
fp8_modules.append(
{
"name": name,
"type": module_type,
"weight_shape": list(module.weight.shape) if hasattr(module, "weight") else None,
}
)
logger.debug(f"Found Float8Linear: {name} ({module_type})")
elif module_type == "Float8Linear":
# Fallback: check by class name in case isinstance fails
fp8_modules.append(
{
"name": name,
"type": module_type,
"weight_shape": list(module.weight.shape) if hasattr(module, "weight") else None,
}
)
logger.debug(f"Found Float8Linear by name: {name} ({module_type})")
logger.info(f"FP8 conversion: {len(fp8_modules)} Float8Linear modules, {total_linear} total linear modules")
return {
"linear_count": total_linear,
"fp8_count": len(fp8_modules),
"conversion_rate": (len(fp8_modules) / total_linear * 100) if total_linear > 0 else 0,
"fp8_modules": fp8_modules,
"success": len(fp8_modules) > 0,
}
[docs]
def create_fp8_config_from_dict(config_dict: Dict[str, Any]) -> FP8Config:
"""Create a FP8Config from a dictionary.
Args:
config_dict: Dictionary containing FP8 configuration.
Returns:
FP8Config instance.
"""
return FP8Config(
enabled=config_dict.get("enabled", False),
recipe_name=config_dict.get("recipe_name", None),
enable_fsdp_float8_all_gather=config_dict.get("enable_fsdp_float8_all_gather", False),
precompute_float8_dynamic_scale_for_fsdp=config_dict.get("precompute_float8_dynamic_scale_for_fsdp", False),
force_recompute_fp8_weight_in_bwd=config_dict.get("force_recompute_fp8_weight_in_bwd", False),
filter_fqns=config_dict.get("filter_fqns", []),
emulate=config_dict.get("emulate", False),
)
[docs]
def build_fp8_config(cfg: Optional[Dict[str, Any]]) -> FP8Config:
"""Build a FP8 config from configuration.
Args:
cfg: Configuration dictionary for FP8 quantization.
Returns:
FP8Config instance.
"""
if cfg is None:
return FP8Config(enabled=False)
else:
return create_fp8_config_from_dict(cfg)