# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Learning rate decay and weight decay incr functions."""
import logging
import math
from typing import Optional
from torch.optim.optimizer import Optimizer
logger = logging.getLogger(__name__)
[docs]
class OptimizerParamScheduler:
"""
Anneals learning rate and weight decay.
Args:
optimizer (Optimizer): the optimizer to be used
init_lr (float): initial learning rate
max_lr (float): maximum learning rate
min_lr (float): minimum learning rate
lr_warmup_steps (int): number of warmup steps
lr_decay_steps (int): number of decay steps
lr_decay_style (str): decay style for learning rate
start_wd (float): initial weight decay
end_wd (float): final weight decay
wd_incr_steps (int): number of weight decay increment steps
wd_incr_style (str): weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values
for the optimizer param scheduler. Defaults to True.
override_opt_param_scheduler (bool, optional): whether to override the optimizer param
scheduler values with the class values. Defaults to False.
wsd_decay_steps (int, optional): number of weight decay decay steps. Defaults to None.
lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay
steps. Defaults to None.
"""
def __init__(
self,
optimizer: Optimizer,
init_lr: float,
max_lr: float,
min_lr: float,
lr_warmup_steps: int,
lr_decay_steps: int,
lr_decay_style: str,
start_wd: float,
end_wd: float,
wd_incr_steps: int,
wd_incr_style: str,
use_checkpoint_opt_param_scheduler: Optional[bool] = True,
override_opt_param_scheduler: Optional[bool] = False,
wsd_decay_steps: Optional[int] = None,
lr_wsd_decay_style: Optional[str] = None,
) -> None:
"""
Constructor for OptimizerParamScheduler.
Args:
optimizer (Optimizer): the optimizer to be used
init_lr (float): initial learning rate
max_lr (float): maximum learning rate
min_lr (float): minimum learning rate
lr_warmup_steps (int): number of warmup steps
lr_decay_steps (int): number of decay steps
lr_decay_style (str): decay style for learning rate
start_wd (float): initial weight decay
end_wd (float): final weight decay
wd_incr_steps (int): number of weight decay increment steps
wd_incr_style (str): weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values
for the optimizer param scheduler. Defaults to True.
override_opt_param_scheduler (bool, optional): whether to override the optimizer param
scheduler values with the class values. Defaults to False.
wsd_decay_steps (int, optional): number of weight decay decay steps. Defaults to None.
lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay
steps. Defaults to None.
"""
# Class values.
self.optimizer = optimizer
self.init_lr = init_lr
self.max_lr = float(max_lr)
self.min_lr = min_lr
assert self.min_lr >= 0.0
assert self.max_lr >= self.min_lr
assert self.init_lr <= self.max_lr
self.lr_warmup_steps = lr_warmup_steps
self.num_steps = 0
self.lr_decay_steps = lr_decay_steps
self.wsd_decay_steps = wsd_decay_steps
self.lr_wsd_decay_style = lr_wsd_decay_style
assert self.lr_decay_steps > 0
assert self.lr_warmup_steps < self.lr_decay_steps
self.lr_decay_style = lr_decay_style
if self.lr_decay_style == "WSD":
assert self.wsd_decay_steps is not None
self.start_wd = start_wd
self.end_wd = end_wd
assert self.start_wd >= 0.0
assert self.end_wd >= self.start_wd
self.wd_incr_steps = wd_incr_steps
self.wd_incr_style = wd_incr_style
self.override_opt_param_scheduler = override_opt_param_scheduler
self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler
if self.override_opt_param_scheduler:
assert not self.use_checkpoint_opt_param_scheduler, "both override and " "use-checkpoint are set."
# Set the learning rate
self.step(0)
logger.info("learning rate decay style: {}".format(self.lr_decay_style))
[docs]
def get_wd(self) -> float:
"""
Weight decay incr functions.
"""
if self.num_steps > self.wd_incr_steps:
return self.end_wd
if self.wd_incr_style == "constant":
assert self.start_wd == self.end_wd
return self.end_wd
incr_ratio = float(self.num_steps) / float(self.wd_incr_steps)
assert incr_ratio >= 0.0
assert incr_ratio <= 1.0
delta_wd = self.end_wd - self.start_wd
if self.wd_incr_style == "linear":
coeff = incr_ratio
elif self.wd_incr_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0)
else:
raise Exception(f"{self.wd_incr_style} weight decay increment style is not supported.")
return self.start_wd + coeff * delta_wd
[docs]
def get_lr(self, param_group: dict) -> float:
"""
Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4.
Argsa:
param_group (dict): parameter group from the optimizer.
"""
max_lr = param_group.get("max_lr", self.max_lr)
min_lr = param_group.get("min_lr", self.min_lr)
# Use linear warmup for the initial part.
if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps:
return self.init_lr + ((max_lr - self.init_lr) * float(self.num_steps) / float(self.lr_warmup_steps))
# If the learning rate is constant, just return the initial value.
if self.lr_decay_style == "constant":
return max_lr
# For any steps larger than `self.lr_decay_steps`, use `min_lr`.
if self.num_steps > self.lr_decay_steps:
return min_lr
# If we are done with the warmup period, use the decay style.
if self.lr_decay_style == "inverse-square-root":
warmup_steps = max(self.lr_warmup_steps, 1)
num_steps = max(self.num_steps, 1)
lr = max_lr * warmup_steps**0.5 / (num_steps**0.5)
return max(min_lr, lr)
num_steps_ = self.num_steps - self.lr_warmup_steps
decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = max_lr - min_lr
coeff = None
if self.lr_decay_style == "linear":
coeff = 1.0 - decay_ratio
elif self.lr_decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
elif self.lr_decay_style == "WSD":
wsd_anneal_start_ = self.lr_decay_steps - self.wsd_decay_steps
if self.num_steps <= wsd_anneal_start_:
coeff = 1.0
else:
wsd_steps = self.num_steps - wsd_anneal_start_
wsd_decay_ratio = float(wsd_steps) / float(self.wsd_decay_steps)
if self.lr_wsd_decay_style == "linear":
coeff = 1.0 - wsd_decay_ratio
elif self.lr_wsd_decay_style == "cosine":
coeff = 0.5 * (math.cos(math.pi * wsd_decay_ratio) + 1.0)
elif self.lr_wsd_decay_style == "exponential":
coeff = (2.0 * math.pow(0.5, wsd_decay_ratio)) - 1.0
elif self.lr_wsd_decay_style == "minus_sqrt":
coeff = 1.0 - math.sqrt(wsd_decay_ratio)
else:
raise Exception(f"{self.lr_decay_style} decay style is not supported.")
assert coeff is not None
return min_lr + coeff * delta_lr
[docs]
def step(self, increment: int) -> None:
"""
Set lr for all parameters groups.
Args:
increment (int): number of steps to increment
"""
self.num_steps += increment
new_wd = self.get_wd()
for param_group in self.optimizer.param_groups:
new_lr = self.get_lr(param_group)
param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0)
param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
[docs]
def state_dict(self) -> dict:
"""
Return the state dict.
"""
state_dict = {
"max_lr": self.max_lr,
"lr_warmup_steps": self.lr_warmup_steps,
"num_steps": self.num_steps,
"lr_decay_style": self.lr_decay_style,
"lr_decay_steps": self.lr_decay_steps,
"min_lr": self.min_lr,
"start_wd": self.start_wd,
"end_wd": self.end_wd,
"wd_incr_style": self.wd_incr_style,
"wd_incr_steps": self.wd_incr_steps,
}
return state_dict
[docs]
def _check_and_set(self, cls_value: float, sd_value: float, name: str) -> float:
"""
Auxiliary function for checking the values in the checkpoint and setting them.
Args:
cls_value (float): class value
sd_value (float): checkpoint value
name (str): name of the parameter
"""
if self.override_opt_param_scheduler:
logger.info("overriding {} value to {}".format(name, cls_value))
return cls_value
if not self.use_checkpoint_opt_param_scheduler:
assert cls_value == sd_value, (
f"OptimizerParamScheduler: class input value {cls_value} and checkpoint"
f"value {sd_value} for {name} do not match"
)
logger.info("using checkpoint value {} for {}".format(sd_value, name))
return sd_value
[docs]
def load_state_dict(self, state_dict: dict) -> None:
"""
Load the state dict.
Args:
state_dict (dict): state dict to be load
"""
if "start_lr" in state_dict:
max_lr_ = state_dict["start_lr"]
else:
max_lr_ = state_dict["max_lr"]
self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate")
self.min_lr = self._check_and_set(self.min_lr, state_dict["min_lr"], "minimum learning rate")
if "warmup_iter" in state_dict:
lr_warmup_steps_ = state_dict["warmup_iter"]
elif "warmup_steps" in state_dict:
lr_warmup_steps_ = state_dict["warmup_steps"]
else:
lr_warmup_steps_ = state_dict["lr_warmup_steps"]
self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations")
if "end_iter" in state_dict:
lr_decay_steps_ = state_dict["end_iter"]
elif "decay_steps" in state_dict:
lr_decay_steps_ = state_dict["decay_steps"]
else:
lr_decay_steps_ = state_dict["lr_decay_steps"]
self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_, "total number of iterations")
if "decay_style" in state_dict:
lr_decay_style_ = state_dict["decay_style"]
else:
lr_decay_style_ = state_dict["lr_decay_style"]
self.lr_decay_style = self._check_and_set(self.lr_decay_style, lr_decay_style_, "learning rate decay style")
if "num_iters" in state_dict:
num_steps = state_dict["num_iters"]
else:
num_steps = state_dict["num_steps"]
self.step(increment=num_steps)
if "start_wd" in state_dict:
self.start_wd = self._check_and_set(self.start_wd, state_dict["start_wd"], "start weight decay")
self.end_wd = self._check_and_set(self.end_wd, state_dict["end_wd"], "end weight decay")
self.wd_incr_steps = self._check_and_set(
self.wd_incr_steps,
state_dict["wd_incr_steps"],
"total number of weight decay iterations",
)
self.wd_incr_style = self._check_and_set(
self.wd_incr_style, state_dict["wd_incr_style"], "weight decay incr style"
)