emerging_optimizers.utils#
- emerging_optimizers.utils.fp32_matmul_precision(precision='highest')[source]#
Context manager for setting the precision of matmuls.
emerging_optimizers.utils.eig#
- emerging_optimizers.utils.eig.conjugate(a, p, diag=False)[source]#
Calculate similarity transformation
This function calculates \(B = P^T A P\). It assumes P is orthogonal so that \(P^{-1} = P^T\) and the similarity transformation exists.
- emerging_optimizers.utils.eig.eigh_with_fallback(x, force_double=False, eps=None, output_dtype=None)[source]#
torch.linalg.eigh() function with double precision fallback
Unified wrapper over eigh() function with automatic fallback and force double precision options. Automatically falls back to double precision on failure and returns eigenvalues in descending order. Default 2nd argument of eigh UPLO is βLβ.
- Parameters:
x (Tensor) β Tensor of shape (, n, n) where ββ is zero or more batch dimensions consisting of symmetric or Hermitian matrices.
force_double (bool) β Force double precision computation. Default False.
eps (float | None) β Small offset for numerical stability. If None, uses dtype-appropriate values (1e-7 for float32, 1e-15 for float64). Default None.
output_dtype (dtype | None) β Desired output dtype. If None, uses input dtype. Default None.
- Returns:
Eigenvalues and eigenvectors tuple (eigenvalues in descending order).
- Return type:
- emerging_optimizers.utils.eig.met_approx_eigvals_criteria(
- kronecker_factor,
- approx_eigvals,
- tolerance,
Determines whether the eigenbasis for a factor matrix met the desired criteria
The approximated eigenvalues update criteria is then defined as \(||diag(Q^T K Q)||_F >= (1 - tolerance) * (Q^T K Q)_F\), where \(Q\) is the approximated eigenvectors and \(K\) is the kronecker factor (L or R).
We use the kronecker factor and approximated eigenvalues directly to save compute because Frobenius norm of kronecker factor is the same as that of the approximated eigenvalues matrix.
- Parameters:
- Returns:
Whether to update eigenbasis this iteration
- Return type:
perform_update
- emerging_optimizers.utils.eig.orthogonal_iteration(
- approx_eigvals,
- kronecker_factor,
- eigenbasis,
- ind,
- exp_avg_sq,
- convert_to_float,
- power_iter_steps,
Computes the eigenbases of the preconditioner using power iteration and QR decomposition.
This function performs multiple rounds of power iteration followed by QR decomposition to recompute the eigenbases of the preconditioner kronecker factor. Generalizes Vyas et al.βs (SOAP) algorithm of 1 step of power iteration for updating the eigenbasis.
- Parameters:
approx_eigenvalue_matrix β Projection of kronecker factor onto the eigenbasis, should be close to diagonal
kronecker_factor (Tensor) β Kronecker factor matrix.
eigenbasis (Tensor) β Kronecker factor eigenbasis matrix.
ind (int) β Index for selecting dimension in the exp_avg_sq matrix to apply the sorting order over.
exp_avg_sq (Tensor) β inner Adam second moment (exp_avg_sq).
convert_to_float (bool) β If True, preconditioner matrices and their corresponding orthonormal matrices will be cast to float. Otherwise, they are left in their original type. Defaults to False.
power_iter_steps (int) β Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.
approx_eigvals (Tensor)
- Returns:
- A tuple containing:
Q: The updated eigenbasis
exp_avg_sq: The updated (sorted) inner Adam second moment
- Return type: