emerging_optimizers.soap#

SOAP#

class emerging_optimizers.soap.soap.SOAP(
params: Iterable[Parameter],
lr: float = 0.003,
betas: Tuple[float, float] | None = None,
shampoo_beta: float = -1,
eps: float = 1e-08,
weight_decay: float = 0.01,
use_decoupled_weight_decay: bool = True,
use_nesterov: bool = False,
precondition_frequency: int | Callable[[int], int] = 10,
precondition_warmup_steps: int = 0,
adam_warmup_steps: int = 1,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
trace_normalization: bool = False,
normalize_preconditioned_grads: bool = False,
correct_bias: bool = True,
fp32_matmul_prec: str = 'high',
use_eigh: bool = False,
qr_fp32_matmul_prec: str = 'high',
use_adaptive_criteria: bool = False,
adaptive_update_tolerance: float | None = None,
power_iter_steps: int = 1,
max_update_rms: float = 0.0,
)[source]#

Implements a variant of SOAP (ShampoO with Adam in the Preconditioner eigenbasis) algorithm.

SOAP (https://arxiv.org/abs/2409.11321) is a preconditioned optimizer that combines the benefits of Shampoo’s non-diagonal preconditioning with Adam’s adaptive learning rates. It uses gradient correlation matrix eigenbasis-based preconditioning to adapt to the local geometry of the optimization landscape.

Parameters:
  • params – Iterable of parameters to optimize or dicts defining parameter groups

  • lr – The learning rate to use

  • betas – Inner Adam’s betas parameters (b1, b2)

  • shampoo_beta – Beta for the kronecker factor matrices (L and R in paper) moving average instead of betas[1] if >= 0

  • eps – Inner Adam’s epsilon for numerical stability

  • weight_decay – Weight decay coefficient

  • use_decoupled_weight_decay – Whether to use decoupled weight decay, see Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101.

  • use_nesterov – uses Nesterov momentum in Adam (https://cs229.stanford.edu/proj2015/054_report.pdf, https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ)

  • precondition_frequency – How often to update the preconditioner. Can be an integer for fixed frequency or a callable function that takes the current step as input and returns the frequency.

  • precondition_warmup_steps – How many steps to warm up the preconditioner (i.e. update every step)

  • adam_warmup_steps – How many steps to skip preconditioning in the beginning (i.e. use standard AdamW updates)

  • precondition_1d – Whether to precondition 1D gradients (like biases).

  • max_precond_dim – Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.

  • trace_normalization – Whether to normalize update by the trace of the kronecker factor matrix

  • normalize_preconditioned_grads – Whether to normalize preconditioned gradients per layer

  • correct_bias – Whether to use bias correction in Inner Adam and Kronecker factor matrices EMA

  • fp32_matmul_prec – Precision of the matmul operations in optimizer states GEMM operations

  • use_eigh – Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis. If False, use orthogonal iteration to compute the eigenbasis.

  • qr_fp32_matmul_prec – Precision of the matmul operations in QR decomposition.

  • use_adaptive_criteria – Whether to use criteria to determine if eigenbasis update is needed

  • adaptive_update_tolerance – Tolerance threshold for the update criteria. Only used if use_adaptive_criteria is True.

  • power_iter_steps – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.

  • max_update_rms – Clip the update RMS to this value (0 means no clipping).

step(
closure: Callable[[], float] | None = None,
) float | None[source]#

Performs a single optimization step.

Parameters:

closure – A closure that reevaluates the model and returns the loss.

emerging_optimizers.soap.soap.precondition(
grad: Tensor,
eigenbasis_list: List[Tensor] | None,
dims: List[List[int]] | None = None,
) Tensor[source]#

Projects the gradient to and from the eigenbases of the kronecker factor matrices.

This function performs tensor contractions between the input gradient and kronecker factor eigenbases.

Parameters:
  • grad – Input tensor to be preconditioned

  • eigenbasis_list – List of eigenbases for preconditioning. Each matrix should be a square matrix of eigenvectors.

  • dims – Dimensions for tensor contraction. Default is [[0], [0]] which contracts the first dimension of grad with the first dimension of each eigenbasis matrix, for projecting into the eigenbasis. Use [[0], [1]] for projecting back to original space.

Example

>>> grad = torch.randn(10, 20)
>>> Q = torch.randn(10, 10)
>>> precondition(grad, [Q], dims=[[0], [0]])
emerging_optimizers.soap.soap.init_kronecker_factors(
grad: Tensor,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
) List[Tensor][source]#

Initializes the kronecker factor matrices for the SOAP optimizer.

This function creates the initial Kronecker factor matrices (L and R) used for preconditioning. For 1D tensors (like biases), it can either skip preconditioning or create a single square kronecker factor matrix. For higher dimensional tensors, it creates a square kronecker factor matrix for each dimension.

When precondition_1d is:
  • False (default):
    • 1D tensors (like biases) will skip SOAP preconditioning entirely

    • These parameters will use standard Adam-style updates

    • This is often desirable as biases typically have fewer parameters and simpler optimization landscapes

    • Can improve performance and reduce memory usage

  • True:
    • All parameters, including 1D tensors, will use SOAP preconditioning

    • May be beneficial for certain architectures or training scenarios

Parameters:
  • grad – Gradient tensor used to initialize the kronecker factor matrices. The shape of this tensor determines the size of the kronecker factor matrices.

  • precondition_1d – Whether to create kronecker factor matrices for 1D tensors (like biases). If False, 1D tensors will skip preconditioning.

  • max_precond_dim – Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.

Returns:

List of kronecker factor matrices (L and R in paper).
  • For 1D tensors with precondition_1d=False: List containing an empty tensor

  • For 1D tensors with precondition_1d=True: List containing a square matrix

  • For higher dimensional tensors: List of square matrices, one per dimension

Return type:

List[torch.Tensor]

Example

>>> # For a 1D tensor (bias)
>>> grad_1d = torch.randn(10)
>>> precond_1d = init_kronecker_factors(grad_1d, precondition_1d=True)
>>> print(len(precond_1d))  # 1
>>> print(precond_1d[0].shape)  # (10, 10)
>>> # For a 2D tensor (weight matrix)
>>> grad_2d = torch.randn(10, 20)
>>> precond_2d = init_kronecker_factors(grad_2d)
>>> print(len(precond_2d))  # 2
>>> print(precond_2d[0].shape)  # (10, 10)
>>> print(precond_2d[1].shape)  # (20, 20)
emerging_optimizers.soap.soap.update_kronecker_factors(
kronecker_factor_list: List[Tensor],
grad: Tensor,
shampoo_beta: float,
precondition_1d: bool = False,
max_precond_dim: int = 8192,
) None[source]#

Updates the preconditioner matrices using gradient outer products.

This function updates the Kronecker factor matrices (L and R) used for preconditioning by computing and accumulating gradient outer products. For 1D tensors (like biases), it can optionally skip preconditioning or use a special 1D preconditioning strategy. It modifies the kronecker_factor_list in place.

Parameters:
  • kronecker_factor_list – List of preconditioner matrices (L and R) to update. Each matrix should be square and match the corresponding dimension of grad.

  • grad – Gradient tensor of the parameter being optimized

  • shampoo_beta – Momentum coefficient for updating preconditioners. Controls how much weight to give to new vs old gradient statistics.

  • precondition_1d – Whether to apply preconditioning to 1D tensors (like biases). If False, 1D tensors will skip preconditioning.

  • max_precond_dim – Maximum dimension of the preconditioner matrices. Skips preconditioning if any tensor dimension exceeds.

Example

>>> grad = torch.randn(10, 20)
>>> L = torch.zeros(10, 10)
>>> R = torch.zeros(20, 20)
>>> update_preconditioner([L, R], grad, shampoo_beta=0.95)
emerging_optimizers.soap.soap.update_eigenbasis_and_momentum(
kronecker_factor_list: List[Tensor],
eigenbasis_list: List[Tensor],
exp_avg_sq: Tensor,
momentum: Tensor,
use_eigh: bool = False,
use_adaptive_criteria: bool = False,
adaptive_update_tolerance: float | None = None,
power_iter_steps: int = 1,
convert_to_float: bool = True,
) Tuple[List[Tensor], Tensor, Tensor][source]#

Updates the eigenbases using QR decomposition and power iteration or eigh.

This function performs an update of the eigenbases (QL and QR) used for preconditioning. It follows these steps:

  1. Projects momentum back to the original basis

  2. Updates the eigenbases using QR decomposition and power iteration (orthogonal iteration)

  3. Projects momentum back to the new eigenbasis

Parameters:
  • kronecker_factor_list – List of preconditioner matrices (L and R) that define the optimization landscape. These are updated with gradient statistics.

  • eigenbasis_list – List of current eigenbases (QL and QR) used for preconditioning. These will be updated by this function.

  • exp_avg_sq – Inner Adam’s second moment tensor, used for scaling the preconditioner updates. This tensor is modified in-place.

  • momentum – Inner Adam’s first moment tensor, used for tracking gradient momentum. This tensor is modified in-place.

  • use_eigh – Whether to use full symmetric eigendecomposition (eigh) to compute the eigenbasis. If False, use orthogonal iteration to compute the eigenbasis.

  • use_adaptive_criteria – Whether to use criteria to determine if eigenbasis update is needed

  • adaptive_update_tolerance – Tolerance threshold for the update criteria. Only used if use_adaptive_criteria is True.

  • power_iter_steps – Number of power iteration steps to perform before QR decomposition. More steps can lead to better convergence but increased computation time.

  • convert_to_float – Whether to convert the preconditioner matrices and their corresponding orthonormal matrices to float for amortized computation. Otherwise, they are left in their original type.

Returns:

A tuple containing:
  • List[torch.Tensor]: Updated list of eigenbases (QL and QR)

  • torch.Tensor: Updated momentum tensor projected to the new eigenbasis

Return type:

Tuple[List[torch.Tensor], torch.Tensor]

Example

>>> L = torch.randn(10, 10)
>>> R = torch.randn(20, 20)
>>> QL = torch.randn(10, 10)
>>> QR = torch.randn(20, 20)
>>> exp_avg_sq = torch.randn(10, 20)
>>> momentum = torch.randn(10, 20)
>>> updated_eigenbases = update_eigenbasis(
...     [L, R], [QL, QR], exp_avg_sq, momentum)