emerging_optimizers.scalar_optimizers#

emerging_optimizers.scalar_optimizers.update_functions.calculate_adam_update(
grad,
exp_avg,
exp_avg_sq,
*,
betas,
eps,
correct_bias,
nesterov,
step,
)[source]#

Performs the Adam update.

This function performs the computation of 1 step of Adam.

The update rule is as follows:

\[m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \\ \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • exp_avg_sq (Tensor) – The accumulated second moment of the gradient (modified in place).

  • betas (tuple[float, float]) – The EMA beta coefficients (beta1, beta2).

  • eps (float) – Epsilon for the second-moment denominator.

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • nesterov (bool) – Whether to use Nesterov momentum.

  • step (int) – Current optimizer step (1-based), used for bias correction.

Returns:

The Adam update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_ademamix_update(
grad,
exp_avg_fast,
exp_avg_slow,
exp_avg_sq,
*,
betas,
eps,
correct_bias,
step,
num_beta_slow_warmup_steps,
num_alpha_warmup_steps,
alpha=2,
)[source]#

Performs AdEMAMix update.

This function performs the computation of 1 step of AdEMAMix. Based on apple/ml-ademamix and https://arxiv.org/abs/2409.03137.

The update rule is as follows:

\[m_t^{\text{fast}} = \beta_{\text{fast}} m_{t-1}^{\text{fast}} + (1 - \beta_{\text{fast}}) g_t \\ m_t^{\text{slow}} = \beta_{\text{slow}} m_{t-1}^{\text{slow}} + (1 - \beta_{\text{slow}}) g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t^{\text{fast}} = \frac{m_t^{\text{fast}}}{1 - \beta_{\text{fast}}^t} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\hat{m}_t^{\text{fast}} + \alpha m_t^{\text{slow}}}{\sqrt{\hat{v}_t} + \epsilon} \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg_fast (Tensor) – The accumulated first moment with the fast time constant (modified in place).

  • exp_avg_slow (Tensor) – The accumulated first moment with the slow time constant (modified in place).

  • exp_avg_sq (Tensor) – The accumulated second moment of the gradient (modified in place).

  • betas (tuple[float, float, float]) – The EMA beta coefficients (beta_fast, beta2, beta_slow_final).

  • eps (float) – Epsilon for the second-moment denominator.

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • step (int) – Current optimizer step (1-based), used for bias correction and the schedulers below.

  • num_beta_slow_warmup_steps (int | None) – Number of warmup steps used to ramp beta_slow toward beta_slow_final. None disables the schedule.

  • num_alpha_warmup_steps (int | None) – Number of warmup steps used to ramp alpha toward its final value. None disables the schedule.

  • alpha (float) – Coefficient for mixing the slow first moment into the update. When scheduled, this is the final value.

Returns:

The AdEMAMix update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_laprop_update(
grad,
exp_avg,
exp_avg_sq,
*,
betas,
eps,
correct_bias,
step,
)[source]#

Performs the LAProp/Normalized SGD with momentum update.

LAProp can be seen as RMSProp with a momentum term, or normalized SGD with momentum. Based on Z-T-WANG/LaProp-Optimizer and https://arxiv.org/abs/2002.04839.

The update rule is as follows:

\[v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ g'_t = \frac{g_t}{\sqrt{\hat{v}_t} + \epsilon} \\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g'_t \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t} \\ \text{update} = \hat{m}_t \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • exp_avg_sq (Tensor) – The accumulated second moment of the gradient (modified in place).

  • betas (tuple[float, float]) – The EMA beta coefficients (beta1, beta2).

  • eps (float) – Epsilon for the second-moment denominator.

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • step (int) – Current optimizer step (1-based), used for bias correction.

Returns:

The LaProp update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_lion_update(grad, exp_avg, *, betas)[source]#

Performs the Lion update.

This function performs the computation of 1 step of Lion update.

The update rule is as follows:

\[\text{update} = \text{sign}(\beta_1 m_{t-1} + (1 - \beta_1) g_t) \\ m_t = \beta_2 m_{t-1} + (1 - \beta_2) g_t \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • betas (tuple[float, float]) – The EMA beta coefficients (beta1, beta2). beta1 controls the sign-update interpolation; beta2 controls the momentum EMA.

Returns:

The Lion update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_madam_update(
grad,
exp_avg,
exp_avg_sq_scaled,
*,
betas,
correct_bias,
step,
scale_log2,
)[source]#

Performs the magnitude-aware Adam (MAdam) update.

Vanilla Adam adds an eps to sqrt(v_hat) so the denominator never goes to zero. The cost is that when the natural scale of sqrt(v_hat) approaches eps, eps quietly reshapes the update (see docs/primer/epsilon.md). MAdam removes eps entirely and prevents zero division two other ways:

  1. Magnitude-aware storage. The second moment is stored in scaled form v'_t = s * EMA(g_t^2) = EMA((sqrt(s) * g_t)^2), with s = 2 ** scale_log2. Multiplying g by sqrt(s) before squaring lifts tiny gradients above the fp32 underflow boundary, so v'_t stays non-zero whenever any prior gradient was non-zero. scale_log2 is constrained to even integers so that sqrt(s) = 2 ** (scale_log2 // 2) is itself an exact power of two β€” the pre-square multiplication is then a bare exponent shift with no rounding.

  2. All-zero mask from ``exp_avg``. Parameters whose gradient has been exactly zero from step 1 onward have exp_avg == 0 (the EMA started at zero and never received non-zero input). For those entries both the numerator and the denominator are zero; we mask the update to zero rather than dividing.

The update rule is:

\[v'_t = \beta_2 v'_{t-1} + (1 - \beta_2) \left( \sqrt{s}\, g_t \right)^2 \\ m_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ \hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}'_t = \frac{v'_t}{1 - \beta_2^t} \\ \text{update} = \begin{cases} \dfrac{\sqrt{s}\, \hat{m}_t}{\sqrt{\hat{v}'_t}} & m_t \ne 0 \\ 0 & m_t = 0 \end{cases} \]

Note

The mask uses exp_avg rather than grad so a parameter whose gradient is zero on the current step but was non-zero earlier still receives a momentum-driven update.

Note

If a non-zero gradient ever produced a squared value that even sqrt(s) * g could not lift above underflow, exp_avg_sq_scaled can become zero while exp_avg is non-zero. The update will then be inf / nan for those entries β€” i.e. it is the caller’s job to pick a scale_log2 large enough for the model’s gradient regime.

Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • exp_avg_sq_scaled (Tensor) – The accumulated scaled second moment, storing s * EMA(g_t^2) (modified in place). Allocate as zeros_like(p); the caller is responsible for using the same scale across steps.

  • betas (tuple[float, float]) – The EMA beta coefficients (beta1, beta2).

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • step (int) – Current optimizer step (1-based), used for bias correction.

  • scale_log2 (float) – log2 of the magnitude scaling factor s = 2 ** scale_log2 used for the second-moment storage. When it is an even integer, sqrt(s) = 2 ** (scale_log2 // 2) is exactly representable in floating point.

Returns:

The MAdam update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_signum_update(
grad,
exp_avg,
*,
momentum,
correct_bias,
nesterov,
step,
use_shape_scaling=False,
)[source]#

Performs the sign-SGD or Signum update.

This function performs the computation of 1 step of sign-SGD or Signum. Based on https://arxiv.org/abs/1802.04434. When using signSGD with shape scaling, general recommendation is to scale \(lr = \text{adam lr} \cdot \text{network width} \cdot \frac{2}{\text{rows} + \text{cols}}\). This is for learning rate transfer with width scaling (https://arxiv.org/abs/2506.07254v1).

The update rule is as follows:

\[m_t = \beta m_{t-1} + (1 - \beta) g_t \\ \hat{m}_t = \frac{m_t}{1 - \beta^t} \\ \text{update} = \text{sign}(\hat{m}_t) \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • momentum (float) – The EMA decay coefficient for the momentum buffer (single scalar).

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • nesterov (bool) – Whether to use Nesterov momentum.

  • step (int) – Current optimizer step (1-based), used for bias correction.

  • use_shape_scaling (bool) – Whether to scale the update by 2 / (m + n) for an (m, n) tensor.

Returns:

The sign-SGD/Signum update.

Return type:

Tensor

emerging_optimizers.scalar_optimizers.update_functions.calculate_sim_ademamix_update(
grad,
exp_avg,
exp_avg_sq,
*,
betas,
eps,
correct_bias,
step,
num_beta_fast_warmup_steps,
min_beta_fast,
alpha=2,
)[source]#

Performs simplified AdEMAMix update.

This function performs the computation of 1 step of simplified AdEMAMix. Based on DepenM/Simplified-AdEMAMix and https://arxiv.org/abs/2409.03137.

The update rule is as follows:

\[m_t = \beta_{\text{fast}} m_{t-1} + g_t \\ v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2 \\ \hat{m}_t = \frac{m_t}{(1 - \beta_{\text{fast}}^t) / (1 - \beta_{\text{fast}})} \\ \hat{v}_t = \frac{v_t}{1 - \beta_2^t} \\ \text{update} = \frac{\alpha g_t + \hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} \]
Parameters:
  • grad (Tensor) – The gradient tensor.

  • exp_avg (Tensor) – The accumulated first moment of the gradient (modified in place).

  • exp_avg_sq (Tensor) – The accumulated second moment of the gradient (modified in place).

  • betas (tuple[float, float]) – The EMA beta coefficients (beta_fast_final, beta2).

  • eps (float) – Epsilon for the second-moment denominator.

  • correct_bias (bool) – Whether to apply Adam-style bias correction.

  • step (int) – Current optimizer step (1-based), used for bias correction and the beta_fast schedule.

  • num_beta_fast_warmup_steps (int | None) – Number of warmup steps used to ramp beta_fast from min_beta_fast toward beta_fast_final. None disables the schedule.

  • min_beta_fast (float) – Initial beta_fast value used at the start of the warmup schedule.

  • alpha (float) – Coefficient for mixing the current gradient into the update.

Returns:

The simplified-AdEMAMix update.

Return type:

Tensor