emerging_optimizers.utils#
- class emerging_optimizers.utils.SinkhornMapper(num_iters=20, eps=1e-08)[source]#
Applies the Sinkhorn-Knopp mapping to the input tensor.
The Sinkhorn-Knopp mapping is an iterative technique for normalizing the rows and columns of a matrix: Input -> [Exp] -> [Iterative Row/Col Normalization]
Supports batched inputs (3D+). The mapping operates on the last two dimensions.
- For an M×N matrix, the normalization targets are:
Square (M=N): row sums = 1.0, col sums = 1.0 (standard doubly-stochastic)
Wide (N>M): row sums = N/M, col sums = 1.0
Tall (M>N): row sums = 1.0, col sums = M/N
Based on Deepseek’s Manifold-Constrained Hyperconnections (https://arxiv.org/abs/2512.24880)
- emerging_optimizers.utils.fp32_matmul_precision(precision='highest')[source]#
Context manager for setting the precision of matmuls.
emerging_optimizers.utils.eig#
- emerging_optimizers.utils.eig.conjugate(a, p, diag=False)[source]#
Calculate similarity transformation
This function calculates \(B = P^T A P\). It assumes P is orthogonal so that \(P^{-1} = P^T\) and the similarity transformation exists.
- emerging_optimizers.utils.eig.eigh_with_fallback(x, force_double=False, eps=None)[source]#
torch.linalg.eigh() function with double precision fallback
Unified wrapper over eigh() function with automatic fallback and force double precision options. Automatically falls back to double precision on failure and returns eigenvalues in descending order. Default 2nd argument of eigh UPLO is ‘L’.
- Parameters:
x (Tensor) – Tensor of shape (, n, n) where “” is zero or more batch dimensions consisting of symmetric or Hermitian matrices.
force_double (bool) – Force double precision computation. Default False.
eps (float | None) – Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, 1e-15 for float64). Default None.
- Returns:
Eigenvalues and eigenvectors tuple (eigenvalues in descending order).
- Return type:
- emerging_optimizers.utils.eig.met_approx_eigvals_criteria(
- kronecker_factor,
- approx_eigvals,
- tolerance,
Determines whether the eigenbasis for a factor matrix met the desired criteria
The approximated eigenvalues update criteria is then defined as \(||diag(Q^T K Q)||_F >= (1 - tolerance) * (Q^T K Q)_F\), where \(Q\) is the approximated eigenvectors and \(K\) is the kronecker factor (L or R).
We use the kronecker factor and approximated eigenvalues directly to save compute because Frobenius norm of kronecker factor is the same as that of the approximated eigenvalues matrix.
- Parameters:
- Returns:
Whether eigenbasis meet criteria and don’t need to be updated
- Return type:
- emerging_optimizers.utils.eig.orthogonal_iteration(
- approx_eigvals,
- kronecker_factor,
- eigenbasis,
- ind,
- exp_avg_sq,
- power_iter_steps,
Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
This function performs multiple rounds of power iteration followed by QR decomposition to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.’s (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis.
- Parameters:
approx_eigvals (Tensor) – Projection of kronecker factor onto the eigenbasis, should be close to diagonal
kronecker_factor (Tensor) – Kronecker factor matrix.
eigenbasis (Tensor) – Kronecker factor eigenbasis matrix.
ind (int) – Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over.
exp_avg_sq (Tensor) – inner Adam second moment (exp_avg_sq).
power_iter_steps (int) – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.
- Returns:
- A tuple containing:
Q: The updated eigenbasis
exp_avg_sq: The updated (sorted) inner Adam second moment
- Return type:
emerging_optimizers.utils.modules#
- class emerging_optimizers.utils.modules.Conv1dFlatWeights(*args, **kwargs)[source]#
Conv1d with weights+bias stored in a single 2D tensor
There are conv1d used in some LLM, in mamba mixer for example. Because the weight is not 2d, we cannot apply many of the emerging optimizers originally introduced for 2d weights of Linear layers without bias. Since convolution can be viewed as a matrix multiplication with im2col (either implicit or explicit), we can flatten the weight into a single 2D tensor and then apply the emerging optimizers to it.
Bias is not commonly used in most LLM’s anymore, but they are often included in this type of conv1d. Since bias is mathematically the 0 order term of the polynomial, we can combine weight and bias into a single 2D tensor.
Arguments are the same as :
torch.nn.Conv1d.Note
This implementation potentially introduces a small overhead because of split weights can combining gradients of it. This should be trivial compared to computational cost of LLM training. If it becomes a concern, a kernel can be developed to eliminate the overhead.
Note
Similar flattening logic can be applied to N-D convolution. But since we don’t have use cases of them in LLM yet, they are not supported despite the __init__() function is generalized enough to support N-D convolution.
- extra_repr()[source]#
Return the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- Return type:
- forward(x)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.