Source code for emerging_optimizers.psgd.psgd_kron_contractions

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List

import torch


__all__ = [
    "partial_contraction",
    "apply_kronecker_factors",
    "apply_preconditioner",
]


[docs] def partial_contraction(G1: torch.Tensor, G2: torch.Tensor, axis: int) -> torch.Tensor: """Compute the partial contraction of G1 and G2 along axis `axis`. This is the contraction of the two tensors, but with all axes except `axis` contracted. Args: G1: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) G2: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis}, d_{axis+1}, ..., d_N) axis: int, the axis to contract along Returns: Tensor of shape (d_{axis}, d_{axis}) """ # dims_to_contract = all dims except `axis` dims_to_contract = [i for i in range(G1.dim()) if i != axis] # contraction is symmetric and has shape (d_{axis}, d_{axis}) return torch.tensordot(G1, G2, dims=(dims_to_contract, dims_to_contract))
[docs] @torch.compile # type: ignore[misc] def apply_kronecker_factors(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: """Apply all Kronecker factors once to tensor :math:`X`, each to its corresponding dimension. This applies each :math:`Q` factor once, for example in 2D case: :math:`Q_1 X Q_2^T`. Args: Q_list: List of :math:`Q` (the upper-triangular Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. X: Tensor of shape `(d_0, d_1, ..., d_N)`. Returns: Tensor of shape `(d_0, d_1, ..., d_N)`. """ if len(Q_list) != X.dim(): raise ValueError( f"Number of Kronecker factors {len(Q_list)} must match the number of dimensions of X {X.dim()}" ) Y = X for i in range(len(Q_list)): Y = _apply_single_kronecker_factor(Q_list, Y, i) return Y
[docs] @torch.compile # type: ignore[misc] def apply_preconditioner(Q_list: List[torch.Tensor], X: torch.Tensor) -> torch.Tensor: """Apply the full PSGD preconditioner to X. This is the full Kronecker product of PSGD's kronecker factors Q^T Q, applied to X. :math:`P X = (Q_1^T Q_1) X (Q_2^T Q_2)` This applies each factor followed by its transpose for the full preconditioner effect. Args: Q_list: List of :math:`Q` (the Kronecker factors), each of shape `(d_i, d_i)` or `(d_i,)`. X: Tensor of shape `(d_0, d_1, ..., d_N)`. Returns: Tensor of shape `(d_0, d_1, ..., d_N)`. """ # Apply Q first, then Q.T to get Q^T @ Q Px = apply_kronecker_factors(Q_list, X) Px = apply_kronecker_factors([q if q.dim() == 1 else q.T for q in Q_list], Px) return Px
def _dim_n_mul_and_permute(X: torch.Tensor, M: torch.Tensor, contract_dim: int) -> torch.Tensor: """Multiply tensor X along axis `contract_dim` by 2D matrix M. Helper function for `_apply_single_kronecker_factor`. If M is (d_out, d_in) we contract M’s second index with X’s `contract_dim` index. `torch.tensordot` is used to contract the two tensors, and then the result is permuted to move the new axis 0 to position `contract_dim`. Returns a new tensor of the same rank, but with size[contract_dim] replaced by d_out. Note that d_{contract_dim} == d_in. Args: X: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_{contract_dim}, d_{contract_dim+1}, ..., d_N) M: Tensor of shape (d_out, d_in) contract_dim: int, the dimension to contract with M, with d_{contract_dim} == d_in Returns: Tensor of shape (d_0, d_1, ..., d_{contract_dim-1}, d_out, d_{contract_dim+1}, ..., d_N) Examples -------- >>> X = torch.randn(2, 3, 6) >>> M = torch.randn(5, 6) >>> contract_dim = 2 >>> result = _dim_n_mul_and_permute(X, M, contract_dim) >>> print(result.shape) torch.Size([2, 3, 5]) """ if X.shape[contract_dim] != M.shape[1]: raise ValueError( f"Shape mismatch: X.shape[{contract_dim}] = {X.shape[contract_dim]}, M.shape[1] = {M.shape[1]}" ) # Contract M's 2nd dim (idx=1) with X's `contract_dim` dim Y = torch.tensordot(M, X, dims=([1], [contract_dim])) # Y now has shape (d_out, d_0, …, d_{contract_dim-1}, d_{contract_dim+1}, …). # We want to move that new axis 0 back to position `contract_dim`, due to `torch.tensordot`. nd = X.dim() perm = list(range(1, contract_dim + 1)) + [0] + list(range(contract_dim + 1, nd)) return Y.permute(perm) @torch.compile # type: ignore[misc] def _apply_single_kronecker_factor(Q_list: List[torch.Tensor], X: torch.Tensor, axis: int) -> torch.Tensor: """Apply a single Kronecker factor Q to X at dimension `axis`. Helper function for apply_kronecker_factors. If Q is a vector, we multiply X by Q. If Q is a matrix, we contract Q's second index with X's `axis` index. Args: Q_list: List of Q (e.g. the Kronecker factors). X: Tensor of shape (d_0, d_1, ..., d_{axis-1}, d_{axis+1}, ..., d_N) """ Q = Q_list[axis] if Q.dim() == 1: shape = [1] * X.dim() shape[axis] = Q.size(0) return X * Q.view(shape) return _dim_n_mul_and_permute(X, Q, contract_dim=axis)