Source code for emerging_optimizers.soap.soap_utils

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TypeAlias

import torch

from emerging_optimizers.utils import eig as eig_utils


TensorList: TypeAlias = list[torch.Tensor]


__all__ = [
    "get_eigenbasis_eigh",
    "get_eigenbasis_qr",
    "sort_eigenbasis_by_approx_eigvals",
]


[docs] def sort_eigenbasis_by_approx_eigvals( kronecker_factor_list: TensorList, eigenbasis_list: TensorList, exp_avg_sq: torch.Tensor, ) -> tuple[TensorList, torch.Tensor]: """Permute each eigenbasis and matching ``exp_avg_sq`` axis by descending approximate eigenvalues. Both the eigh and QR eigenbasis-update paths consume the sorted output: the eigh path discards the permuted eigenbasis but uses the permuted ``exp_avg_sq`` (the sort_idx best approximates the permutation component of the eigh-vs-old basis change under small drift); the QR path power- iterates from the pre-sorted basis. Args: kronecker_factor_list: List of preconditioner matrices (L and R). eigenbasis_list: List of current eigenbases (QL and QR). exp_avg_sq: Inner Adam second moment tensor permuted along each Kronecker-factor axis to match the new descending-eigenvalue column ordering. Returns: ``(sorted_eigenbasis_list, sorted_exp_avg_sq)``. """ sorted_eigenbasis_list: TensorList = [] sorted_exp_avg_sq = exp_avg_sq for ind, (kronecker_factor, eigenbasis) in enumerate(zip(kronecker_factor_list, eigenbasis_list, strict=True)): approx_eigvals = eig_utils.conjugate(kronecker_factor, eigenbasis, diag=True) sort_idx = torch.argsort(approx_eigvals, descending=True) sorted_eigenbasis_list.append(eigenbasis[:, sort_idx]) sorted_exp_avg_sq = sorted_exp_avg_sq.index_select(ind, sort_idx) return sorted_eigenbasis_list, sorted_exp_avg_sq
[docs] def get_eigenbasis_eigh( kronecker_factor_list: TensorList, ) -> TensorList: """Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition. Args: kronecker_factor_list: Matrix List to compute eigenbases of Returns: List of orthonormal kronecker factor eigenbases matrices Example: .. code-block:: python # Create sample Kronecker factors (symmetric positive definite matrices) k_factor1 = torch.randn(4, 4) k_factor1 = k_factor1 @ k_factor1.T # Make symmetric positive definite k_factor2 = torch.randn(5, 5) k_factor2 = k_factor2 @ k_factor2.T # Make symmetric positive definite # Get orthogonal matrices for these factors ortho_matrices = get_eigenbasis_eigh([k_factor1, k_factor2]) # ortho_matrices[0] has shape [4, 4] and ortho_matrices[1] has shape [5, 5] """ updated_eigenbasis_list: TensorList = [] for kronecker_factor in kronecker_factor_list: _, eigenvectors = eig_utils.eigh_with_fallback(kronecker_factor, force_double=False) updated_eigenbasis_list.append(eigenvectors) return updated_eigenbasis_list
[docs] def get_eigenbasis_qr( kronecker_factor_list: TensorList, eigenbasis_list: TensorList, power_iter_steps: int = 1, ) -> TensorList: """Updates the eigenbases of the preconditioner using power iteration and QR. Computes using multiple rounds of power iteration followed by QR decomposition (orthogonal iteration). ``eigenbasis_list`` is expected to be already sorted by descending approximate eigenvalues (see :func:`sort_eigenbasis_by_approx_eigvals`). Args: kronecker_factor_list: List of preconditioner matrices (L and R). eigenbasis_list: List of current eigenbases (QL and QR). power_iter_steps: Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time. Returns: Updated list of orthonormal eigenbases (QL and QR). Example: .. code-block:: python # Create sample Kronecker factors (symmetric positive definite matrices) n, m = 10, 20 k_factor1 = torch.randn(n, n) k_factor1 = k_factor1 @ k_factor1.T # Make symmetric positive definite k_factor2 = torch.randn(m, m) k_factor2 = k_factor2 @ k_factor2.T # Make symmetric positive definite # Get orthogonal matrices for these kronecker factors kronecker_factor_list = [k_factor1, k_factor2] eigenbasis_list = get_eigenbasis_eigh(kronecker_factor_list) # Perturb the kronecker factor matrices, simulating the effect of gradient updates perturbation = 1e-2*torch.randn(n, m) perturbed_kronecker_factor_list = [None, None] perturbed_kronecker_factor_list[0] = k_factor1 + perturbation@perturbation.T perturbed_kronecker_factor_list[1] = k_factor2 + perturbation.T@perturbation # Refine the orthogonal matrices using QR (eigenbasis_list already sorted) updated_ortho_matrices = get_eigenbasis_qr( perturbed_kronecker_factor_list, eigenbasis_list, ) """ updated_eigenbasis_list: TensorList = [] for kronecker_factor, eigenbasis in zip(kronecker_factor_list, eigenbasis_list, strict=True): Q = eig_utils.orthogonal_iteration( kronecker_factor=kronecker_factor, eigenbasis=eigenbasis, power_iter_steps=power_iter_steps, ) updated_eigenbasis_list.append(Q) return updated_eigenbasis_list