Source code for emerging_optimizers.riemannian_optimizers.normalized_optimizer

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable

import torch
from torch.optim.optimizer import Optimizer


[docs] class ObliqueSGD(Optimizer): """SGD optimizer for row- or column-normalized 2D parameters on oblique manifolds. This optimizer performs SGD on oblique manifolds, where parameters are constrained to have unit-norm rows or columns. It implements Riemannian SGD with manifold-aware gradient updates and retraction operations. References: - An Introduction to Optimization on Smooth Manifolds (Nicolas Boumal) - EDM2: https://arxiv.org/abs/2312.02696 - Jianlin Su: https://kexue.fm/archives/11196 - Raman et al.: https://arxiv.org/abs/1909.06463 - Franz Cesista: https://leloykun.github.io/ponder/steepest-descent-stiefel/#6-bonus-a-muon-like-optimizer-for-the-embedding-and-unembedding-layers Args: lr: learning rate momentum: momentum coefficient weight_decay: weight decay coefficient dim: The dimension to normalize over eps: epsilon for numerical stability """ def __init__( self, params: list[torch.nn.Parameter], lr: float = 1e-3, momentum: float = 0.9, weight_decay: float = 0.0, dim: int = 0, eps: float = 1e-8, ) -> None: if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if momentum < 0.0 or momentum >= 1.0: raise ValueError(f"Invalid momentum value: {momentum}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, momentum=momentum, weight_decay=weight_decay, dim=dim, eps=eps, ) super().__init__(params, defaults)
[docs] @torch.no_grad() # type: ignore[misc] def step(self, closure: Callable[[], float] | None = None) -> float | None: """Performs a single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. """ loss = closure() if closure is not None else None for group in self.param_groups: lr = group["lr"] mom = group["momentum"] wd = group["weight_decay"] dim = group["dim"] eps = group["eps"] for param in group["params"]: if param.grad is None: continue if param.ndim != 2: raise ValueError("ObliqueSGD only supports 2D parameters") grad = param.grad # Initialize momentum buffer if needed state = self.state[param] if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(param) buf = state["momentum_buffer"] # theory style momentum buf = torch.add(grad, buf, alpha=mom) # Apply Riemannian gradient update _compute_riemannian_grad_and_update(param, buf, dim, lr, wd) # Retraction back to the manifold, the hyper-sphere torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) return loss
[docs] class ObliqueAdam(Optimizer): """Adam optimizer for row- or column-normalized 2D parameters on oblique manifolds. This optimizer adapts an Adam-like algorithm to work on oblique manifolds, where parameters are constrained to have unit-norm rows or columns. It combines adaptive momentum estimation with Riemannian gradient computation and manifold retraction. """ def __init__( self, params: list[torch.nn.Parameter], lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.99), weight_decay: float = 0.0, dim: int = 0, eps: float = 1e-8, correct_bias: bool = True, ) -> None: """An Adam-like optimizer for Normalized 2d Parameters Args: lr: The learning rate. betas: The coefficients used for computing running averages of gradient and its square. weight_decay: The weight decay coefficient. dim: The dimension to normalize over. eps: The epsilon for numerical stability. correct_bias: Whether to correct bias in Adam-like computation. """ if lr < 0.0: raise ValueError(f"Invalid learning rate: {lr}") if betas[0] < 0.0 or betas[0] >= 1.0: raise ValueError(f"Invalid beta1 value: {betas[0]}") if betas[1] < 0.0 or betas[1] >= 1.0: raise ValueError(f"Invalid beta2 value: {betas[1]}") if weight_decay < 0.0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") defaults = dict( lr=lr, betas=betas, weight_decay=weight_decay, dim=dim, eps=eps, correct_bias=correct_bias, ) super().__init__(params, defaults)
[docs] @torch.no_grad() # type: ignore[misc] def step(self, closure: Callable[[], float] | None = None) -> float | None: """Performs a single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. """ loss = closure() if closure is not None else None for group in self.param_groups: lr = group["lr"] betas = group["betas"] wd = group["weight_decay"] dim = group["dim"] eps = group["eps"] correct_bias = group["correct_bias"] for param in group["params"]: if param.grad is None: continue if param.ndim != 2: raise ValueError("ObliqueAdam only supports 2D parameters") state = self.state[param] if "step" not in state: state["step"] = 0 grad = param.grad # Initialize momentum buffer if needed if "exp_avg" not in state: state["exp_avg"] = torch.zeros_like(param) if "exp_avg_sq" not in state: state["exp_avg_sq"] = torch.zeros_like(param) exp_avg = state["exp_avg"] exp_avg_sq = state["exp_avg_sq"] # Increment step counter state["step"] += 1 step = state["step"] # Update biased first and second moment estimates exp_avg.mul_(betas[0]).add_(grad, alpha=1 - betas[0]) exp_avg_sq.mul_(betas[1]).addcmul_(grad, grad, value=1 - betas[1]) if correct_bias: # step size correction for ADAM moments EMA bias_correction1 = 1.0 - betas[0] ** step bias_correction2 = 1.0 - betas[1] ** step else: bias_correction1 = 1.0 bias_correction2 = 1.0 norm_grad = (exp_avg / bias_correction1) / (exp_avg_sq.sqrt() / bias_correction2 + eps) # Apply Riemannian gradient update _compute_riemannian_grad_and_update(param, norm_grad, dim, lr, wd) # Retraction back to the manifold, i.e. the hyper-sphere torch.nn.functional.normalize(param, p=2.0, dim=dim, eps=eps, out=param) return loss
def _compute_riemannian_grad_and_update( param: torch.Tensor, grad_like: torch.Tensor, dim: int, lr: float, wd: float ) -> None: """Compute Riemannian gradient for oblique manifold and update parameter in-place. Args: param: Parameter tensor (2D) grad_like: Gradient-like tensor (momentum buffer or normalized gradient) dim: The dimension to normalize over lr: Learning rate wd: Weight decay coefficient """ inner = (param * grad_like).sum(dim=dim, keepdim=True) riem_grad = torch.add(grad_like, param * inner, alpha=-1) # Add decoupled weight decay param.mul_(1 - lr * wd) # Apply update in-place param.add_(riem_grad, alpha=-lr)