emerging_optimizers.soap#

SOAP#

class emerging_optimizers.soap.soap.SOAP(
params,
lr,
betas=(0.9, 0.95),
shampoo_beta=0.95,
eps=1e-08,
weight_decay=0.01,
*,
weight_decay_method='decoupled',
nesterov=False,
precondition_frequency=1,
adam_warmup_steps=0,
correct_bias=True,
fp32_matmul_prec='high',
use_eigh=False,
qr_fp32_matmul_prec='high',
use_adaptive_criteria=False,
adaptive_update_tolerance=1e-07,
power_iter_steps=1,
max_update_rms=0.0,
use_kl_shampoo=False,
correct_shampoo_beta_bias=None,
)[source]#

Implements a variant of SOAP (ShampoO with Adam in the Preconditioner eigenbasis) algorithm.

SOAP (https://arxiv.org/abs/2409.11321) is a preconditioned optimizer that combines the benefits of Shampoo’s non-diagonal preconditioning with Adam’s adaptive learning rates. It uses gradient correlation matrix eigenbasis-based preconditioning to adapt to the local geometry of the optimization landscape.

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate to use

  • betas (tuple[float, float]) – Inner Adam’s betas parameters (b1, b2)

  • shampoo_beta (float) – Beta for the kronecker factor matrices (L and R in paper) moving average instead of betas[1] if >= 0

  • eps (float) – Inner Adam’s epsilon for numerical stability

  • weight_decay (float) – Weight decay coefficient

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • nesterov (bool) – uses Nesterov momentum in Adam (https://cs229.stanford.edu/proj2015/054_report.pdf, https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ)

  • precondition_frequency (int | Callable[[int], int]) – How often to update the preconditioner. Can be an integer for fixed frequency or a callable function that takes the current step as input and returns the frequency.

  • adam_warmup_steps (int) – How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)

  • correct_bias (bool) – Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations

  • use_eigh (bool) – Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis. If False, use orthogonal iteration to compute the eigenbasis.

  • qr_fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in QR decomposition.

  • use_adaptive_criteria (bool) – Whether to use criteria to determine if eigenbasis update is needed

  • adaptive_update_tolerance (float) – Tolerance threshold for the update criteria. Only used if use_adaptive_criteria is True.

  • power_iter_steps (int) – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.

  • max_update_rms (float) – Clip the update RMS to this value (0 means no clipping).

  • use_kl_shampoo (bool) – Whether to use KL-Shampoo correction.

  • correct_shampoo_beta_bias (bool | None) – Whether to correct shampoo beta bias. Decoupled it from correct_bias for testability because reference implementation of Soap doesn’t bias correct shampoo beta.

step(closure: None = None) None[source]#
step(closure: Callable[[], float]) float

Performs a single optimization step.

Parameters:

closure – A closure that reevaluates the model and returns the loss.

emerging_optimizers.soap.soap.precondition(x, eigenbasis_list=None, dims=None)[source]#

Projects the gradient to and from the eigenbases of the kronecker factor matrices.

This function performs tensor contractions between the input gradient and kronecker factor eigenbases.

Note

For 2D tensors, we can use matmul instead of tensordot for code legibility. However, the code has been using tensordot historically, so does the reference implementation. It is difficult to match matmul and tensordot outputs exactly because of underlying floating point arithmetic differences. Therefore, we decided to keep using tensordot for consistency.

Parameters:
  • x (Tensor) – Input tensor to be preconditioned

  • eigenbasis_list (list[Tensor] | None) – List of eigenbases for preconditioning. Each matrix should be a square matrix of eigenvectors.

  • dims (list[list[int]] | None) – Dimensions for tensor contraction. Default is [[0], [0]] which contracts the first dimension of grad with the first dimension of each eigenbasis matrix, for projecting into the eigenbasis. Use [[0], [1]] for projecting back to original space.

Return type:

Tensor

Example

>>> x = torch.randn(10, 20)
>>> Q = torch.randn(10, 10)
>>> precondition(x, [Q], dims=[[0], [0]])
emerging_optimizers.soap.soap.init_kronecker_factors(grad_shape, device=None)[source]#

Initializes the kronecker factor matrices for the SOAP optimizer.

This function creates the initial Kronecker factor matrices (L and R) used for preconditioning. It creates a square kronecker factor matrix for each dimension of the 2D gradient shape.

Note

The Kronecker factors are always initialized to float32 (unless default precision is set otherwise) as its accumulation and decomposition are not safe in lower precisions.

Parameters:
  • grad_shape (Size) – Shape of the gradient tensor. Must be 2D. Determines the size of the kronecker factor matrices.

  • device (device | None) – Device on which to create the kronecker factor matrices.

Returns:

Tuple of kronecker factor matrices (L and R in paper).

Return type:

tuple[Tensor, Tensor]

Example

>>> # For a 2D tensor (weight matrix)
>>> grad_shape = torch.Size([10, 20])
>>> precond_2d = init_kronecker_factors(grad_shape)
>>> print(len(precond_2d))  # 2
>>> print(precond_2d[0].shape)  # (10, 10)
>>> print(precond_2d[1].shape)  # (20, 20)
emerging_optimizers.soap.soap.update_kronecker_factors(kronecker_factor_list, grad, shampoo_beta)[source]#

Updates the preconditioner matrices using gradient outer products.

This function updates the Kronecker factor matrices (L and R) used for preconditioning by computing and accumulating gradient outer products. kronecker_factor_list is updated in place.

Parameters:
  • kronecker_factor_list (list[Tensor]) – List of preconditioner matrices (L and R) to update. Each matrix should be square and match the corresponding dimension of grad.

  • grad (Tensor) – Gradient tensor of the parameter being optimized

  • shampoo_beta (float) – Momentum coefficient for updating preconditioners. Controls how much weight to give to new vs old gradient statistics.

Return type:

None

Example

>>> grad = torch.randn(10, 20)
>>> L = torch.zeros(10, 10)
>>> R = torch.zeros(20, 20)
>>> update_kronecker_factors([L, R], grad, shampoo_beta=0.95)
emerging_optimizers.soap.soap.update_kronecker_factors_kl_shampoo(
kronecker_factor_list,
grad,
shampoo_beta,
eigenbasis_list,
eps,
eigval_exp=-1.0,
)[source]#

Updates the kronecker factor matrices in place using KL-Shampoo correction.

Implement Kullback–Leibler Minimization from https://arxiv.org/pdf/2509.03378

Parameters:
  • kronecker_factor_list (list[Tensor]) – List of preconditioner matrices (L and R) to update.

  • grad (Tensor) – Gradient tensor of the parameter being optimized

  • shampoo_beta (float) – Momentum coefficient for updating preconditioners.

  • eigenbasis_list (list[Tensor]) – List of orthonormal eigenbases of the kronecker factor matrices

  • eps (float) – Small offset for numerical stability.

  • eigenval_exp – Exponent of the eigenvalues.

  • eigval_exp (float)

Return type:

None

emerging_optimizers.soap.soap.update_eigenbasis_and_exp_avgs(
kronecker_factor_list,
eigenbasis_list,
exp_avg_sq,
exp_avg,
use_eigh=False,
power_iter_steps=1,
)[source]#

Updates the eigenbases and moving averages.

This function performs an update of the eigenbases (QL and QR) used for preconditioning. It follows these steps:

  1. Projects exp_avg back to the original basis

  2. Updates the eigenbases using QR decomposition and power iteration (orthogonal iteration)

  3. Projects exp_avg back to the new eigenbasis

Parameters:
  • kronecker_factor_list (list[Tensor]) – List of preconditioner matrices (L and R) that define the optimization landscape. These are updated with gradient statistics.

  • eigenbasis_list (list[Tensor]) – List of current eigenbases (QL and QR) used for preconditioning. These will be updated by this function.

  • exp_avg_sq (Tensor) – Inner Adam’s second moment tensor, used for scaling the preconditioner updates. This tensor is modified in-place.

  • exp_avg (Tensor) – Inner Adam’s first moment tensor, used for tracking gradient momentum. This tensor is modified in-place.

  • use_eigh (bool) – Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis. If False, use orthogonal iteration to compute the eigenbasis.

  • power_iter_steps (int) – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.

Returns:

  • Updated list of eigenbases (QL and QR)

  • Updated exp_avg tensor projected to the new eigenbasis

  • Updated exp_avg_sq tensor

Return type:

A tuple containing

Example

>>> L = torch.randn(10, 10)
>>> R = torch.randn(20, 20)
>>> QL = torch.randn(10, 10)
>>> QR = torch.randn(20, 20)
>>> exp_avg_sq = torch.randn(10, 20)
>>> exp_avg = torch.randn(10, 20)
>>> updated_eigenbasis_list, updated_exp_avg, updated_exp_avg_sq = update_eigenbasis_and_exp_avgs(
...     [L, R], [QL, QR], exp_avg_sq, exp_avg)

REKLS#

class emerging_optimizers.soap.rekls.REKLS(
params,
lr,
betas=(0.9, 0.95),
shampoo_beta=0.95,
eps=1e-08,
weight_decay=0.01,
*,
weight_decay_method='decoupled',
)[source]#

REKLS (Realtime Eigen Kullback-Leibler Soap) optimizer.

REKLS is a variant of SOAP that uses the up to date eigenbasis calculated by Eigen decomposition. It is “up to date” because current step’s gradient is accumulated to the kronecker factor before eigenbasis update.

Note

Refer to SOAP for detailed documentation of arguments.

Parameters:

emerging_optimizers.soap.soap_utils#

emerging_optimizers.soap.soap_utils.all_eigenbases_met_criteria(
kronecker_factor_list,
eigenbasis_list,
adaptive_update_tolerance=1e-07,
)[source]#

Checks if every eigenbasis in the list meets the adaptive update tolerance criteria.

Parameters:
  • kronecker_factor_list (list[Tensor]) – List of Kronecker factor matrices

  • eigenbasis_list (list[Tensor]) – List of orthonormal eigenbases of the kronecker factor matrices

  • adaptive_update_tolerance (float) – Tolerance threshold for the normalized diagonal component of approximated eigenvalue matrix.

Returns:

True if all eigenbases meet the criteria (no update needed), False otherwise.

Return type:

bool

emerging_optimizers.soap.soap_utils.get_eigenbasis_eigh(kronecker_factor_list, eps=None)[source]#

Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.

Parameters:
  • kronecker_factor_list (list[Tensor]) – Matrix List to compute eigenbases of

  • eps (float | None) – Small offset for numerical stability.

Returns:

List of orthonormal kronecker factor eigenbases matrices

Return type:

list[Tensor]

Example

# Create sample Kronecker factors (symmetric positive definite matrices)
k_factor1 = torch.randn(4, 4)
k_factor1 = k_factor1 @ k_factor1.T  # Make symmetric positive definite
k_factor2 = torch.randn(5, 5)
k_factor2 = k_factor2 @ k_factor2.T  # Make symmetric positive definite

# Get orthogonal matrices for these factors
ortho_matrices = get_eigenbasis_eigh([k_factor1, k_factor2])
# ortho_matrices[0] has shape [4, 4] and ortho_matrices[1] has shape [5, 5]
emerging_optimizers.soap.soap_utils.get_eigenbasis_qr(
kronecker_factor_list,
eigenbasis_list,
exp_avg_sq,
power_iter_steps=1,
)[source]#

Updates the eigenbases of the preconditioner using power iteration and QR.

Computes using multiple rounds of power iteration followed by QR decomposition (orthogonal iteration).

Parameters:
  • kronecker_factor_list (list[Tensor]) – List containing preconditioner (\(GG^T\) and \(G^TG\))

  • eigenbasis_list (list[Tensor]) – List containing eigenbases (\(Q_L\) and \(Q_R\))

  • exp_avg_sq (Tensor) – inner adam second moment (exp_avg_sq).

  • power_iter_steps (int) – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.

Returns:

Tuple of updated list of orthonormal kronecker factor eigenbases matrices and updated (sorted) inner

Adam’s second moment.

Return type:

tuple[list[Tensor], Tensor]

Example

# Create sample Kronecker factors (symmetric positive definite matrices)
n, m = 10, 20
k_factor1 = torch.randn(n, n)
k_factor1 = k_factor1 @ k_factor1.T  # Make symmetric positive definite
k_factor2 = torch.randn(m, m)
k_factor2 = k_factor2 @ k_factor2.T  # Make symmetric positive definite

# Get orthogonal matrices for these kronecker factors
kronecker_factor_list = [k_factor1, k_factor2]
eigenbasis_list = get_eigenbasis_eigh(kronecker_factor_list)

# Perturb the kronecker factor matrices, simulating the effect of gradient updates
perturbation = 1e-2*torch.randn(n, m)
perturbed_kronecker_factor_list = [None, None]
perturbed_kronecker_factor_list[0] = k_factor1 + perturbation@perturbation.T
perturbed_kronecker_factor_list[1] = k_factor2 + perturbation.T@perturbation

# Initialize exp_avg_sq tensor
exp_avg_sq = torch.randn(n, m).abs()

# Refine the orthogonal matrices using QR
updated_ortho_matrices, updated_exp_avg_sq = get_eigenbasis_qr(
    perturbed_kronecker_factor_list,
    eigenbasis_list,
    exp_avg_sq
)