Source code for emerging_optimizers.orthogonalized_optimizers.muon_hyperball

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, override

import torch

from emerging_optimizers import registry
from emerging_optimizers.orthogonalized_optimizers import muon


__all__ = ["MuonHyperball"]


[docs] @registry.register_optimizer("muon_hyperball") class MuonHyperball(muon.Muon): """Muon optimizer with hyperball-style norm-preserving weight updates. This optimizer extends Muon by performing gradient descent on the sphere manifold while preserving the weight norm. The update rule is: .. math:: W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps the weight matrix at constant scale while updating. Warning: This optimizer is experimental and may change in future versions. See :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` for full documentation of the base Muon optimizer. Args: *args: Arguments passed to Muon. hyperball_eps: Epsilon for numerical stability in normalization. Default: ``1e-8``. hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), uses each parameter's initial Frobenius norm as its radius. If specified, all parameters will be rescaled to have this radius at initialization. **kwargs: Keyword arguments passed to Muon. """ def __init__( self, *args: Any, hyperball_eps: float = 1e-8, hyperball_radius: float | None = None, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps self.hyperball_radius = hyperball_radius super().__init__(*args, **kwargs) # Validate and optionally rescale parameters based on hyperball_radius. with torch.no_grad(): for group in self.param_groups: for p in group["params"]: p_norm = p.norm() # Validate that parameter has non-zero norm. if p_norm.item() == 0: raise ValueError( "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." ) # Rescale parameter to have the specified radius if provided. if self.hyperball_radius is not None: p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps))
[docs] @override def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: """Store the original weight norm and normalize the update using Frobenius norm. Args: p: The parameter tensor. update: The orthogonalized gradient tensor. """ # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() self.state[p]["hyperball_R"] = R # Normalize the update in-place and scale by R # This modifies update to be: R * normalize(update) using Frobenius norm. update_norm = update.norm().clamp_min(self.hyperball_eps) update.mul_(R / update_norm)
[docs] @override def post_weight_update_fn_inplace(self, p: torch.Tensor) -> None: """Normalize the updated weights and scale back to original norm using Frobenius norm. Args: p: The parameter tensor (already updated). """ # Retrieve R from per-parameter state R = self.state[p]["hyperball_R"] # Normalize the result and scale back by R: p = R * (p / ||p||_F) using Frobenius norm. p_norm = p.norm().clamp_min(self.hyperball_eps) p.mul_(R / p_norm)