Source code for emerging_optimizers.scalar_optimizers.ademamix

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple

import torch


__all__ = [
    "calculate_sim_ademamix_update",
    "calculate_ademamix_update",
]


[docs] @torch.compile # type: ignore[misc] @torch.no_grad() # type: ignore[misc] def calculate_sim_ademamix_update( grad: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq: torch.Tensor, num_beta_fast_warmup_steps: Optional[int], min_beta_fast: float, betas: Tuple[float, float], step: int, eps: float, correct_bias: bool, alpha: float = 2, ) -> torch.Tensor: """Performs simplified AdEMAMix update. This function performs the computation of 1 step of simplified AdEMAMix. Based on https://github.com/DepenM/Simplified-AdEMAMix/blob/main/simplified_AdEMAMix.py and https://arxiv.org/abs/2409.03137. The update rule is as follows: .. math:: m_t = \\beta_{\\text{fast}} m_{t-1} + g_t \\\\ v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\ \\hat{m}_t = \\frac{m_t}{(1 - \\beta_{\\text{fast}}^t) / (1 - \\beta_{\\text{fast}})} \\\\ \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\ \\text{update} = \\frac{\\alpha g_t + \\hat{m}_t}{\\sqrt{\\hat{v}_t} + \\epsilon} Args: grad: The gradient tensor. exp_avg: The accumulated first moment of the gradient. exp_avg_sq: The accumulated second moment of the gradient. num_beta_fast_warmup_steps: Number of warmup steps used to increase beta_fast min_beta_fast: The minimum beta_fast value used at initialization betas: The EMA beta coefficients for the Adam update. step: The current step of the optimizer, used to compute the bias correction terms. eps: The epsilon for the Adam second moment update. correct_bias: Whether to correct the bias of the AdEMAMix update. alpha: Coeficient for mixing the current gradient and EMA. Returns: The simplified-AdEMAMix update. """ beta_fast_final, beta2 = betas # Compute beta_fast based on scheduler if num_beta_fast_warmup_steps is not None: beta_fast = _linear_half_life_warmup_scheduler( step, beta_end=beta_fast_final, beta_start=min_beta_fast, num_warmup_steps=num_beta_fast_warmup_steps ) else: beta_fast = beta_fast_final # Decay the first moment "theory style": https://arxiv.org/abs/2502.02431 exp_avg.mul_(beta_fast).add_(grad, alpha=1.0) # Decay the second moment exponential moving average exp_avg_sq.lerp_(grad.square(), 1 - beta2) if correct_bias: # theory style bias correction bias_correction1 = (1 - beta_fast**step) / (1 - beta_fast) bias_correction2 = 1 - beta2**step else: bias_correction1 = 1 bias_correction2 = 1 # step size correction for optimizer states EMA momentum = exp_avg / bias_correction1 adam_second_moment = exp_avg_sq / bias_correction2 adam_second_moment = adam_second_moment.sqrt() + eps return (alpha * grad + momentum) / adam_second_moment
[docs] @torch.compile # type: ignore[misc] @torch.no_grad() # type: ignore[misc] def calculate_ademamix_update( grad: torch.Tensor, exp_avg_fast: torch.Tensor, exp_avg_slow: torch.Tensor, exp_avg_sq: torch.Tensor, num_beta_slow_warmup_steps: Optional[int], num_alpha_warmup_steps: Optional[int], betas: Tuple[float, float, float], step: int, eps: float, correct_bias: bool, alpha: float = 2, ) -> torch.Tensor: """Performs AdEMAMix update. This function performs the computation of 1 step of AdEMAMix. Based on https://github.com/apple/ml-ademamix/blob/main/pytorch/ademamix.py and https://arxiv.org/abs/2409.03137. The update rule is as follows: .. math:: m_t^{\\text{fast}} = \\beta_{\\text{fast}} m_{t-1}^{\\text{fast}} + (1 - \\beta_{\\text{fast}}) g_t \\\\ m_t^{\\text{slow}} = \\beta_{\\text{slow}} m_{t-1}^{\\text{slow}} + (1 - \\beta_{\\text{slow}}) g_t \\\\ v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\ \\hat{m}_t^{\\text{fast}} = \\frac{m_t^{\\text{fast}}}{1 - \\beta_{\\text{fast}}^t} \\\\ \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\ \\text{update} = \\frac{\\hat{m}_t^{\\text{fast}} + \\alpha m_t^{\\text{slow}}}{\\sqrt{\\hat{v}_t} + \\epsilon} Args: grad: The gradient tensor. exp_avg_fast: The accumulated first moment of the gradient with fast time constant. exp_avg_slow: The accumulated first moment of the gradient with slow time constant. exp_avg_sq: The accumulated second moment of the gradient. num_beta_slow_warmup_steps: Number of warmup steps used to increase beta_slow num_alpha_warmup_steps: Number of warmup steps used to increase alpha betas: The EMA beta coefficients for the Adam update. step: The current step of the optimizer, used to compute the bias correction terms. eps: The epsilon for the Adam second moment update. correct_bias: Whether to correct the bias of the AdEMAMix update. alpha: Coeficient for mixing the current gradient and EMA, the final value to use in case of scheduling. Returns: The AdEMAMix update. """ beta_fast, beta2, beta_slow_final = betas if num_alpha_warmup_steps is not None: alpha = _linear_warmup_scheduler(step, alpha_end=alpha, alpha_start=0, num_warmup_steps=num_alpha_warmup_steps) else: alpha = alpha # Compute beta_slow based on scheduler with half-life linear warmup # beta_start is usually set to beta_fast if num_beta_slow_warmup_steps is not None: beta_slow = _linear_half_life_warmup_scheduler( step, beta_end=beta_slow_final, beta_start=beta_fast, num_warmup_steps=num_beta_slow_warmup_steps ) else: beta_slow = beta_slow_final if correct_bias: bias_correction1 = 1 - beta_fast**step bias_correction2 = 1 - beta2**step else: bias_correction1 = 1 bias_correction2 = 1 # Decay the fast first moment, slow first moment and second moment with an exponential moving average if beta_fast != 0.0: exp_avg_fast.lerp_(grad, 1 - beta_fast) else: exp_avg_fast = grad exp_avg_slow.lerp_(grad, 1 - beta_slow) exp_avg_sq.lerp_(grad.square(), 1 - beta2) # Correct biases of fast moment and adam second moment, slow moment is not corrected fast_moment = exp_avg_fast / bias_correction1 adam_second_moment = exp_avg_sq / bias_correction2 adam_second_moment = adam_second_moment.sqrt() + eps return (fast_moment + alpha * exp_avg_slow) / adam_second_moment
def _half_life_steps(beta: float, eps: float = 1e-8) -> float: """Function that maps beta to the number of steps to reach 0.5. Equation: f(beta) = log(0.5) / log(beta + eps) - 1 Args: beta: The beta parameter. eps: A small constant to avoid division by zero. Returns: The number of steps to reach 0.5. """ return math.log(0.5) / math.log(beta + eps) - 1 def _inverse_half_life_beta(t: float) -> float: """Maps number of steps to reach 0.5 to beta. Equation: f_inv(t) = 0.5^(1 / (t + 1)) Args: t: The number of steps to reach 0.5. Returns: The beta parameter. """ return math.pow(0.5, 1 / (t + 1)) def _linear_half_life_warmup_scheduler( step: int, beta_end: float, beta_start: float = 0, num_warmup_steps: int = 1 ) -> float: """Half-life linear warmup scheduler for the beta parameter. Equation: beta = f_inv((1 - step / num_warmup_steps) * f(beta_start) + (step / num_warmup_steps) * f(beta_end)) Args: step: The current step of the optimizer. beta_end: The final value of the beta parameter. beta_start: The initial value of the beta parameter. num_warmup_steps: The number of warmup steps. Returns: The value of the beta parameter at the current step. """ if step < num_warmup_steps: a = step / float(num_warmup_steps) return _inverse_half_life_beta((1.0 - a) * _half_life_steps(beta_start) + a * _half_life_steps(beta_end)) return beta_end def _linear_warmup_scheduler(step: int, alpha_end: float, alpha_start: float = 0, num_warmup_steps: int = 1) -> float: """Linear warmup scheduler for the alpha parameter. Equation: alpha = (1 - step / num_warmup_steps) * alpha_start + (step / num_warmup_steps) * alpha_end Args: step: The current step of the optimizer. alpha_end: The final value of the alpha parameter. alpha_start: The initial value of the alpha parameter. num_warmup_steps: The number of warmup steps. Returns: The value of the alpha parameter at the current step. """ if step < num_warmup_steps: a = step / float(num_warmup_steps) return (1.0 - a) * alpha_start + a * alpha_end return alpha_end