FP8 Training in NeMo AutoModel

View as Markdown

NeMo AutoModel supports FP8 quantization using TorchAO and torch.compile to accelerate training on compatible hardware.

FP8 (8-bit floating point) quantization can provide substantial speedups for models where the majority of GEMMs are sufficiently large. The speedup from using FP8 tensor cores must outweigh the overhead of dynamic quantization.

Requirements for FP8 Training in NeMo AutoModel

To enable FP8 training in NeMo AutoModel, the following hardware and software requirements must be met:

  • Hardware:
    An NVIDIA H100 GPU or newer is required. These GPUs feature FP8 tensor cores that accelerate training.

  • Software:
    The TorchAO library must be installed.

  • Configuration:
    Both torch.compile and fp8 must be enabled in your training configuration.
    Important: torch.compile is essential for achieving meaningful speedup with TorchAO FP8 training.

Install TorchAO

Make sure you have TorchAO installed. Follow the installation guide for TorchAO.

Usage

Configure FP8

To enable FP8 quantization with torch.compile, you need both FP8 and compilation enabled in your configuration:

1# Enable torch.compile (required for FP8 speedup)
2compile:
3 enabled: true
4 mode: "default"
5 fullgraph: false
6 dynamic: false
7
8# Enable FP8 quantization
9fp8:
10 enabled: true
11 recipe_name: tensorwise
12 enable_fsdp_float8_all_gather: true
13 precompute_float8_dynamic_scale_for_fsdp: true
14 force_recompute_fp8_weight_in_bwd: true
15 filter_fqns: ["lm_head"]
16 emulate: false

FP8 Config Parameters

ParameterTypeDefaultDescription
recipe_namestrNoneFP8 recipe: “tensorwise”, “rowwise”, or “rowwise_with_gw_hp”
enable_fsdp_fp8_all_gatherboolFalseEnable FP8 all-gather in FSDP for bandwidth savings
force_recompute_fp8_weight_in_bwdboolFalseForce recomputation of FP8 weights in backward pass
precompute_fp8_dynamic_scale_for_fsdpboolFalsePrecompute FP8 scales for FSDP optimization
filter_fqnslist[str][]Module names to exclude from FP8 conversion
emulateboolFalseUse emulation instead of hardware acceleration

Scaling Strategies

Tensorwise Scaling (Default)

  • Single scale per tensor
  • Good performance, moderate accuracy
  • Recommended for most use cases

Rowwise Scaling

  • Scale per row for better accuracy
  • Slower than tensorwise
  • Better numerical stability

For more on scaling strategies, refer to the TorchAO FP8 documentation.

Filter Modules

You can exclude specific modules from FP8 conversion using filter_fqns:

1fp8:
2 enabled: true
3 recipe_name: tensorwise
4 filter_fqns: ["lm_head"] # Skip these modules

Speed and Convergence

FP8 quantization provides measurable performance improvements while maintaining model convergence:

  • Speed: Over 1.2x training speedup on 8xH100 with tensorwise scaling.
  • Convergence: FP8 training achieves loss parity with BF16 training.
  • Memory: FP8 training achieves on par memory usage with BF16 baseline.

FP8 Convergence Comparison

Figure: Loss curves comparing FP8 tensorwise scaling + torch.compile vs. BF16 + torch.compile training on 8xH100 with 8k sequence length, demonstrating virtually identical convergence behavior with 1.24x speedup

Ready-to-Use Recipes

We provide FP8 training configs for popular models:

Check out our examples directory for more recipes and configurations.

To run any of these FP8 training recipes, use the following command:

$automodel --nproc-per-node=8 <path-to-config.yaml>

For example, to train Llama 3.1 8B with FP8:

$automodel --nproc-per-node=8 examples/llm_finetune/llama3_1/llama3_1_8b_hellaswag_fp8.yaml

Performance Considerations

FP8 requires specific conditions to be effective:

  • Input tensors must have dimensions divisible by 16
  • Use compatible hardware (H100+)
  • Train with torch.compile

FP8 works best when the majority of GEMM operations are sufficiently large such that the speedup achieved by using FP8 tensor cores is greater than the overhead of dynamic quantization.

Ideal Conditions for FP8 Performance

  • Linear layers are large and compute-intensive
  • The model consists of fewer small operations and more large matrix multiplications
  • You have modern (H100+) hardware optimized for FP8 acceleration
  • Moderate numerical precision is acceptable and slight approximations won’t affect outcomes

References