emerging_optimizers.orthogonalized_optimizers#

OrthogonalizedOptimizer#

class emerging_optimizers.orthogonalized_optimizers.OrthogonalizedOptimizer(
params,
lr,
momentum,
weight_decay,
*,
nesterov,
weight_decay_method,
fp32_matmul_prec,
scaled_orthogonalize_fn=None,
**kwargs,
)[source]#

Base class for orthogonalized optimizers.

This class is a wrapper around a base optimizer that performs orthogonalization on the updates. The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers:

  • Carlson, D., Cevher, V., and Carin, L. Stochastic spectral descent for Restricted Boltzmann Machines. In International Conference on Artificial Intelligence and Statistics (2015a).

  • Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. Stochastic Spectral Descent for Discrete Graphical Models. In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016).

  • Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. Preconditioned spectral descent for deep learning. In Neural Information Processing Systems (2015b).

  • Flynn, T. The duality structure gradient descent algorithm: analysis and applications to neural networks. arXiv preprint arXiv:1708.00523 (2017). [arXiv:1708.00523]

Note

OrthogonalizedOptimizer as base class doesn’t directly support orthogonalizing fused parameters separately. Subclass can override the orthogonalize function to support this, see example below.

Split QKV example#
class SplitQkvOrthogonalizedOptimizer(OrthogonalizedOptimizer):
    def __init__(..., split_qkv_shapes):
        super().__init__(...)
        self.qkv_split_shapes = split_qkv_shapes

    def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor, **kwargs: Any) -> torch.Tensor:

        # Alternative is passing "is_qkv" to scaled_orthogonalize_fn and split inside the
        # scaled_orthogonalize_fn.
        if getattr(p, "is_qkv", False) or kwargs.get("is_qkv", False):
            qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0)
            qkv_orthogonalized = [self.scaled_orthogonalize_fn(g) for g in qkv_grads]
            grad = torch.cat([orthogonalized for orthogonalized in qkv_orthogonalized])
        else:
            grad = self.scaled_orthogonalize_fn(grad)

        return grad
Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations.

  • scaled_orthogonalize_fn (Callable | None) – Function to orthogonalize and scale the updates.

  • **kwargs (Any) – Arguments passed through to the base optimizer.

Note

Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them.

orthogonalize(p, grad, **kwargs)[source]#

Orthogonalize the momentum.

The default orthogonalize function calls the scaled_orthogonalize_fn with the gradient. Subclass can override this function to implement different orthogonalization logic as well as split fused parameters. For example, a scaled_orthogonalize_fn function can get attributes from p or from kwargs to determine if the parameter is a fused parameter and should be split for preconditioning.

Note

N-D parameters can be supported by overriding this function. For example, convolution weight can be supported by reshaping to [output_channels, input_channels * kernel_height * kernel_width], i.e. treating convolution as matrix multiplication with im2col.

Parameters:
  • p (Tensor) – The parameter tensor. It is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example. Although not used in this default orthogonalize function.

  • grad (Tensor) – The momentum tensor.

  • **kwargs (Any) – keyword arguments of the param_group that p was belonged to.

Returns:

The orthogonalized gradient tensor.

Return type:

Tensor

post_weight_update_fn_inplace(p)[source]#

Function called after the final weight update.

Subclasses can override this to implement custom behavior after the weight update. For example, to implement hyperball-style updates that preserve weight norms.

Warning

This function is experimental and may change in future versions.

Parameters:

p (Tensor) – The parameter tensor (already updated).

Return type:

None

pre_weight_update_fn_inplace(p, update)[source]#

Function called before the final weight update.

Subclasses can override this to implement custom behavior before the weight update. For example, to implement hyperball-style updates that preserve weight norms.

Warning

This function is experimental and may change in future versions.

Parameters:
  • p (Tensor) – The parameter tensor.

  • update (Tensor) – The orthogonalized gradient tensor (will be applied as p -= lr * update).

Return type:

None

step(closure: None = None) None[source]#
step(
closure: Callable[[], float],
) float

Performs a single optimization step.

Parameters:

closure – A closure that reevaluates the model and returns the loss.

Muon#

class emerging_optimizers.orthogonalized_optimizers.Muon(
params,
lr=0.0003,
momentum=0.95,
weight_decay=0.01,
*,
nesterov=False,
weight_decay_method='decoupled',
fp32_matmul_prec='medium',
coefficient_type='quintic',
num_ns_steps=5,
scale_mode='spectral',
extra_scale_factor=1.0,
use_syrk=False,
)[source]#

Muon: MomentUm Orthogonalized by Newton-schulz

Muon runs standard SGD-momentum with Nesterov momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.

Orthogonalization can be viewed as steepest descent in the spectral norm. The theoretical foundation is based on modular duality and norm-constrained optimization.

This implementation incorporates decoupled weight decay, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe.

References

  • Jordan, K. Muon Optimizer Implementation. [GitHub]

  • Modular Duality in Deep Learning. arXiv:2410.21265 (2024). [arXiv:2410.21265]

  • Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529 (2025). [arXiv:2502.07529]

Warning

  • This optimizer requires that all parameters passed in are 2D.

  • It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those can all be optimized by a standard method (e.g., AdamW).

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations.

  • coefficient_type (Literal['simple', 'quintic', 'polar_express', 'aol', 'custom']) – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].

  • num_ns_steps (int) – The number of iteration steps to use in the Newton-Schulz iteration.

  • scale_mode (Literal['shape_scaling', 'spectral', 'unit_rms_norm']) – The type of scale factor to use for the update. Defaults to “spectral” style scaling.

  • extra_scale_factor (float) – The additional scale factor to use for the update. Setting it to 0.2 can closely match the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982.

  • use_syrk (bool) – Whether to use the Triton kernel for the Newton-Schulz iteration.

Scion#

class emerging_optimizers.orthogonalized_optimizers.Scion(
params,
lr=0.0003,
momentum=0.95,
*,
fp32_matmul_prec='medium',
coefficient_type='quintic',
num_ns_steps=5,
spectral_radius=1.0,
)[source]#

Scion: Stochastic CondItional descent with Operator Norms

Scion runs standard SGD-momentum and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.

This implementation incorporates step_size and spectral_radius, refer to Scion which views weight decay as constrained optimization via Frank-Wolfe.

References

  • Training Deep Learning Models with Norm-Constrained LMOs. arXiv:2502.07529 (2025). [arXiv:2502.07529]

Warning

  • This optimizer requires that all parameters passed in are 2D.

  • It should not be used for the embedding layer, the final fully connected layer, or any 1-D parameters; those should all be optimized by the appropriate LMO for that layer. For example, for 1d params, it is scaled by the ell_inf radius.

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum (float) – The momentum used by the internal SGD.

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations.

  • coefficient_type (Literal['simple', 'quintic', 'polar_express', 'aol', 'custom']) – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].

  • num_ns_steps (int) – The number of iteration steps to use in the Newton-Schulz iteration.

  • spectral_radius (float) – The spectral radius to use for the update, we are scaling the LMO by this spectral radius.

MOP#

class emerging_optimizers.orthogonalized_optimizers.MOP(
params,
lr=0.0003,
momentum=0.95,
weight_decay=0.01,
*,
nesterov=False,
weight_decay_method='decoupled',
fp32_matmul_prec='highest',
scale_mode='nuclear_norm',
extra_scale_factor=1.0,
)[source]#

MOP: Momentum Orthogonalized by Polar decomposition

Warning

This optimizer is experimental and not yet thoroughly tested.

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations.

  • scale_mode (Literal['shape_scaling', 'spectral', 'unit_rms_norm', 'nuclear_norm']) – The type of scale factor to use for the update. Defaults to nuclear_norm style scaling.

  • extra_scale_factor (float) – The additional scale factor to use for the update.

MuonHyperball#

class emerging_optimizers.orthogonalized_optimizers.MuonHyperball(
*args,
hyperball_eps=1e-08,
hyperball_radius=None,
**kwargs,
)[source]#

Muon optimizer with hyperball-style norm-preserving weight updates.

This optimizer extends Muon by performing gradient descent on the sphere manifold while preserving the weight norm. The update rule is:

\[W_{t+1} = R \cdot \text{normalize}(W_t - \text{lr} \cdot R \cdot \text{normalize}(\text{update}))\]

where \(R\) is the Frobenius norm of \(W_t\) (or a user-specified radius). This keeps the weight matrix at constant scale while updating.

Warning

This optimizer is experimental and may change in future versions.

See Muon for full documentation of the base Muon optimizer.

Parameters:
  • *args (Any) – Arguments passed to Muon.

  • hyperball_eps (float) – Epsilon for numerical stability in normalization. Default: 1e-8.

  • hyperball_radius (float | None) – Fixed radius for the hyperball. If None (default), uses each parameter’s initial Frobenius norm as its radius. If specified, all parameters will be rescaled to have this radius at initialization.

  • **kwargs (Any) – Keyword arguments passed to Muon.

post_weight_update_fn_inplace(p)[source]#

Normalize the updated weights and scale back to original norm using Frobenius norm.

Parameters:

p (Tensor) – The parameter tensor (already updated).

Return type:

None

pre_weight_update_fn_inplace(p, update)[source]#

Store the original weight norm and normalize the update using Frobenius norm.

Parameters:
  • p (Tensor) – The parameter tensor.

  • update (Tensor) – The orthogonalized gradient tensor.

Return type:

None

PolarGrad#

class emerging_optimizers.orthogonalized_optimizers.PolarGrad(
params,
lr=0.0003,
momentum=0.95,
weight_decay=0.01,
*,
nesterov=False,
weight_decay_method='decoupled',
fp32_matmul_prec='highest',
coefficient_type='quintic',
num_ns_steps=5,
extra_scale_factor=1.0,
)[source]#

PolarGrad: Polar Gradient Methods.

PolarGrad runs standard SGD-momentum with Nesterov momentum, and then performs an orthogonalization post-processing step, in which each 2D parameter’s update is replaced with the nearest orthogonal matrix. Note that the update is also scaled by the nuclear norm of the momentum term. This is equivalent to solving the steepest descent w.r.t. the spectral norm, as opposed to the LMO formulation of Scion and Muon. To efficiently orthogonalize each update, Newton-Schulz iteration is used, which has the advantage that it may be stably run on tensor cores on GPUs.

This implementation incorporates decoupled weight decay.

References

  • PolarGrad: A Class of Matrix-Gradient Optimizers from a Unifying Preconditioning Perspective. arXiv:2505.21799 (2025). [arXiv:2505.21799]

  • Lau, T. T.-K. PolarGrad Optimizer Implementation. [polar_grad.py]

Warning

  • This optimizer requires that all parameters passed in are 2D.

  • It should not be used for the embedding layer, the final fully connected layer (with the Newton-Schulz iteration), or any 1-D parameters; those can all be optimized by a standard method (e.g., AdamW).

Parameters:
  • params (Iterable[Tensor] | Iterable[dict[str, Any]] | Iterable[tuple[str, Tensor]]) – Iterable of parameters to optimize or dicts defining parameter groups

  • lr (float) – The learning rate used by the internal SGD.

  • momentum (float) – The momentum used by the internal SGD.

  • weight_decay (float) – The weight decay used by the optimizer, default to be decoupled weight decay. See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101

  • nesterov (bool) – Whether to use Nesterov-style momentum in the internal SGD.

  • weight_decay_method (Literal['decoupled', 'independent', 'l2']) – Method to apply weight decay, see WeightDecayMixin for more details.

  • fp32_matmul_prec (Literal['highest', 'high', 'medium']) – Precision of the matmul operations in optimizer states GEMM operations.

  • coefficient_type (Literal['simple', 'quintic', 'polar_express', 'aol', 'custom']) – The type of coefficient set to use for the Newton-Schulz iteration. Can be one of [“simple”, “quintic”, “polar_express”].

  • num_ns_steps (int) – The number of iteration steps to use in the Newton-Schulz iteration.

  • extra_scale_factor (float) – The additional scale factor to use for the update. Setting it to 0.2 can closely match the update RMS norm of AdamW as suggested by https://arxiv.org/abs/2502.16982.

Newton-Schulz#

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz(
x,
steps,
coefficient_type='quintic',
custom_coefficient_sets=None,
eps=1e-07,
transpose=None,
tp_group=None,
use_syrk=False,
)[source]#

Use Newton-Schulz iteration to compute the zeroth power / orthogonalization of x.

Newton-Schulz iteration to compute the zeroth power / orthogonalization of x. We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero and minimize variance. For the purpose of minimizing steps, it turns out to be empirically effective to keep increasing the slope at zero even beyond the point where the iteration no longer converges all the way to one everywhere on the interval. This iteration therefore does not produce \(UV^T\) but rather something like \(US'V^T\) where \(S'\) is diagonal with noisy values around 1, which turns out not to hurt model performance at all relative to \(UV^T\), where \(USV^T = G\) is the SVD.

Parameter coefficient_type can be one of the following
  • “simple”: Default coefficient set.

  • “quintic”: Quintic iteration with optimized coefficients.

  • “polar_express”: Polar Express iteration with optimized coefficients.

  • “custom”: Custom coefficient sets.

Parameters:
  • x (Tensor) – The tensor to be orthogonalized.

  • steps (int) – Number of Newton-Schulz iterations.

  • coefficient_type (Literal['simple', 'quintic', 'polar_express', 'aol', 'custom']) – Type of coefficient set to use for the Newton-Schulz iteration.

  • custom_coefficient_sets (list[tuple[float, float, float]] | None) – Custom coefficient sets to use for the Newton-Schulz iteration.

  • eps (float) – Small constant to avoid division by zero.

  • transpose (bool | None) – Whether to transpose the tensor to perform whitening on the smaller dimension. If None, will be determined based on the size of the tensor.

  • tp_group (ProcessGroup | None) – The process group for communication if input is distributed.

  • use_syrk (bool) – Whether to use the Triton kernel for the Newton-Schulz iteration.

Returns:

The orthogonalization of x.

Return type:

Tensor

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_step(X, a, b, c, tp_group=None)[source]#

Perform a single Newton-Schulz iteration step.

This function performs a single Newton-Schulz iteration step. It supports distributed input that’s sharded along the smaller (orthogonalize) dimension.

Warning

If distributed, this function doesn’t have the information to verify that X is sharded along the smaller (orthogonalize) dimension. It is user’s responsibility to ensure that X is sharded correctly.

Parameters:
  • X (Tensor) – The tensor to be orthogonalized.

  • a (float) – The a coefficient.

  • b (float) – The b coefficient.

  • c (float) – The c coefficient.

  • tp_group (ProcessGroup | None) – The process group to use for the all-reduce.

Returns:

The orthogonalization of X.

Return type:

Tensor

emerging_optimizers.orthogonalized_optimizers.muon_utils.newton_schulz_tp(
x,
steps,
coefficient_type,
tp_group,
partition_dim=None,
tp_mode='duplicated',
)[source]#

Tensor Parallel Newton-Schulz iteration.

This function uses partition_dim to determine along which dimension the input tensor is sharded. Transpose is set based on the partition_dim. If partition_dim is None, the input tensor is not sharded and the function will fall back to the non-TP path.

Warning

If partition_dim is the smaller dim of the input tensor, distributed mode will run Newton-Schulz along the long dimension which wastes compute. Although we reuse the partition_dim name, the default value is None which means no partition instead of -1.

Note

This function is designed to provide tensor parallel support for most common use of Newton-Schulz. Many arguments, e.g. custom coefficient sets and custom eps, are not supported.

tp_mode can be one of the following:
  • “duplicated”: The input tensor is duplicated and orthogonalized on each rank.

  • “distributed”: The input tensor is partitioned along the partition_dim and orthogonalized on each rank.

Parameters:
  • x (Tensor) – The tensor to be orthogonalized. Must has partition_dim and tensor_model_parallel set by TransformerEngine.

  • steps (int) – Number of Newton-Schulz iterations.

  • coefficient_type (Literal['simple', 'quintic', 'polar_express', 'aol', 'custom']) – Type of coefficient set to use for the Newton-Schulz iteration.

  • partition_dim (int | None) – The dimension to partition the tensor.

  • tp_group (ProcessGroup) – The process group for communication if input is distributed.

  • tp_mode (Literal['duplicated', 'distributed']) – The mode to use for the Newton-Schulz iteration.

Return type:

Tensor