emerging_optimizers.orthogonalized_optimizers#
OrthogonalizedOptimizer#
- class emerging_optimizers.orthogonalized_optimizers.OrthogonalizedOptimizer(
- params: Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]],
- lr: float,
- momentum_beta: float,
- use_nesterov: bool,
- weight_decay: float,
- use_decoupled_weight_decay: bool,
- split_qkv: bool,
- is_qkv_fn: Callable[[Tensor], bool] | None,
- qkv_split_shapes: tuple[int, int, int] | None,
- fp32_matmul_prec: str,
- orthogonalize_fn: Callable | None = None,
- scale_factor_fn: Callable | None = None,
- **kwargs: Any,
Base class for orthogonalized optimizers.
This class is a wrapper around a base optimizer that performs orthogonalization on the updates. The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:
Carlson, D., Cevher, V., and Carin, L. Stochastic spectral descent for Restricted Boltzmann Machines. In International Conference on Artificial Intelligence and Statistics (2015a).
Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. Stochastic Spectral Descent for Discrete Graphical Models. In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).
Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. Preconditioned spectral descent for deep learning. In Neural Information Processing Systems (2015b).
Flynn, T. The duality structure gradient descent algorithm: analysis and applications to neural networks. arXiv preprint arXiv:1708.00523 (2017). [arXiv:1708.00523]
Note
Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across the network must have the same size.
- Parameters:
params – Iterable of parameters to optimize or dicts defining parameter groups
lr – The learning rate used by the internal SGD.
momentum_beta – The momentum used by the internal SGD.
use_nesterov – Whether to use Nesterov-style momentum in the internal SGD.
weight_decay – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_weight_decay – Whether to use decoupled weight decay, default to be True.
split_qkv – Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
is_qkv_fn – Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
qkv_split_shapes – For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers representing the sizes of Q, K, V components along the first dimension.
fp32_matmul_prec – Precision of the matmul operations in optimizer states GEMM operations.
orthogonalize_fn – Function to orthogonalize the updates.
scale_factor_fn – Function to compute the scale factor for the update.
**kwargs – Arguments passed through to the base optimizer.
Note
Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them.
- orthogonalize( ) Tensor [source]#
Orthogonalize the momentum.
- Parameters:
p – The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example.
grad – The momentum tensor.
- Returns:
The orthogonalized gradient tensor.
Muon#
- class emerging_optimizers.orthogonalized_optimizers.Muon(
- params: Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]],
- lr: float = 0.0003,
- momentum_beta: float = 0.95,
- use_nesterov: bool = True,
- weight_decay: float = 0.01,
- use_decoupled_weight_decay: bool = True,
- split_qkv: bool = False,
- is_qkv_fn: Callable[[Tensor], bool] | None = None,
- qkv_split_shapes: tuple[int, int, int] | None = None,
- fp32_matmul_prec: str = 'medium',
- coefficient_type: str = 'quintic',
- num_ns_steps: int = 5,
- scale_mode: str = 'spectral',
- extra_scale_factor: float = 1.0,
Muon: MomentUm Orthogonalized by Newton-schulz
Muon runs standard SGD-momentum with Nesterov momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.
Orthogonalization can be viewed as steepest descent in the spectral norm. The theoretical foundation is based on modular duality and norm-constrained optimization.
This implementation incorporates decoupled weight decay, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe.
References
Jordan, K. Muon Optimizer Implementation. [GitHub]
Modular Duality in Deep Learning. arXiv:2410.21265 (2024). [arXiv:2410.21265]
Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529 (2025). [arXiv:2502.07529]
Warning
This optimizer requires that all parameters passed in are 2D.
It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those should all be optimized by a standard method (e.g., AdamW).
- Parameters:
params – Iterable of parameters to optimize or dicts defining parameter groups
lr – The learning rate used by the internal SGD.
momentum_beta – The momentum used by the internal SGD.
use_nesterov – Whether to use Nesterov-style momentum in the internal SGD.
weight_decay – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
use_decoupled_weight_decay – Whether to use decoupled weight decay, default to be True.
split_qkv – Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
is_qkv_fn – Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
qkv_split_shapes – For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers representing the sizes of Q, K, V components along the first dimension.
fp32_matmul_prec – Precision of the matmul operations in optimizer states GEMM operations.
coefficient_type – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].
num_ns_steps – The number of iteration steps to use in the Newton-Schulz iteration.
scale_mode – The type of scale factor to use for the update. Defaults to “spectral” style scaling.
extra_scale_factor – The additional scale factor to use for the update.
Newton-Schulz#
- emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz(
- x: Tensor,
- steps: int,
- coefficient_type: str = 'quintic',
- custom_coefficient_sets: list[tuple[float, float, float]] | None = None,
- eps: float = 1e-07,
- transpose: bool | None = None,
- tp_group: ProcessGroup | None = None,
Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.
Newton-Schulz iteration to compute the zeroth power / orthogonalization of x. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero and minimize variance. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce \(UV^T\) but rather something like \(US'V^T\) where \(S'\) is diagonal with noisy values around 1, which turns out not to hurt model performance at all relative to \(UV^T\), where \(USV^T = G\) is the SVD.
- Parameter
coefficient_type
can be one of the following “simple”: Default coefficient set.
“quintic”: Quintic iteration with optimized coefficients.
“polar_express”: Polar Express iteration with optimized coefficients.
“custom”: Custom coefficient sets.
- Parameters:
x – The tensor to be orthogonalized.
steps – Number of Newton-Schulz iterations.
coefficient_type – Type of coefficient set to use for the Newton-Schulz iteration.
custom_coefficient_sets – Custom coefficient sets to use for the Newton-Schulz iteration.
eps – Small constant to avoid division by zero.
transpose – Whether to transpose the tensor to perform whitening on the smaller dimension. If None, will be determined based on the size of the tensor.
tp_group – The process group for communication if input is distributed.
- Returns:
The orthogonalization of x.
- Parameter
- emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_step( ) Tensor [source]#
Perform a single Newton-Schulz iteration step.
This function performs a single Newton-Schulz iteration step. It supports distributed input that’s sharded along the smaller (orthogonalize) dimension.
Warning
If distributed, this function doesn’t have the information to verify that X is sharded along the smaller (orthogonalize) dimension. It is user’s responsibility to ensure that X is sharded correctly.
- Parameters:
X – The tensor to be orthogonalized.
a – The a coefficient.
b – The b coefficient.
c – The c coefficient.
tp_group – The process group to use for the all-reduce.
- Returns:
The orthogonalization of X.
- emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_tp(
- x: Tensor,
- steps: int,
- coefficient_type: str,
- tp_group: ProcessGroup,
- partition_dim: int | None = None,
- mode: Literal['duplicated', 'distributed'] = 'duplicated',
Tensor Parallel Newton-Schulz iteration.
This function uses partition_dim to determine along which dimension the input tensor is sharded. Transpose is set based on the partition_dim. If partition_dim is None, the input tensor is not sharded and the function will fall back to the non-TP path.
Warning
If partition_dim is the smaller dim of the input tensor, distributed mode will run Newton-Schulz along the long dimension which wastes compute. Although we reuse the partition_dim name, the default value is None which means no partition instead of -1.
Note
This function is designed to provide tensor parallel support for most common use of Newton-Schulz. Many arguments, e.g. custom coefficient sets and custom eps, are not supported.
mode
can be one of the following:“duplicated”: The input tensor is duplicated and orthogonalized on each rank.
“distributed”: The input tensor is partitioned along the partition_dim and orthogonalized on each rank.
- Parameters:
x – The tensor to be orthogonalized. Must has partition_dim and tensor_model_parallel set by TransformerEngine.
steps – Number of Newton-Schulz iterations.
coefficient_type – Type of coefficient set to use for the Newton-Schulz iteration.
partition_dim – The dimension to partition the tensor.
tp_group – The process group for communication if input is distributed.
mode – The mode to use for the Newton-Schulz iteration.