Source code for emerging_optimizers.scalar_optimizers.update_functions.madam

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch


__all__ = [
    "calculate_madam_update",
]


[docs] @torch.compile # type: ignore[misc] @torch.no_grad() # type: ignore[misc] def calculate_madam_update( grad: torch.Tensor, exp_avg: torch.Tensor, exp_avg_sq_scaled: torch.Tensor, *, betas: tuple[float, float], correct_bias: bool, step: int, scale_log2: float, ) -> torch.Tensor: """Performs the magnitude-aware Adam (MAdam) update. Vanilla Adam adds an ``eps`` to ``sqrt(v_hat)`` so the denominator never goes to zero. The cost is that when the natural scale of ``sqrt(v_hat)`` approaches ``eps``, ``eps`` quietly reshapes the update (see ``docs/primer/epsilon.md``). MAdam removes ``eps`` entirely and prevents zero division two other ways: 1. **Magnitude-aware storage.** The second moment is stored in scaled form ``v'_t = s * EMA(g_t^2) = EMA((sqrt(s) * g_t)^2)``, with ``s = 2 ** scale_log2``. Multiplying ``g`` by ``sqrt(s)`` *before* squaring lifts tiny gradients above the fp32 underflow boundary, so ``v'_t`` stays non-zero whenever any prior gradient was non-zero. ``scale_log2`` is constrained to even integers so that ``sqrt(s) = 2 ** (scale_log2 // 2)`` is itself an exact power of two — the pre-square multiplication is then a bare exponent shift with no rounding. 2. **All-zero mask from ``exp_avg``.** Parameters whose gradient has been exactly zero from step 1 onward have ``exp_avg == 0`` (the EMA started at zero and never received non-zero input). For those entries both the numerator and the denominator are zero; we mask the update to zero rather than dividing. The update rule is: .. math:: v'_t = \\beta_2 v'_{t-1} + (1 - \\beta_2) \\left( \\sqrt{s}\\, g_t \\right)^2 \\\\ m_t = \\beta_1 m_{t-1} + (1 - \\beta_1) g_t \\\\ \\hat{m}_t = \\frac{m_t}{1 - \\beta_1^t}, \\quad \\hat{v}'_t = \\frac{v'_t}{1 - \\beta_2^t} \\\\ \\text{update} = \\begin{cases} \\dfrac{\\sqrt{s}\\, \\hat{m}_t}{\\sqrt{\\hat{v}'_t}} & m_t \\ne 0 \\\\ 0 & m_t = 0 \\end{cases} Note: The mask uses ``exp_avg`` rather than ``grad`` so a parameter whose gradient is zero on the current step but was non-zero earlier still receives a momentum-driven update. Note: If a non-zero gradient ever produced a squared value that even ``sqrt(s) * g`` could not lift above underflow, ``exp_avg_sq_scaled`` can become zero while ``exp_avg`` is non-zero. The update will then be ``inf`` / ``nan`` for those entries — i.e. it is the caller's job to pick a ``scale_log2`` large enough for the model's gradient regime. Args: grad: The gradient tensor. exp_avg: The accumulated first moment of the gradient (modified in place). exp_avg_sq_scaled: The accumulated **scaled** second moment, storing ``s * EMA(g_t^2)`` (modified in place). Allocate as ``zeros_like(p)``; the caller is responsible for using the same ``scale`` across steps. betas: The EMA beta coefficients ``(beta1, beta2)``. correct_bias: Whether to apply Adam-style bias correction. step: Current optimizer step (1-based), used for bias correction. scale_log2: ``log2`` of the magnitude scaling factor ``s = 2 ** scale_log2`` used for the second-moment storage. When it is an even integer, ``sqrt(s) = 2 ** (scale_log2 // 2)`` is exactly representable in floating point. Returns: The MAdam update. """ beta1, beta2 = betas assert scale_log2 // 2 == scale_log2 / 2, "scale_log2 should be an even integer" grad_scale = 2.0 ** (scale_log2 // 2) # First moment as usual; second moment stored scaled. Multiply before squaring # so small gradients are lifted above the fp32 underflow boundary first. exp_avg.lerp_(grad, 1 - beta1) exp_avg_sq_scaled.lerp_((grad * grad_scale).square(), 1 - beta2) # Step-size correction for the EMAs bias_correction1 = 1.0 bias_correction2 = 1.0 if correct_bias: bias_correction1 = 1.0 - beta1 ** (step) bias_correction2 = 1.0 - beta2 ** (step) momentum = exp_avg / bias_correction1 second_moment_scaled = exp_avg_sq_scaled / bias_correction2 out = momentum / second_moment_scaled.sqrt() * grad_scale zero_mask = exp_avg == 0 out.masked_fill_(zero_mask, 0.0) return out