Source code for emerging_optimizers.utils.sinkhorn_mapper

# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
import torch.nn.functional as F


__all__ = [
    "SinkhornMapper",
]


[docs] class SinkhornMapper: """Applies the Sinkhorn-Knopp mapping to the input tensor. The Sinkhorn-Knopp mapping is an iterative technique for normalizing the rows and columns of a matrix: Input -> [Exp] -> [Iterative Row/Col Normalization] Supports batched inputs (3D+). The mapping operates on the last two dimensions. For an M×N matrix, the normalization targets are: - Square (M=N): row sums = 1.0, col sums = 1.0 (standard doubly-stochastic) - Wide (N>M): row sums = N/M, col sums = 1.0 - Tall (M>N): row sums = 1.0, col sums = M/N Based on Deepseek's Manifold-Constrained Hyperconnections (https://arxiv.org/abs/2512.24880) Args: num_iters: The number of iterations to run the Sinkhorn-Knopp mapping. eps: The epsilon value to use for the Sinkhorn-Knopp mapping for numerical stability. """ def __init__(self, num_iters: int = 20, eps: float = 1e-8): self.num_iters = num_iters self.eps = eps @torch.no_grad() def _sinkhorn_map(self, x: torch.Tensor, inplace: bool = False) -> torch.Tensor: """Apply Sinkhorn-Knopp mapping to the input tensor. Args: x: Input tensor to apply the mapping to. Must be at least 2D. Batched inputs (3D+) are supported. inplace: If True, modify x in place. If False, work on a copy. Returns: The tensor with the Sinkhorn-Knopp mapping applied. """ if x.dim() < 2: raise ValueError( f"{self.__class__.__name__} requires at least a 2D tensor, got {x.dim()}D with shape {x.shape}" ) result = x if inplace else x.clone() # Enforce positivity via exp with numerical stability. # Subtract global max before exp to prevent overflow (log-sum-exp trick). # The normalization step will scale the result, so subtracting any max (global, row, or column) # is sufficient for numerical stability. result.sub_(result.max()).exp_() # Determine normalization targets based on aspect ratio. # For non-square matrices (M x N), we scale the shorter dimension so that # rows sum to N/M and cols sum to 1.0 (if N > M), or # rows sum to 1.0 and cols sum to M/N (if M > N). # See chapter 4 of https://arxiv.org/abs/1803.00567. # For square matrices, both targets are 1.0 (standard doubly-stochastic). M, N = result.shape[-2], result.shape[-1] if N > M: row_target = N / M col_target = 1.0 else: row_target = 1.0 col_target = M / N # Iterative normalization of rows and columns for _ in range(self.num_iters): # Normalize columns (along row dimension) F.normalize(result, p=1, dim=-2, eps=self.eps, out=result) if col_target != 1.0: result.mul_(col_target) # Normalize rows (along column dimension) F.normalize(result, p=1, dim=-1, eps=self.eps, out=result) if row_target != 1.0: result.mul_(row_target) return result @torch.no_grad() def __call__(self, x: torch.Tensor, inplace: bool = False) -> torch.Tensor: """Apply Sinkhorn-Knopp mapping to the input tensor. Args: x: Input tensor to apply the mapping to. inplace: If True, modify x in place. If False, work on a copy. Returns: The tensor with the Sinkhorn-Knopp mapping applied (modified in place if inplace=True, otherwise a new tensor). """ return self._sinkhorn_map(x, inplace=inplace)