Source code for emerging_optimizers.orthogonalized_optimizers.adaptive_muon

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Callable, Literal, override


if TYPE_CHECKING:
    from typing import overload

import torch
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import mixin as opt_mixin
from emerging_optimizers import registry, utils
from emerging_optimizers.orthogonalized_optimizers import muon, muon_utils
from emerging_optimizers.orthogonalized_optimizers.muon_utils import NSCoeffT
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
from emerging_optimizers.utils import FP32MatmulPrecT


__all__ = ["AdaptiveMuon", "Moment2MethodT"]

Moment2MethodT = Literal["adamuon", "normuon", "namo"]


[docs] @registry.register_optimizer("adaptive_muon") class AdaptiveMuon(OrthogonalizedOptimizer): """Adaptive Muon optimizer with adaptive second moment (AdaMuon/NorMuon/NAMO variants). This class extends Muon by adding adaptive second moment accumulation after raw orthogonalization and before Muon's update scaling. This idea was first explored in D.E. Carlson, E. Collins, Ya-Ping Hsieh, L. Carin, and V. Cevher. *Preconditioned spectral descent for deep learning.* In Advances in neural information processing systems 28 (2015). The step() method is overridden to include second moment normalization logic. Args: params: Iterable of parameters to optimize or dicts defining parameter groups. lr: Learning rate. momentum: The exponential decay rate for momentum. weight_decay: Weight decay coefficient. nesterov: Whether to use Nesterov momentum. weight_decay_method: The weight decay method to use. fp32_matmul_prec: Precision for FP32 matrix multiplication. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. scale_mode: The type of scale factor to use for the update. extra_scale_factor: The additional scale factor to use for the update. moment2_method: Method for second moment accumulation ("adamuon", "normuon", or "namo"). - "adamuon": Full elementwise second moment (like AdamW). - "normuon": Row or column-wise second moment. - "namo": Scalar adaptive scaling via Frobenius-norm ratio. beta2: The exponential decay rate for second moment. eps: Small constant for numerical stability. """ def __init__( self, params: ParamsT, lr: float, momentum: float, weight_decay: float, *, nesterov: bool = False, weight_decay_method: opt_mixin.WeightDecayT = "decoupled", fp32_matmul_prec: FP32MatmulPrecT = "medium", coefficient_type: NSCoeffT = "quintic", num_ns_steps: int = 5, scale_mode: muon.MuonScaleT = "spectral", extra_scale_factor: float = 1.0, moment2_method: Moment2MethodT = "adamuon", beta2: float = 0.95, eps: float = 1e-8, ): if moment2_method == "namo" and weight_decay_method == "l2": raise ValueError('moment2_method="namo" is incompatible with weight_decay_method="l2"') if num_ns_steps < 1: raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") self.scale_mode = scale_mode self.extra_scale_factor = extra_scale_factor def raw_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: # Adaptive variants normalize the raw Newton-Schulz output first; Muon scale is applied after moment2. logging.debug(f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient") return muon_utils.newton_schulz( grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False, ) super().__init__( params, lr, momentum, weight_decay, nesterov=nesterov, weight_decay_method=weight_decay_method, fp32_matmul_prec=fp32_matmul_prec, scaled_orthogonalize_fn=raw_orthogonalize_fn, ) self.moment2_method = moment2_method for group in self.param_groups: group.setdefault("beta2", beta2) group.setdefault("eps", eps) def _apply_muon_scale(self, update: torch.Tensor, size_out: int, size_in: int) -> torch.Tensor: scale_factor = muon.get_muon_scale_factor(size_out, size_in, mode=self.scale_mode) logging.debug(f"Applying Muon scale factor {scale_factor}, extra_scale_factor={self.extra_scale_factor}") return update * scale_factor * self.extra_scale_factor def _match_frobenius_norm( self, update: torch.Tensor, reference: torch.Tensor, eps: float, ) -> torch.Tensor: """Scale update to match the Frobenius norm of reference.""" update_norm = torch.linalg.vector_norm(update) reference_norm = torch.linalg.vector_norm(reference) return update * (reference_norm / update_norm.clamp_min(eps)) @torch.no_grad() # type: ignore[misc] @override def _init_group( self, group: dict, skip_non_grad_params: bool = True, ) -> None: """Performs lazy state initialization for parameters. Extends the base class to also initialize the second moment buffer. The shape of the moment2 buffer depends on the moment2_method: - "adamuon": Full elementwise buffer with same shape as parameter - "normuon": Reduced shape buffer (averaged along -1 if shape[-2] >= shape[-1], else -2) - "namo": Scalar buffer (EMA of squared Frobenius norm of gradient) Args: group: Parameter group dictionary. skip_non_grad_params: If True, skip parameters without gradients. """ for p in group["params"]: if skip_non_grad_params and p.grad is None: continue state = self.state[p] if len(state) == 0: state["momentum_buffer"] = torch.zeros_like(p.data) if self.moment2_method == "adamuon": # Full elementwise second moment state["moment2_buffer"] = torch.zeros_like(p.data) elif self.moment2_method == "normuon": # Row/column-wise second moment - reduced along one dimension if p.data.ndim != 2: raise ValueError( f"{self.__class__.__name__} only supports 2D parameters, got shape {tuple(p.data.shape)}" ) avg_dim = -1 if p.data.shape[-2] >= p.data.shape[-1] else -2 moment2_shape = list(p.data.shape) moment2_shape[avg_dim] = 1 state["moment2_buffer"] = torch.zeros(moment2_shape, dtype=p.data.dtype, device=p.data.device) elif self.moment2_method == "namo": state["moment2_buffer"] = torch.zeros(1, dtype=p.data.dtype, device=p.data.device) else: raise TypeError(f"Invalid second moment method: {self.moment2_method}") def _apply_moment2_normalization( self, orth_grad: torch.Tensor, moment2: torch.Tensor, beta2: float, eps: float, *, raw_grad: torch.Tensor | None = None, pre_orth_grad: torch.Tensor | None = None, ) -> torch.Tensor: """Apply second moment accumulation and normalization. This method supports three variants: - "adamuon": Full elementwise second moment (like AdamW, https://arxiv.org/abs/2507.11005) - "normuon": Row or column-wise second moment (https://arxiv.org/abs/2510.05491) - "namo": Scalar adaptive scaling using Frobenius-norm ratio (https://arxiv.org/abs/2602.17080). Scales the orthogonalized momentum by .. math:: \\alpha_t = \\frac{\\|g_t^{\\text{pre-orth}}\\|_F}{\\sqrt{v_t} + \\varepsilon} where :math:`v_t` is the EMA of :math:`\\|G_t\\|_F^2`. For all methods: 1. Updates the second moment as an EMA of (some function of) squared gradients 2. Returns the adaptively scaled gradient Args: orth_grad: The raw orthogonalized gradient tensor before Muon update scaling. moment2: The second moment buffer from state. beta2: The exponential decay rate for second moment. eps: Small constant for numerical stability. raw_grad: (NAMO only) The raw gradient before momentum update. pre_orth_grad: (NAMO only) The gradient after momentum/Nesterov, before orthogonalization. Returns: The adaptively scaled weight update tensor. """ if self.moment2_method == "adamuon": # AdamMuon: Full elementwise second moment like AdamW # Update second moment with EMA of squared orthogonalized gradient moment2.lerp_(orth_grad.square(), 1 - beta2) # AdamW-style division: grad / (sqrt(moment2) + eps) denom = moment2.sqrt() + eps return orth_grad / denom elif self.moment2_method == "normuon": # NorMuon: Row or column-wise second moment # Compute mean of squared gradients along one dimension based on shape # Average along the longer dimension to preserve structure along shorter dim avg_dim = -1 if orth_grad.shape[-2] >= orth_grad.shape[-1] else -2 v_mean = orth_grad.square().mean(dim=avg_dim, keepdim=True) # Update second moment with EMA moment2.lerp_(v_mean, 1 - beta2) # NorMuon uses reciprocal square root with clamping step_size = moment2.clamp_min(eps).rsqrt_() update = orth_grad * step_size # Preserve the raw orthogonalized update norm before applying Muon's scale factor. return self._match_frobenius_norm(update, orth_grad, eps) elif self.moment2_method == "namo": if raw_grad is None or pre_orth_grad is None: raise RuntimeError("NAMO requires raw_grad and pre_orth_grad") # NAMO: Scalar adaptive scaling via Frobenius-norm ratio # v_t = β2 * v_{t-1} + (1 - β2) * ||G_t||_F^2 grad_norm_square = torch.linalg.vector_norm(raw_grad).square() moment2.lerp_(grad_norm_square, 1 - beta2) # α_t = ||pre_orth_grad||_F / (sqrt(v_t) + ε) pre_orth_norm = torch.linalg.vector_norm(pre_orth_grad) alpha_t = pre_orth_norm / (moment2.sqrt() + eps) return orth_grad * alpha_t else: raise TypeError(f"Invalid second moment method: {self.moment2_method}") if TYPE_CHECKING: @overload def step(self, closure: None = ...) -> None: ... @overload def step(self, closure: Callable[[], float]) -> float: ...
[docs] @torch.no_grad() # type: ignore[misc] @override def step(self, closure: Callable[[], float] | None = None) -> float | None: """Single optimization step. Args: closure: Unsupported; must be ``None``. """ if closure is not None: raise ValueError("closure is not supported") for group in self.param_groups: self._init_group(group) for p in group["params"]: if p.dim() != 2: raise ValueError(f"{self.__class__.__name__} only supports 2D parameters") grad = p.grad if grad is None: continue state = self.state[p] exp_avg = state["momentum_buffer"] self._apply_weight_decay_inplace( p, grad, group["lr"], group["weight_decay"], ) raw_grad = grad if self.moment2_method == "namo" else None # update momentum buffer with EMA of gradient exp_avg.lerp_(grad, 1 - group["momentum"]) if self.nesterov: grad = grad.lerp(exp_avg, group["momentum"]) else: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): group_kwargs = {k: v for k, v in group.items() if k != "params"} orth_grad = self.orthogonalize(p, grad, **group_kwargs) update = self._apply_moment2_normalization( orth_grad=orth_grad, moment2=state["moment2_buffer"], beta2=group["beta2"], eps=group["eps"], raw_grad=raw_grad, pre_orth_grad=grad if raw_grad is not None else None, ) update = self._apply_muon_scale(update, update.size(-2), update.size(-1)) # perform weight update with pre and post weight update functions for subclass customization self.pre_weight_update_fn_inplace(p, update) p.add_(update, alpha=-group["lr"]) self.post_weight_update_fn_inplace(p) return None