TF32 Q^T L Q precision loss — standalone demo#

Build a symmetric matrix L with a realistic wide-dynamic-range spectrum (top eigenvalue ~10⁶, smallest ~10⁻¹ — span of 7 orders of magnitude). Take its exact eigendecomposition L = Q diag(λ) Q^T, then try to recover λ by computing diag(Q^T L Q) under different matmul precisions.

import numpy as np
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")  # uncomment to force CPU
dtype = torch.float32
print(f"device={device}, dtype={dtype}")
if device.type == "cuda":
    print(f"  cuBLAS TF32 matmul:  {torch.backends.cuda.matmul.allow_tf32}")
    print(f"  global float32 matmul precision: {torch.get_float32_matmul_precision()}")
else:
    print("  (running on CPU — fp32 'high' and 'highest' are identical, so this notebook won't show a precision difference unless run on CUDA.)")

device=cuda, dtype=torch.float32
  cuBLAS TF32 matmul:  False
  global float32 matmul precision: highest

Construct L with a realistic SOAP-like spectrum#

Pick m = 128 and assign true eigenvalues:

  • Top 5 nearly degenerate around ~2.5e6 (mimicking KL-Shampoo’s near-isotropic dominant subspace).

  • The remaining 123 decay log-uniformly down to ~0.1 (the tail).

Then generate a random orthogonal Q via QR of a Gaussian, and form L = Q diag(λ) Q^T. We build in fp64 for clean ground truth, then cast to fp32 for the test. By construction, computing diag(Q^T L Q) should return λ exactly (any deviation is matmul precision error).

m = 128

# True eigenvalues: top 5 near-degenerate around 2.5e6, tail decays log-uniformly to 0.1.
true_eigvals_f64 = torch.empty(m, dtype=torch.float64)
top = 2.5e6
true_eigvals_f64[:5] = torch.tensor([top, top * 0.98, top * 0.96, top * 0.94, top * 0.92], dtype=torch.float64)
true_eigvals_f64[5:] = torch.logspace(np.log10(top * 0.90), -1.0, m - 5, dtype=torch.float64)

# Random orthogonal Q via QR of a Gaussian matrix (built in fp64 for cleanliness).
g = torch.Generator().manual_seed(0)
A = torch.randn(m, m, generator=g, dtype=torch.float64)
Q_f64, _ = torch.linalg.qr(A)

# Assemble L = Q diag(λ) Q^T in fp64, symmetrize, then cast to fp32 on device.
L_f64 = Q_f64 @ torch.diag(true_eigvals_f64) @ Q_f64.T
L_f64 = (L_f64 + L_f64.T) / 2

L = L_f64.to(dtype).to(device)
Q = Q_f64.to(dtype).to(device)
true_eigvals = true_eigvals_f64.to(dtype)

print(f"L: shape={tuple(L.shape)}, dtype={L.dtype}, device={L.device}")
print(f"Spectrum:")
print(f"  top 5     : {true_eigvals[:5].tolist()}")
print(f"  bottom 5  : {true_eigvals[-5:].tolist()}")
print(f"  dynamic range: {(true_eigvals[0] / true_eigvals[-1]).item():.2e}")
L: shape=(128, 128), dtype=torch.float32, device=cuda:0
Spectrum:
  top 5     : [2500000.0, 2450000.0, 2400000.0, 2350000.0, 2300000.0]
  bottom 5  : [0.1742028146982193, 0.1516321748495102, 0.13198591768741608, 0.1148851215839386, 0.10000000149011612]
  dynamic range: 2.50e+07

Compute diag(Q^T L Q) under different matmul precisions#

Two ways to control matmul precision in PyTorch:

  • torch.set_float32_matmul_precision("highest") — full fp32 (24-bit mantissa).

  • torch.set_float32_matmul_precision("high") — TF32 on Ampere+ CUDA (10-bit mantissa); identical to "highest" on CPU.

Run on CUDA to see the TF32 precision loss. On CPU both runs produce the same output.

def diag_QtLQ(L_in: torch.Tensor, Q_in: torch.Tensor) -> torch.Tensor:
    """Compute diag(Q.T @ L @ Q) without materializing the full product (matches SOAP's `eig_utils.conjugate(..., diag=True)`)."""
    QtL = Q_in.T @ L_in
    return (QtL * Q_in.T).sum(dim=-1)


def run_at_precision(prec: str) -> torch.Tensor:
    """Run the diag(Q.T L Q) computation under a particular matmul precision setting."""
    prev_global = torch.get_float32_matmul_precision()
    prev_tf32 = torch.backends.cuda.matmul.allow_tf32
    if prec == "highest":
        # Belt-and-suspenders: also flip the cuBLAS knob to ensure no TF32 on CUDA.
        torch.backends.cuda.matmul.allow_tf32 = False
    torch.set_float32_matmul_precision(prec)
    try:
        return diag_QtLQ(L, Q).cpu()
    finally:
        torch.set_float32_matmul_precision(prev_global)
        torch.backends.cuda.matmul.allow_tf32 = prev_tf32


eigvals_highest = run_at_precision("highest")  # full fp32
eigvals_high = run_at_precision("high")        # TF32 on CUDA, fp32 on CPU

print(f"{'idx':>4}  {'true':>14}  {'fp32 (highest)':>16}  {'fp32 (high/TF32)':>18}")
print("-" * 60)
for i in [0, 1, 4, 20, 50, 80, 100, 115, 120, 124, 125, 126, 127]:
    print(
        f"  {i:>3}  {true_eigvals[i].item():>14.5g}  {eigvals_highest[i].item():>16.5g}  "
        f"{eigvals_high[i].item():>18.5g}"
    )

 idx            true    fp32 (highest)    fp32 (high/TF32)
------------------------------------------------------------
    0         2.5e+06           2.5e+06          2.4999e+06
    1        2.45e+06          2.45e+06          2.4501e+06
    4         2.3e+06           2.3e+06          2.2999e+06
   20      2.8069e+05        2.8069e+05          2.8067e+05
   50          4368.3            4368.3              4366.8
   80          67.983            67.979              75.758
  100          4.2376            4.2453             -5.1326
  115         0.52865           0.52618               16.15
  120         0.26415           0.26051              4.6133
  124         0.15163           0.15539             -12.385
  125         0.13199           0.13418             -1.5987
  126         0.11489           0.11431             -27.317
  127             0.1           0.10366              -6.226