emerging_optimizers.orthogonalized_optimizers#

OrthogonalizedOptimizer#

class emerging_optimizers.orthogonalized_optimizers.OrthogonalizedOptimizer(
params,
lr,
momentum_beta,
weight_decay,
*,
use_nesterov,
weight_decay_method,
fp32_matmul_prec,
scaled_orthogonalize_fn=None,
**kwargs,
)[source]#

Base class for orthogonalized optimizers.

This class is a wrapper around a base optimizer that performs orthogonalization on the updates. The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:

  • Carlson, D., Cevher, V., and Carin, L. Stochastic spectral descent for Restricted Boltzmann Machines. In International Conference on Artificial Intelligence and Statistics (2015a).

  • Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. Stochastic Spectral Descent for Discrete Graphical Models. In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).

  • Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. Preconditioned spectral descent for deep learning. In Neural Information Processing Systems (2015b).

  • Flynn, T. The duality structure gradient descent algorithm: analysis and applications to neural networks. arXiv preprint arXiv:1708.00523 (2017). [arXiv:1708.00523]

Note

OrthogonalizedOptimizer as base class doesn’t directly support orthogonalizing fused parameters separately. Subclass can override the orthogonalize function to support this, see example below.

Split QKV example#
class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
    def __init__(..., split_qkv_shapes):
        super().__init__(...)
        self.qkv_split_shapes = split_qkv_shapes

    def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:

        # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
        # scaled_orthogonalize_fn.
        if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
            qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
            qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
            grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
        else:
            grad = self.scaled_orthogonalize_fn(grad)

        return grad
Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum_beta (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • use_nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (str) – Precision of the matmul operations in optimizer states GEMM operations.

  • scaled_orthogonalize_fn (Callable | None) – Function to orthogonalize and scale the updates.

  • **kwargs (Any) – Arguments passed through to the base optimizer.

Note

Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them.

orthogonalize(p, grad, **kwargs)[source]#

Orthogonalize the momentum.

The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can override this function to implement different orthogonalization logic as well as split fused parameters. For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if the parameter is a fused parameter and should be split for preconditioning.

Parameters:
  • p (Tensor) – The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example. Although not used in this default orthogonalize function.

  • grad (Tensor) – The momentum tensor.

  • **kwargs (Any) – keyword arguments of the param_group that p was belonged to.

Returns:

The orthogonalized gradient tensor.

Return type:

Tensor

step(closure=None)[source]#

Performs a single optimization step.

Parameters:

closure (Callable[[], float] | None) – A closure that reevaluates the model and returns the loss.

Return type:

float | None

Muon#

class emerging_optimizers.orthogonalized_optimizers.Muon(
params,
lr=0.0003,
momentum_beta=0.95,
weight_decay=0.01,
*,
use_nesterov=False,
weight_decay_method='decoupled',
fp32_matmul_prec='medium',
coefficient_type='quintic',
num_ns_steps=5,
scale_mode='spectral',
extra_scale_factor=1.0,
use_syrk=False,
)[source]#

Muon: MomentUm Orthogonalized by Newton-schulz

Muon runs standard SGD-momentum with Nesterov momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.

Orthogonalization can be viewed as steepest descent in the spectral norm. The theoretical foundation is based on modular duality and norm-constrained optimization.

This implementation incorporates decoupled weight decay, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe.

References

  • Jordan, K. Muon Optimizer Implementation. [GitHub]

  • Modular Duality in Deep Learning. arXiv:2410.21265 (2024). [arXiv:2410.21265]

  • Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529 (2025). [arXiv:2502.07529]

Warning

  • This optimizer requires that all parameters passed in are 2D.

  • It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those should all be optimized by a standard method (e.g., AdamW).

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum_beta (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • use_nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (str) – Precision of the matmul operations in optimizer states GEMM operations.

  • coefficient_type (str) – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].

  • num_ns_steps (int) – The number of iteration steps to use in the Newton-Schulz iteration.

  • scale_mode (str) – The type of scale factor to use for the update. Defaults to “spectral” style scaling.

  • extra_scale_factor (float) – The additional scale factor to use for the update.

  • use_syrk (bool) – Whether to use the Triton kernel for the Newton-Schulz iteration.

Scion#

class emerging_optimizers.orthogonalized_optimizers.Scion(
params,
lr=0.0003,
momentum_beta=0.95,
*,
fp32_matmul_prec='medium',
coefficient_type='quintic',
num_ns_steps=5,
spectral_radius=1.0,
)[source]#

Scion: Stochastic CondItional descent with Operator Norms

Scion runs standard SGD-momentum and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.

This implementation incorporates step_size and spectral_radius, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe.

References

  • Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529 (2025). [arXiv:2502.07529]

Warning

  • This optimizer requires that all parameters passed in are 2D.

  • It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those should all be optimized by the appropriate LMO for that layer. For example, for 1d params, it is scaled by the ell_inf radius.

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum_beta (float) – The momentum used by the internal SGD.

  • fp32_matmul_prec (str) – Precision of the matmul operations in optimizer states GEMM operations.

  • coefficient_type (str) – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].

  • num_ns_steps (int) – The number of iteration steps to use in the Newton-Schulz iteration.

  • spectral_radius (float) – The spectral radius to use for the update, we are scaling the LMO by this spectral radius.

Newton-Schulz#

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz(
x,
steps,
coefficient_type='quintic',
custom_coefficient_sets=None,
eps=1e-07,
transpose=None,
tp_group=None,
use_syrk=False,
)[source]#

Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.

Newton-Schulz iteration to compute the zeroth power / orthogonalization of x. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero and minimize variance. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce \(UV^T\) but rather something like \(US'V^T\) where \(S'\) is diagonal with noisy values around 1, which turns out not to hurt model performance at all relative to \(UV^T\), where \(USV^T = G\) is the SVD.

Parameter coefficient_type can be one of the following
  • “simple”: Default coefficient set.

  • “quintic”: Quintic iteration with optimized coefficients.

  • “polar_express”: Polar Express iteration with optimized coefficients.

  • “custom”: Custom coefficient sets.

Parameters:
  • x (Tensor) – The tensor to be orthogonalized.

  • steps (int) – Number of Newton-Schulz iterations.

  • coefficient_type (str) – Type of coefficient set to use for the Newton-Schulz iteration.

  • custom_coefficient_sets (list[tuple[float, float, float]] | None) – Custom coefficient sets to use for the Newton-Schulz iteration.

  • eps (float) – Small constant to avoid division by zero.

  • transpose (bool | None) – Whether to transpose the tensor to perform whitening on the smaller dimension. If None, will be determined based on the size of the tensor.

  • tp_group (ProcessGroup | None) – The process group for communication if input is distributed.

  • use_syrk (bool) – Whether to use the Triton kernel for the Newton-Schulz iteration.

Returns:

The orthogonalization of x.

Return type:

Tensor

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_step(X, a, b, c, tp_group=None)[source]#

Perform a single Newton-Schulz iteration step.

This function performs a single Newton-Schulz iteration step. It supports distributed input that’s sharded along the smaller (orthogonalize) dimension.

Warning

If distributed, this function doesn’t have the information to verify that X is sharded along the smaller (orthogonalize) dimension. It is user’s responsibility to ensure that X is sharded correctly.

Parameters:
  • X (Tensor) – The tensor to be orthogonalized.

  • a (float) – The a coefficient.

  • b (float) – The b coefficient.

  • c (float) – The c coefficient.

  • tp_group (ProcessGroup | None) – The process group to use for the all-reduce.

Returns:

The orthogonalization of X.

Return type:

Tensor

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_tp(
x,
steps,
coefficient_type,
tp_group,
partition_dim=None,
mode='duplicated',
)[source]#

Tensor Parallel Newton-Schulz iteration.

This function uses partition_dim to determine along which dimension the input tensor is sharded. Transpose is set based on the partition_dim. If partition_dim is None, the input tensor is not sharded and the function will fall back to the non-TP path.

Warning

If partition_dim is the smaller dim of the input tensor, distributed mode will run Newton-Schulz along the long dimension which wastes compute. Although we reuse the partition_dim name, the default value is None which means no partition instead of -1.

Note

This function is designed to provide tensor parallel support for most common use of Newton-Schulz. Many arguments, e.g. custom coefficient sets and custom eps, are not supported.

mode can be one of the following:
  • “duplicated”: The input tensor is duplicated and orthogonalized on each rank.

  • “distributed”: The input tensor is partitioned along the partition_dim and orthogonalized on each rank.

Parameters:
  • x (Tensor) – The tensor to be orthogonalized. Must has partition_dim and tensor_model_parallel set by TransformerEngine.

  • steps (int) – Number of Newton-Schulz iterations.

  • coefficient_type (str) – Type of coefficient set to use for the Newton-Schulz iteration.

  • partition_dim (int | None) – The dimension to partition the tensor.

  • tp_group (ProcessGroup) – The process group for communication if input is distributed.

  • mode (Literal['duplicated', 'distributed']) – The mode to use for the Newton-Schulz iteration.

Return type:

Tensor