Mixed Precision Training#

Mixed precision training enhances computational efficiency by conducting operations in low-precision format while selectively maintaining critical data in single-precision. NeMo supports FP16 and BF16 precision via PyTorch Lightning, in both mixed and true half-precision modes.

Precision Modes#

PyTorch Lightning provides two categories of half-precision training:

Mixed Precision ("bf16-mixed" / "16-mixed"):

Operations run in half-precision where safe, but model weights are kept in FP32. Gradients are computed in half-precision and accumulated in FP32. This is the safest option and generally a good default for ASR and TTS training.

True Half Precision ("bf16-true" / "fp16-true"):

The entire model – weights, activations, and gradients – runs in half-precision. This uses less memory than mixed precision (no FP32 weight copy) and is faster, but requires the model to be numerically stable in half-precision. SpeechLM2 models use "bf16-true" by default for training.

Configuration#

Set precision through the PyTorch Lightning trainer’s precision argument.

In YAML (with Hydra):

trainer:
    precision: "bf16-mixed"    # BF16 mixed precision
    # precision: "16-mixed"    # FP16 mixed precision
    # precision: "bf16-true"   # True BF16 half precision
    # precision: "fp16-true"   # True FP16 half precision

In Python:

import lightning.pytorch as pl

trainer = pl.Trainer(
    precision="bf16-mixed",
    devices=2,
    accelerator="gpu",
)

Choosing a Precision Format#

  • BF16 has the same dynamic range as FP32, which makes it more numerically stable and generally easier to use. It is the recommended choice for most Speech AI training workloads.

  • FP16 offers slightly higher throughput on some hardware but has a reduced dynamic range. In mixed precision mode, PyTorch Lightning handles loss scaling automatically.

HalfPrecisionForAudio#

Audio waveform tensors are sensitive to precision loss – downcasting raw audio samples to half-precision can degrade signal quality and hurt model accuracy. NeMo provides the HalfPrecisionForAudio plugin (in nemo.utils.trainer_utils) that extends Lightning’s HalfPrecision plugin to preserve full-precision for audio tensors while still casting all other inputs to half-precision.

Specifically, when the training mini-batch is a dictionary, any tensor whose key contains the substring "audio" is kept in its original precision (typically FP32). All other floating-point tensors are cast to the target half-precision dtype.

This plugin is used automatically when you launch training with NeMo’s resolve_trainer_cfg utility (used by all NeMo example training scripts). When the trainer config specifies precision: "bf16-true" or precision: "fp16-true", resolve_trainer_cfg replaces the precision setting with the HalfPrecisionForAudio plugin:

from nemo.utils.trainer_utils import resolve_trainer_cfg

# In YAML: trainer.precision = "bf16-true"
# resolve_trainer_cfg automatically installs HalfPrecisionForAudio
trainer = pl.Trainer(**resolve_trainer_cfg(cfg.trainer))

If you construct the trainer manually, you can install the plugin directly:

from nemo.utils.trainer_utils import HalfPrecisionForAudio

trainer = pl.Trainer(
    plugins=[HalfPrecisionForAudio("bf16-true")],
    devices=2,
    accelerator="gpu",
)