Source code for emerging_optimizers.orthogonalized_optimizers.scion

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer


[docs] class Scion(OrthogonalizedOptimizer): """Scion: Stochastic CondItional descent with Operator Norms Scion runs standard SGD-momentum and then performs an orthogonalization post-processing step, in which each 2D parameter's update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs. This implementation incorporates `step_size` and `spectral_radius`, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe. References: - *Training Deep Learning Models with Norm-Constrained LMOs.* arXiv:2502.07529 (2025). [`arXiv:2502.07529 <https://arxiv.org/abs/2502.07529>`_] Warning: - This optimizer requires that all parameters passed in are 2D. - It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those should all be optimized by the appropriate LMO for that layer. For example, for 1d params, it is scaled by the `ell_inf` radius. Args: params: Iterable of parameters to optimize or dicts defining parameter groups lr: The learning rate used by the internal SGD. momentum_beta: The momentum used by the internal SGD. fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations. coefficient_type: The type of coefficient set to use for the Newton-Schulz iteration. Can be one of ["simple", "quintic", "polar_express"]. num_ns_steps: The number of iteration steps to use in the Newton-Schulz iteration. spectral_radius: The spectral radius to use for the update, we are scaling the LMO by this spectral radius. """ def __init__( self, params: ParamsT, lr: float = 3e-4, momentum_beta: float = 0.95, *, fp32_matmul_prec: str = "medium", coefficient_type: str = "quintic", num_ns_steps: int = 5, spectral_radius: float = 1.0, ) -> None: if num_ns_steps < 1: raise ValueError(f"num_ns_steps must be at least 1, got {num_ns_steps}") # Add checks for weight decay arguments to enable Franke-Wolfe step. logging.info( "Scion does not use weight decay. Setting weight_decay to 1 and weight_decay_method to decoupled." ) weight_decay = 1 weight_decay_method = "decoupled" logging.info("Scion does not use Nesterov momentum. Setting use_nesterov to False.") use_nesterov = False def scaled_orthogonalize_fn(grad: torch.Tensor) -> torch.Tensor: logging.debug( f"Orthogonalizing grad with {num_ns_steps} steps, {coefficient_type} coefficient, spectral_radius={spectral_radius}" ) orth_grad = newton_schulz(grad, steps=num_ns_steps, coefficient_type=coefficient_type, use_syrk=False) width_factor = get_muon_scale_factor(grad.size(-2), grad.size(-1), mode="unit_rms_norm") return orth_grad * width_factor * spectral_radius super().__init__( params, lr, momentum_beta, weight_decay, use_nesterov=use_nesterov, weight_decay_method=weight_decay_method, # type: ignore[arg-type] fp32_matmul_prec=fp32_matmul_prec, scaled_orthogonalize_fn=scaled_orthogonalize_fn, )