emerging_optimizers.scalar_optimizers#

emerging_optimizers.scalar_optimizers.calculate_adam_update(
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
betas: Tuple[float, float],
correct_bias: bool,
use_nesterov: bool,
step: int,
eps: float,
) Tensor[source]#

Performs the Adam update.

This function performs the computation of 1 step of Adam.

The update rule is as follows:

\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\ \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg – The accumulated first moment of the gradient.

  • exp_avg_sq – The accumulated second moment of the gradient.

  • betas – The EMA beta coefficients for the Adam update.

  • correct_bias – Whether to correct the bias of the Adam update.

  • use_nesterov – Whether to use nesterov momentum.

  • step – The current step of the optimizer, used to compute the bias correction terms.

  • eps – The epsilon for the Adam second moment update.

Returns:

The Adam-update.

emerging_optimizers.scalar_optimizers.calculate_ademamix_update(
grad: Tensor,
exp_avg_fast: Tensor,
exp_avg_slow: Tensor,
exp_avg_sq: Tensor,
num_beta_slow_warmup_steps: int | None,
num_alpha_warmup_steps: int | None,
betas: Tuple[float, float, float],
step: int,
eps: float,
correct_bias: bool,
alpha: float = 2,
) Tensor[source]#

Performs AdEMAMix update.

This function performs the computation of 1 step of AdEMAMix. Based on apple/ml-ademamix and https://arxiv.org/abs/2409.03137.

The update rule is as follows:

\[m_t^{\text{fast}} = \beta_{\text{fast}} m_{t-1}^{\text{fast}} + (1 - \beta_{\text{fast}}) g_t \\ m_t^{\text{slow}} = \beta_{\text{slow}} m_{t-1}^{\text{slow}} + (1 - \beta_{\text{slow}}) g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t^{\text{fast}} = \frac{m_t^{\text{fast}}}{1 - \beta_{\text{fast}}^t} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\hat{m}_t^{\text{fast}} + \alpha m_t^{\text{slow}}}{\sqrt{\hat{v}_t} + \epsilon} \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg_fast – The accumulated first moment of the gradient with fast time constant.

  • exp_avg_slow – The accumulated first moment of the gradient with slow time constant.

  • exp_avg_sq – The accumulated second moment of the gradient.

  • num_beta_slow_warmup_steps – Number of warmup steps used to increase beta_slow

  • num_alpha_warmup_steps – Number of warmup steps used to increase alpha

  • betas – The EMA beta coefficients for the Adam update.

  • step – The current step of the optimizer, used to compute the bias correction terms.

  • eps – The epsilon for the Adam second moment update.

  • correct_bias – Whether to correct the bias of the AdEMAMix update.

  • alpha – Coeficient for mixing the current gradient and EMA, the final value to use in case of scheduling.

Returns:

The AdEMAMix update.

emerging_optimizers.scalar_optimizers.calculate_laprop_update(
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
correct_bias: bool,
betas: Tuple[float, float],
step: int,
eps: float,
) Tensor[source]#

Performs the LAProp/Normalized SGD with momentum update.

LAProp can be seen as RMSProp with a momentum term, or normalized SGD with momentum. Based on Z-T-WANG/LaProp-Optimizer and https://arxiv.org/abs/2002.04839.

The update rule is as follows:

\[v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ g'_t = \frac{g_t}{\sqrt{\hat{v}_t} + \epsilon} \\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g'_t \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} \\ \text{update} = \hat{m}_t \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg – The exponential moving average of the gradient.

  • exp_avg_sq – The exponential moving average of the gradient squared.

  • correct_bias – Whether to correct the bias of the Adam update.

  • betas – The betas for the exponential moving average.

  • step – The current step.

  • eps – The epsilon for the second moment update.

Returns:

The LAProp update.

emerging_optimizers.scalar_optimizers.calculate_lion_update(
grad: Tensor,
exp_avg: Tensor,
momentum_beta: float,
momentum_beta2: float | None = None,
) Tensor[source]#

Performs the Lion update.

This function performs the computation of 1 step of Lion update.

The update rule is as follows:

\[\text{update} = \text{sign}(\beta_1 m_{t-1} + (1 - \beta_1) g_t) \\ m_t = \beta_2 m_{t-1} + (1 - \beta_2) g_t \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg – The accumulated first moment of the gradient.

  • momentum_beta – The EMA beta coefficients for the momentum update (beta1 in Lion).

  • momentum_beta2 – The second EMA beta coefficient for Lion momentum update.

Returns:

The Lion update.

emerging_optimizers.scalar_optimizers.calculate_signum_update(
grad: Tensor,
exp_avg: Tensor,
momentum_beta: float,
correct_bias: bool,
use_nesterov: bool,
step: int,
use_shape_scaling: bool = False,
) Tensor[source]#

Performs the sign-SGD or Signum update.

This function performs the computation of 1 step of sign-SGD or Signum. Based on https://arxiv.org/abs/1802.04434. When using signSGD with shape scaling, general recommendation is to scale \(lr = \text{adam lr} \cdot \text{network width} \cdot \frac{2}{\text{rows} + \text{cols}}\). This is for learning rate transfer with width scaling (https://arxiv.org/abs/2506.07254v1).

The update rule is as follows:

\[m_t = \beta m_{t-1} + (1 - \beta) g_t \\ \hat{m}_t = \frac{m_t}{1 - \beta^t} \\ \text{update} = \text{sign}(\hat{m}_t) \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg – The accumulated first moment of the gradient.

  • momentum_beta – The EMA beta coefficients for the momentum update.

  • correct_bias – Whether to correct the bias of the momentum update.

  • use_nesterov – Whether to use nesterov momentum.

  • step – The current step of the optimizer, used to compute the bias correction terms.

  • use_shape_scaling – Whether to scale the update by the shape of the tensor.

Returns:

The sign-SGD/Signum update.

emerging_optimizers.scalar_optimizers.calculate_sim_ademamix_update(
grad: Tensor,
exp_avg: Tensor,
exp_avg_sq: Tensor,
num_beta_fast_warmup_steps: int | None,
min_beta_fast: float,
betas: Tuple[float, float],
step: int,
eps: float,
correct_bias: bool,
alpha: float = 2,
) Tensor[source]#

Performs simplified AdEMAMix update.

This function performs the computation of 1 step of simplified AdEMAMix. Based on DepenM/Simplified-AdEMAMix and https://arxiv.org/abs/2409.03137.

The update rule is as follows:

\[m_t = \beta_{\text{fast}} m_{t-1} + g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t = \frac{m_t}{(1 - \beta_{\text{fast}}^t) / (1 - \beta_{\text{fast}})} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\alpha g_t + \hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \]
Parameters:
  • grad – The gradient tensor.

  • exp_avg – The accumulated first moment of the gradient.

  • exp_avg_sq – The accumulated second moment of the gradient.

  • num_beta_fast_warmup_steps – Number of warmup steps used to increase beta_fast

  • min_beta_fast – The minimum beta_fast value used at initialization

  • betas – The EMA beta coefficients for the Adam update.

  • step – The current step of the optimizer, used to compute the bias correction terms.

  • eps – The epsilon for the Adam second moment update.

  • correct_bias – Whether to correct the bias of the AdEMAMix update.

  • alpha – Coeficient for mixing the current gradient and EMA.

Returns:

The simplified-AdEMAMix update.