Source code for nemo_automodel.components.quantization.fp8

# Copyright (c) NVIDIA CORPORATION and affiliates.
# All rights reserved.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import logging
from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional

import torch
import torch.nn as nn

from nemo_automodel.shared.import_utils import MISSING_TORCHAO_MSG

logger = logging.getLogger(__name__)

try:
    from torchao.float8 import Float8LinearConfig, convert_to_float8_training

    HAVE_TORCHAO = True
except ImportError:
    HAVE_TORCHAO = False


[docs] @dataclass class FP8Config: """Configuration for FP8 quantization settings.""" recipe_name: Optional[Literal["tensorwise", "rowwise", "rowwise_with_gw_hp"]] = None """FP8 recipe to use. If None, uses tensorwise scaling with manual configuration.""" enable_fsdp_float8_all_gather: bool = False """Whether to enable float8 all-gather in FSDP, recommended for tensorwise scaling.""" precompute_float8_dynamic_scale_for_fsdp: bool = False """Whether to precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling.""" force_recompute_fp8_weight_in_bwd: bool = False """Whether to force the recomputation of FP8 weights during backward pass.""" filter_fqns: List[str] = field(default_factory=list) """ List of fully qualified names of modules to skip applying float8 training to. nn.Linear modules with any dim size not divisible by 16 are always skipped due to hardware requirements. Example: ["attention.wq", "attention.wk", "attention.wv", "lm_head"] """ emulate: bool = False """If True, emulation is used instead of hardware accelerated gemm. This is for test purpose only"""
[docs] @classmethod def from_config_node(cls, config_node): """Create FP8Config from a configuration node.""" if config_node is None: return cls() kwargs = {} for field_name in cls.__dataclass_fields__: if hasattr(config_node, field_name): kwargs[field_name] = getattr(config_node, field_name) return cls(**kwargs)
[docs] def to_dict(self): return { "fp8_recipe_name": self.recipe_name, "enable_fsdp_float8_all_gather": self.enable_fsdp_float8_all_gather, "precompute_float8_dynamic_scale_for_fsdp": self.precompute_float8_dynamic_scale_for_fsdp, "force_recompute_fp8_weight_in_bwd": self.force_recompute_fp8_weight_in_bwd, "fp8_filter_fqns": self.filter_fqns, "fp8_emulate": self.emulate, }
[docs] def _has_cuda_capability(major: int, minor: int) -> bool: """Check if CUDA device has required compute capability.""" if not torch.cuda.is_available(): return False device = torch.cuda.current_device() capability = torch.cuda.get_device_capability(device) return capability >= (major, minor)
[docs] def _module_filter_fn(module, name, filter_fqns: List[str] = None): """ Filter function to exclude certain modules from FP8 conversion. Args: module: The module to check name: Fully qualified name of the module filter_fqns: List of FQNs to filter out Returns: True if module should be converted to FP8, False otherwise """ if filter_fqns is None: filter_fqns = [] # Skip modules in filter list for fqn in filter_fqns: if fqn in name: return False # Always skip non-linear layers if not isinstance(module, nn.Linear): return False # Skip layers with dimensions not divisible by 16 if hasattr(module, "weight"): weight = module.weight if weight.shape[0] % 16 != 0 or weight.shape[1] % 16 != 0: logger.info(f"Skipping fp8 for layer {name} with weight shape {weight.shape}") return False return True
[docs] def apply_fp8_to_model( model: nn.Module, filter_fqns: Optional[List[str]] = None, recipe_name: Optional[str] = None, force_recompute_fp8_weight_in_bwd: bool = False, enable_fsdp_float8_all_gather: bool = False, emulate: bool = False, ) -> nn.Module: """ Apply FP8 quantization to a PyTorch model using torchao. Args: model: The model to convert filter_fqns: List of module names to exclude from FP8 conversion recipe_name: Recipe name for FP8 configuration ("tensorwise", "rowwise", etc.) force_recompute_fp8_weight_in_bwd: Whether to force recompute FP8 weight in backward pass enable_fsdp_float8_all_gather: Whether to enable FSDP FP8 all-gather emulate: Use emulation instead of hardware acceleration (for testing on older GPUs) Returns: The model with FP8 linear layers (modified in-place) Raises: ImportError: If torchao is not installed ValueError: If hardware doesn't support FP8 and emulation is disabled """ # Check if torchao is available if not HAVE_TORCHAO: raise ImportError(MISSING_TORCHAO_MSG) # Handle config creation or recipe-based configuration if recipe_name is not None and recipe_name != "tensorwise": config = Float8LinearConfig.from_recipe_name(recipe_name) logger.info(f"Using FP8 recipe: {recipe_name}") # Enable inductor precision cast emulation for rowwise recipe if recipe_name == "rowwise": torch._inductor.config.emulate_precision_casts = True logger.debug("Enabled torch._inductor.config.emulate_precision_casts for rowwise recipe") else: # Manual configuration for tensorwise scaling config = Float8LinearConfig( enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, force_recompute_fp8_weight_in_bwd=force_recompute_fp8_weight_in_bwd, emulate=emulate, ) logger.info("Using FP8 tensorwise scaling") # Check hardware capability if not using emulation config_emulate = getattr(config, "emulate", emulate) if not _has_cuda_capability(8, 9) and not config_emulate: raise ValueError( "FP8 is only supported on SM89 or later GPUs (H100+). " "To enable testing on older hardware, set emulate=True in Float8LinearConfig or pass emulate=True." ) if filter_fqns is None: filter_fqns = [] filter_fn = partial(_module_filter_fn, filter_fqns=filter_fqns) # Convert model to use FP8 linear layers convert_to_float8_training( model, config=config, module_filter_fn=filter_fn, ) logger.info( f"Successfully converted model to FP8 with torchAO, recipe: {recipe_name or 'tensorwise'}, " f"fp8 all-gather enabled: {config.enable_fsdp_float8_all_gather}, " f"force recompute FP8 weight in backward pass: {config.force_recompute_fp8_weight_in_bwd}" ) verify_fp8_conversion(model) return model
[docs] def verify_fp8_conversion(model: nn.Module) -> dict: """ Verify that FP8 conversion was successful by counting converted modules. Args: model: The model to verify Returns: Dict with conversion statistics """ from torchao.float8.float8_linear import Float8Linear total_linear = 0 fp8_modules = [] for name, module in model.named_modules(): module_type = type(module).__name__ # Count both nn.Linear and Float8Linear as linear layers if isinstance(module, nn.Linear): total_linear += 1 logger.debug(f"Found nn.Linear: {name} ({module_type})") # Check if it's a Float8Linear by comparing class names or checking attributes if isinstance(module, Float8Linear): fp8_modules.append( { "name": name, "type": module_type, "weight_shape": list(module.weight.shape) if hasattr(module, "weight") else None, } ) logger.debug(f"Found Float8Linear: {name} ({module_type})") elif module_type == "Float8Linear": # Fallback: check by class name in case isinstance fails fp8_modules.append( { "name": name, "type": module_type, "weight_shape": list(module.weight.shape) if hasattr(module, "weight") else None, } ) logger.debug(f"Found Float8Linear by name: {name} ({module_type})") logger.info(f"FP8 conversion: {len(fp8_modules)} Float8Linear modules, {total_linear} total linear modules") return { "linear_count": total_linear, "fp8_count": len(fp8_modules), "conversion_rate": (len(fp8_modules) / total_linear * 100) if total_linear > 0 else 0, "fp8_modules": fp8_modules, "success": len(fp8_modules) > 0, }