Source code for emerging_optimizers.psgd.procrustes_step
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal
import torch
import emerging_optimizers.utils as utils
from emerging_optimizers.psgd.psgd_utils import norm_lower_bound_skew
__all__ = [
"procrustes_step",
]
[docs]
@torch.compile # type: ignore[misc]
def procrustes_step(
Q: torch.Tensor, max_step_size: float = 0.125, eps: float = 1e-8, order: Literal[2, 3] = 2
) -> torch.Tensor:
r"""One step of an online solver for the orthogonal Procrustes problem.
The orthogonal Procrustes problem is :math:`\min_U \| U Q - I \|_F` s.t. :math:`U^H U = I`
by rotating Q as :math:`\exp(a R) Q`, where :math:`R = Q^H - Q` is the generator and :math:`\|a R\| < 1`.
If using 2nd order expansion, `max_step_size` should be less than :math:`1/4` as we only expand :math:`\exp(a R)`
to its 2nd order term. If using 3rd order expansion, `max_step_size` should be less than :math:`5/8`.
This method is an expansion of a Lie algebra parametrized rotation that
uses a simple approximate line search to find the optimal step size, from Xi-Lin Li.
Args:
Q: Tensor of shape (n, n), general square matrix to orthogonalize.
max_step_size: Maximum step size for the line search. Default is 1/8. (0.125)
eps: Small number for numerical stability.
order: Order of the Taylor expansion. Must be 2 or 3. Default is 2.
"""
if order not in (2, 3):
raise ValueError(f"order must be 2 or 3, got {order}")
# Note: this function is written in fp32 to avoid numerical instability while computing the expansion of the exponential map
with utils.fp32_matmul_precision("highest"):
R = Q.T - Q
R /= torch.clamp(norm_lower_bound_skew(R), min=eps)
RQ = R @ Q
# trace of RQ is always positive,
# since tr(RQ) = ⟨R, Q⟩_F = ⟨Q^T - Q, Q⟩_F = ||Q||_F^2 - ⟨Q, Q⟩_F = ||Q||_F^2 - tr(Q^T Q) ≥ 0
tr_RQ = torch.trace(RQ)
RRQ = R @ RQ
tr_RRQ = torch.trace(RRQ)
if order == 2:
# clip step size to max_step_size, based on a 2nd order expansion.
_step_size = torch.clamp(-tr_RQ / tr_RRQ, min=0, max=max_step_size)
# If tr_RRQ >= 0, the quadratic approximation is not concave, we fallback to max_step_size.
step_size = torch.where(tr_RRQ < 0, _step_size, max_step_size)
# rotate Q as exp(a R) Q ~ (I + a R + a^2 R^2/2) Q with an optimal step size by line search
# for 2nd order expansion, only expand exp(a R) to its 2nd term.
# Q += _step_size * (RQ + 0.5 * _step_size * RRQ)
Q = torch.add(Q, torch.add(RQ, RRQ, alpha=0.5 * step_size), alpha=step_size)
if order == 3:
RRRQ = R @ RRQ
tr_RRRQ = torch.trace(RRRQ)
# for a 3rd order expansion, we take the larger root of the cubic.
_step_size = (-tr_RRQ - torch.sqrt(tr_RRQ * tr_RRQ - 1.5 * tr_RQ * tr_RRRQ)) / (0.75 * tr_RRRQ)
step_size = torch.clamp(_step_size, max=max_step_size)
# Q += step_size * (RQ + 0.5 * step_size * (RRQ + 0.25 * step_size * RRRQ))
Q = torch.add(
Q, torch.add(RQ, torch.add(RRQ, RRRQ, alpha=0.25 * step_size), alpha=0.5 * step_size), alpha=step_size
)
return Q