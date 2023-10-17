# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Supported optimizer configs """ import torch from dataclasses import dataclass , field from hydra.core.config_store import ConfigStore from typing import List , Any from omegaconf import MISSING [docs] @dataclass class OptimizerConf : _target_ = MISSING _params_ : Any = field ( default_factory = lambda : { "compute_gradients" : "adam_compute_gradients" , "apply_gradients" : "adam_apply_gradients" , } ) [docs] @dataclass class AdamConf ( OptimizerConf ): _target_ : str = "torch.optim.Adam" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 amsgrad : bool = False [docs] @dataclass class SGDConf ( OptimizerConf ): _target_ : str = "torch.optim.SGD" lr : float = 1.0e-3 momentum : float = 1.0e-2 dampening : float = 0 weight_decay : float = 0 nesterov : bool = False [docs] @dataclass class AdahessianConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Adahessian" lr : float = 1.0e-1 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-4 weight_decay : float = 0.0 hessian_power : float = 1.0 _params_ : Any = field ( default_factory = lambda : { "compute_gradients" : "adahess_compute_gradients" , "apply_gradients" : "adahess_apply_gradients" , } ) [docs] @dataclass class BFGSConf ( OptimizerConf ): _target_ : str = "torch.optim.LBFGS" lr : float = 1.0 max_iter : int = 1000 max_eval : Any = None tolerance_grad : float = 1e-7 tolerance_change : float = 1e-9 history_size : int = 100 line_search_fn : Any = None # Union[None, str] _params_ : Any = field ( default_factory = lambda : { "compute_gradients" : "bfgs_compute_gradients" , "apply_gradients" : "bfgs_apply_gradients" , } ) [docs] @dataclass class AdadeltaConf ( OptimizerConf ): _target_ : str = "torch.optim.Adadelta" lr : float = 1.0 rho : float = 0.9 eps : float = 1e-6 weight_decay : float = 0 [docs] @dataclass class AdagradConf ( OptimizerConf ): _target_ : str = "torch.optim.Adagrad" lr : float = 1.0e-2 lr_decay : float = 0 weight_decay : float = 0 initial_accumulator_value : float = 0 eps : float = 1e-10 [docs] @dataclass class AdamWConf ( OptimizerConf ): _target_ : str = "torch.optim.AdamW" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0.01 amsgrad : bool = False [docs] @dataclass class SparseAdamConf ( OptimizerConf ): _target_ : str = "torch.optim.SparseAdam" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 [docs] @dataclass class AdamaxConf ( OptimizerConf ): _target_ : str = "torch.optim.Adamax" lr : float = 2.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 [docs] @dataclass class ASGDConf ( OptimizerConf ): _target_ : str = "torch.optim.ASGD" lr : float = 1.0e-2 lambd : float = 1.0e-4 alpha : float = 0.75 t0 : float = 1000000.0 weight_decay : float = 0 [docs] @dataclass class NAdamConf ( OptimizerConf ): _target_ : str = "torch.optim.NAdam" lr : float = 2.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 momentum_decay : float = 0.004 [docs] @dataclass class RAdamConf ( OptimizerConf ): _target_ : str = "torch.optim.RAdam" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 [docs] @dataclass class RMSpropConf ( OptimizerConf ): _target_ : str = "torch.optim.RMSprop" lr : float = 1.0e-2 alpha : float = 0.99 eps : float = 1e-8 weight_decay : float = 0 momentum : float = 0 centered : bool = False [docs] @dataclass class RpropConf ( OptimizerConf ): _target_ : str = "torch.optim.Rprop" lr : float = 1.0e-2 etas : List [ float ] = field ( default_factory = lambda : [ 0.5 , 1.2 ]) step_sizes : List [ float ] = field ( default_factory = lambda : [ 1.0e-6 , 50 ]) [docs] @dataclass class A2GradExpConf ( OptimizerConf ): _target_ : str = "torch_optimizer.A2GradExp" lr : float = 1e-2 # LR not support for optim, but needed to not fail schedulers beta : float = 10.0 lips : float = 10.0 [docs] @dataclass class A2GradIncConf ( OptimizerConf ): _target_ : str = "torch_optimizer.A2GradInc" lr : float = 1e-2 # LR not support for optim, but needed to not fail schedulers beta : float = 10.0 lips : float = 10.0 [docs] @dataclass class A2GradUniConf ( OptimizerConf ): _target_ : str = "torch_optimizer.A2GradUni" lr : float = 1e-2 # LR not support for optim, but needed to not fail schedulers beta : float = 10.0 lips : float = 10.0 [docs] @dataclass class AccSGDConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AccSGD" lr : float = 1.0e-3 kappa : float = 1000.0 xi : float = 10.0 small_const : float = 0.7 weight_decay : float = 0 [docs] @dataclass class AdaBeliefConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AdaBelief" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1.0e-3 weight_decay : float = 0 amsgrad : bool = False weight_decouple : bool = False fixed_decay : bool = False rectify : bool = False [docs] @dataclass class AdaBoundConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AdaBound" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) final_lr : float = 0.1 gamma : float = 1e-3 eps : float = 1e-8 weight_decay : float = 0 amsbound : bool = False [docs] @dataclass class AdaModConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AdaMod" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) beta3 : float = 0.999 eps : float = 1e-8 weight_decay : float = 0 [docs] @dataclass class AdafactorConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Adafactor" lr : float = 1.0e-3 eps2 : List [ float ] = field ( default_factory = lambda : [ 1e-30 , 1e-3 ]) clip_threshold : float = 1.0 decay_rate : float = - 0.8 beta1 : Any = None weight_decay : float = 0 scale_parameter : bool = True relative_step : bool = True warmup_init : bool = False [docs] @dataclass class AdamPConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AdamP" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 delta : float = 0.1 wd_ratio : float = 0.1 [docs] @dataclass class AggMoConf ( OptimizerConf ): _target_ : str = "torch_optimizer.AggMo" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.0 , 0.9 , 0.99 ]) weight_decay : float = 0 [docs] @dataclass class ApolloConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Apollo" lr : float = 1.0e-2 beta : float = 0.9 eps : float = 1e-4 warmup : int = 0 init_lr : float = 0.01 weight_decay : float = 0 [docs] @dataclass class DiffGradConf ( OptimizerConf ): _target_ : str = "torch_optimizer.DiffGrad" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 [docs] @dataclass class LambConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Lamb" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 [docs] @dataclass class MADGRADConf ( OptimizerConf ): _target_ : str = "torch_optimizer.MADGRAD" lr : float = 1.0e-2 momentum : float = 0.9 weight_decay : float = 0 eps : float = 1e-6 [docs] @dataclass class NovoGradConf ( OptimizerConf ): _target_ : str = "torch_optimizer.NovoGrad" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-8 weight_decay : float = 0 grad_averaging : bool = False amsgrad : bool = False [docs] @dataclass class PIDConf ( OptimizerConf ): _target_ : str = "torch_optimizer.PID" lr : float = 1.0e-3 momentum : float = 0 dampening : float = 0 weight_decay : float = 1e-2 integral : float = 5.0 derivative : float = 10.0 [docs] @dataclass class QHAdamConf ( OptimizerConf ): _target_ : str = "torch_optimizer.QHAdam" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) nus : List [ float ] = field ( default_factory = lambda : [ 1.0 , 1.0 ]) weight_decay : float = 0 decouple_weight_decay : bool = False eps : float = 1e-8 [docs] @dataclass class QHMConf ( OptimizerConf ): _target_ : str = "torch_optimizer.QHM" lr : float = 1.0e-3 momentum : float = 0 nu : float = 0.7 weight_decay : float = 1e-2 weight_decay_type : str = "grad" [docs] @dataclass class RangerConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Ranger" lr : float = 1.0e-3 alpha : float = 0.5 k : int = 6 N_sma_threshhold : int = 5 betas : List [ float ] = field ( default_factory = lambda : [ 0.95 , 0.999 ]) eps : float = 1e-5 weight_decay : float = 0 [docs] @dataclass class RangerQHConf ( OptimizerConf ): _target_ : str = "torch_optimizer.RangerQH" lr : float = 1.0e-3 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) nus : List [ float ] = field ( default_factory = lambda : [ 0.7 , 1.0 ]) weight_decay : float = 0 k : int = 6 alpha : float = 0.5 decouple_weight_decay : bool = False eps : float = 1e-8 [docs] @dataclass class RangerVAConf ( OptimizerConf ): _target_ : str = "torch_optimizer.RangerVA" lr : float = 1.0e-3 alpha : float = 0.5 k : int = 6 n_sma_threshhold : int = 5 betas : List [ float ] = field ( default_factory = lambda : [ 0.95 , 0.999 ]) eps : float = 1e-5 weight_decay : float = 0 amsgrad : bool = True transformer : str = "softplus" smooth : int = 50 grad_transformer : str = "square" [docs] @dataclass class SGDPConf ( OptimizerConf ): _target_ : str = "torch_optimizer.SGDP" lr : float = 1.0e-3 momentum : float = 0 dampening : float = 0 weight_decay : float = 1e-2 nesterov : bool = False delta : float = 0.1 wd_ratio : float = 0.1 [docs] @dataclass class SGDWConf ( OptimizerConf ): _target_ : str = "torch_optimizer.SGDW" lr : float = 1.0e-3 momentum : float = 0 dampening : float = 0 weight_decay : float = 1e-2 nesterov : bool = False [docs] @dataclass class SWATSConf ( OptimizerConf ): _target_ : str = "torch_optimizer.SWATS" lr : float = 1.0e-1 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-3 weight_decay : float = 0 amsgrad : bool = False nesterov : bool = False [docs] @dataclass class ShampooConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Shampoo" lr : float = 1.0e-1 momentum : float = 0 weight_decay : float = 0 epsilon : float = 1e-4 update_freq : int = 1 [docs] @dataclass class YogiConf ( OptimizerConf ): _target_ : str = "torch_optimizer.Yogi" lr : float = 1.0e-2 betas : List [ float ] = field ( default_factory = lambda : [ 0.9 , 0.999 ]) eps : float = 1e-3 initial_accumulator : float = 1e-6 weight_decay : float = 0 [docs] def register_optimizer_configs () -> None : cs = ConfigStore . instance () cs . store ( group = "optimizer" , name = "adam" , node = AdamConf , ) cs . store ( group = "optimizer" , name = "sgd" , node = SGDConf , ) cs . store ( group = "optimizer" , name = "adahessian" , node = AdahessianConf , ) cs . store ( group = "optimizer" , name = "bfgs" , node = BFGSConf , ) cs . store ( group = "optimizer" , name = "adadelta" , node = AdadeltaConf , ) cs . store ( group = "optimizer" , name = "adagrad" , node = AdagradConf , ) cs . store ( group = "optimizer" , name = "adamw" , node = AdamWConf , ) cs . store ( group = "optimizer" , name = "sparse_adam" , node = SparseAdamConf , ) cs . store ( group = "optimizer" , name = "adamax" , node = AdamaxConf , ) cs . store ( group = "optimizer" , name = "asgd" , node = ASGDConf , ) cs . store ( group = "optimizer" , name = "nadam" , node = NAdamConf , ) cs . store ( group = "optimizer" , name = "radam" , node = RAdamConf , ) cs . store ( group = "optimizer" , name = "rmsprop" , node = RMSpropConf , ) cs . store ( group = "optimizer" , name = "rprop" , node = RpropConf , ) cs . store ( group = "optimizer" , name = "a2grad_exp" , node = A2GradExpConf , ) cs . store ( group = "optimizer" , name = "a2grad_inc" , node = A2GradIncConf , ) cs . store ( group = "optimizer" , name = "a2grad_uni" , node = A2GradUniConf , ) cs . store ( group = "optimizer" , name = "accsgd" , node = AccSGDConf , ) cs . store ( group = "optimizer" , name = "adabelief" , node = AdaBeliefConf , ) cs . store ( group = "optimizer" , name = "adabound" , node = AdaBoundConf , ) cs . store ( group = "optimizer" , name = "adamod" , node = AdaModConf , ) cs . store ( group = "optimizer" , name = "adafactor" , node = AdafactorConf , ) cs . store ( group = "optimizer" , name = "adamp" , node = AdamPConf , ) cs . store ( group = "optimizer" , name = "aggmo" , node = AggMoConf , ) cs . store ( group = "optimizer" , name = "apollo" , node = ApolloConf , ) cs . store ( group = "optimizer" , name = "diffgrad" , node = DiffGradConf , ) cs . store ( group = "optimizer" , name = "lamb" , node = LambConf , ) cs . store ( group = "optimizer" , name = "madgrad" , node = MADGRADConf , ) cs . store ( group = "optimizer" , name = "novograd" , node = NovoGradConf , ) cs . store ( group = "optimizer" , name = "pid" , node = PIDConf , ) cs . store ( group = "optimizer" , name = "qhadam" , node = QHAdamConf , ) cs . store ( group = "optimizer" , name = "qhm" , node = QHMConf , ) cs . store ( group = "optimizer" , name = "ranger" , node = RangerConf , ) cs . store ( group = "optimizer" , name = "ranger_qh" , node = RangerQHConf , ) cs . store ( group = "optimizer" , name = "ranger_va" , node = RangerVAConf , ) cs . store ( group = "optimizer" , name = "sgdp" , node = SGDPConf , ) cs . store ( group = "optimizer" , name = "sgdw" , node = SGDWConf , ) cs . store ( group = "optimizer" , name = "swats" , node = SWATSConf , ) cs . store ( group = "optimizer" , name = "shampoo" , node = ShampooConf , ) cs . store ( group = "optimizer" , name = "yogi" , node = YogiConf , )