Source code for emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, override

import torch
import torch.optim as optim
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers import utils


_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
        lr: The learning rate used by the internal SGD.
        momentum_beta: The momentum used by the internal SGD.
        use_nesterov: Whether to use Nesterov-style momentum in the internal SGD.
        weight_decay: The weight decay used by the optimizer, default to be decoupled weight decay.
            See Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101
        use_decoupled_weight_decay: Whether to use decoupled weight decay, default to be True.
        split_qkv: Whether parameter is fused attention parameters (QKV, GQA, etc.), default to be False.
        is_qkv_fn: Function to check if a parameter is fused attention parameters (QKV, GQA, etc.).
        qkv_split_shapes: For grouped attention parameters (QKV, GQA, etc.), specify the shapes as a tuple of 3 integers
            representing the sizes of Q, K, V components along the first dimension.
        fp32_matmul_prec: Precision of the matmul operations in optimizer states GEMM operations.
"""


[docs] class OrthogonalizedOptimizer(optim.Optimizer): """Base class for orthogonalized optimizers. This class is a wrapper around a base optimizer that performs orthogonalization on the updates. The theoretical foundation of orthogonalization for stochastic gradient descent was developed by the following papers: - Carlson, D., Cevher, V., and Carin, L. *Stochastic spectral descent for Restricted Boltzmann Machines.* In International Conference on Artificial Intelligence and Statistics (2015a). - Carlson, D., Hsieh, Y.-P., Collins, E., Carin, L., and Cevher, V. *Stochastic Spectral Descent for Discrete Graphical Models.* In IEEE Journal of Selected Topics in Signal Processing, vol. 10, no. 2, pp. 296-311 (2016). - Carlson, D., Collins, E., Hsieh, Y.-P., Carin, L., and Cevher, V. *Preconditioned spectral descent for deep learning.* In Neural Information Processing Systems (2015b). - Flynn, T. *The duality structure gradient descent algorithm: analysis and applications to neural networks.* arXiv preprint arXiv:1708.00523 (2017). [`arXiv:1708.00523 <https://arxiv.org/abs/1708.00523>`_] Note: Orthogonalizing QKV sperately when they are fused is supported but with limitations. User must provide a function to check if a weight tensor is fused attention parameters (QKV, GQA, etc.) as well as the leading dimension of Q, K, V components. Only one split size is supported, i.e. all attention layers across the network must have the same size. Args: {_args_doc} orthogonalize_fn: Function to orthogonalize the updates. scale_factor_fn: Function to compute the scale factor for the update. **kwargs: Arguments passed through to the base optimizer. Note: Keyword arguments passed through are not checked here. Optimizer inherited from this class should check them. """ def __init__( self, params: ParamsT, lr: float, momentum_beta: float, use_nesterov: bool, weight_decay: float, use_decoupled_weight_decay: bool, split_qkv: bool, is_qkv_fn: Callable[[torch.Tensor], bool] | None, qkv_split_shapes: tuple[int, int, int] | None, fp32_matmul_prec: str, orthogonalize_fn: Callable | None = None, scale_factor_fn: Callable | None = None, **kwargs: Any, ): if orthogonalize_fn is None: logging.warning("orthogonalize_fn not provided. Using noop") orthogonalize_fn = torch.nn.Identity() if scale_factor_fn is None: logging.warning("scale_factor_fn not provided. Using default scale_factor_fn.") def return_one(*args, **kwargs): # type: ignore[no-untyped-def] return 1.0 scale_factor_fn = return_one if split_qkv: assert is_qkv_fn is not None, "is_qkv_fn must be provided when split_qkv is True" assert qkv_split_shapes is not None, "qkv_split_shapes must be provided when split_qkv is True" if len(qkv_split_shapes) != 3: raise ValueError( f"qkv_split_shapes must be a tuple of 3 integers, got {len(qkv_split_shapes)} elements" ) if not all(isinstance(s, int) for s in qkv_split_shapes): raise ValueError(f"All elements in qkv_split_shapes must be integers, got {qkv_split_shapes}") if any(s <= 0 for s in qkv_split_shapes): raise ValueError(f"All elements in qkv_split_shapes must be positive, got {qkv_split_shapes}") self.split_qkv = split_qkv self.is_qkv_fn = is_qkv_fn self.qkv_split_shapes = qkv_split_shapes self.fp32_matmul_prec = fp32_matmul_prec default_args_dict = dict( lr=lr, momentum_beta=momentum_beta, use_nesterov=use_nesterov, weight_decay=weight_decay, use_decoupled_weight_decay=use_decoupled_weight_decay, **kwargs, ) super().__init__(params, default_args_dict) self.orthogonalize_fn = orthogonalize_fn self.scale_factor_fn = scale_factor_fn
[docs] @torch.no_grad() # type: ignore[misc] @override def step(self, closure: Callable[[], float] | None = None) -> float | None: """Performs a single optimization step. Args: closure: A closure that reevaluates the model and returns the loss. """ if closure is None: loss = None else: loss = closure() for group in self.param_groups: for p in group["params"]: if p.dim() == 1: raise ValueError(f"{self.__class__.__name__} does not support 1D parameters") grad = p.grad if grad is None: continue state = self.state[p] # initialize momentum buffer if "momentum_buffer" not in state: state["momentum_buffer"] = torch.zeros_like(grad) # Subsequent update to exp_avg are all inplace, so it is not assigned back to state. exp_avg = state["momentum_buffer"] # Apply weight decay if group["weight_decay"] > 0.0: if group["use_decoupled_weight_decay"]: # Apply decoupled weight decay p.add_(p, alpha=(-group["lr"] * group["weight_decay"])) else: # add l2 regularization before preconditioning (i.e. adding a squared loss term) grad += group["weight_decay"] * p # update momentum buffer with EMA of gradient exp_avg.lerp_(grad, 1 - group["momentum_beta"]) # include nesterov momentum if group["use_nesterov"]: grad = grad.lerp(exp_avg, group["momentum_beta"]) else: grad = exp_avg with utils.fp32_matmul_precision(self.fp32_matmul_prec): grad = self.orthogonalize(p, grad) # perform weight update # scale is applied to have update RMS == 1 p.add_(grad, alpha=-group["lr"]) return loss
[docs] def orthogonalize(self, p: torch.Tensor, grad: torch.Tensor) -> torch.Tensor: """Orthogonalize the momentum. Args: p: The parameter tensor. i is necessary to pass param tensor in addition to momentum because a lot of information is only available in the param tensor, attributes for example. grad: The momentum tensor. Returns: The orthogonalized gradient tensor. """ if self.split_qkv and self.is_qkv_fn(p): # type: ignore[misc] logging.log_first_n(logging.INFO, f"split qkv with {p.shape} to {self.qkv_split_shapes}", 1) # split grouped attention parameters (e.g., QKV, GQA, etc.) qkv_grads = torch.split(grad, self.qkv_split_shapes, dim=0) # Apply Newton-Schulz to each component qkv_whitened = [self.orthogonalize_fn(g) for g in qkv_grads] qkv_scales = [self.scale_factor_fn(g.size(0), g.size(1)) for g in qkv_grads] # Apply individual scales to each component and concatenate grad = torch.cat([whitened * scale for whitened, scale in zip(qkv_whitened, qkv_scales)]) else: grad = self.orthogonalize_fn(grad) * self.scale_factor_fn(grad.size(0), grad.size(1)) return grad
OrthogonalizedOptimizer.__doc__ = OrthogonalizedOptimizer.__doc__.format(_args_doc=_args_doc) # type: ignore[union-attr]