Source code for nemo_automodel.optim.scheduler

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.

"""Learning rate decay and weight decay incr functions."""

import logging
import math
from typing import Optional

from torch.optim.optimizer import Optimizer


logger = logging.getLogger(__name__)


[docs] class OptimizerParamScheduler: """ Anneals learning rate and weight decay. Args: optimizer (Optimizer): the optimizer to be used init_lr (float): initial learning rate max_lr (float): maximum learning rate min_lr (float): minimum learning rate lr_warmup_steps (int): number of warmup steps lr_decay_steps (int): number of decay steps lr_decay_style (str): decay style for learning rate start_wd (float): initial weight decay end_wd (float): final weight decay wd_incr_steps (int): number of weight decay increment steps wd_incr_style (str): weight decay increment style use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values for the optimizer param scheduler. Defaults to True. override_opt_param_scheduler (bool, optional): whether to override the optimizer param scheduler values with the class values. Defaults to False. wsd_decay_steps (int, optional): number of weight decay decay steps. Defaults to None. lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay steps. Defaults to None. """ def __init__( self, optimizer: Optimizer, init_lr: float, max_lr: float, min_lr: float, lr_warmup_steps: int, lr_decay_steps: int, lr_decay_style: str, start_wd: float, end_wd: float, wd_incr_steps: int, wd_incr_style: str, use_checkpoint_opt_param_scheduler: Optional[bool] = True, override_opt_param_scheduler: Optional[bool] = False, wsd_decay_steps: Optional[int] = None, lr_wsd_decay_style: Optional[str] = None, ) -> None: """ Constructor for OptimizerParamScheduler. Args: optimizer (Optimizer): the optimizer to be used init_lr (float): initial learning rate max_lr (float): maximum learning rate min_lr (float): minimum learning rate lr_warmup_steps (int): number of warmup steps lr_decay_steps (int): number of decay steps lr_decay_style (str): decay style for learning rate start_wd (float): initial weight decay end_wd (float): final weight decay wd_incr_steps (int): number of weight decay increment steps wd_incr_style (str): weight decay increment style use_checkpoint_opt_param_scheduler (bool, optional): whether to use the checkpoint values for the optimizer param scheduler. Defaults to True. override_opt_param_scheduler (bool, optional): whether to override the optimizer param scheduler values with the class values. Defaults to False. wsd_decay_steps (int, optional): number of weight decay decay steps. Defaults to None. lr_wsd_decay_style (str, optional): decay style for learning rate during weight decay decay steps. Defaults to None. """ # Class values. self.optimizer = optimizer self.init_lr = init_lr self.max_lr = float(max_lr) self.min_lr = min_lr assert self.min_lr >= 0.0 assert self.max_lr >= self.min_lr assert self.init_lr <= self.max_lr self.lr_warmup_steps = lr_warmup_steps self.num_steps = 0 self.lr_decay_steps = lr_decay_steps self.wsd_decay_steps = wsd_decay_steps self.lr_wsd_decay_style = lr_wsd_decay_style assert self.lr_decay_steps > 0 assert self.lr_warmup_steps < self.lr_decay_steps self.lr_decay_style = lr_decay_style if self.lr_decay_style == "WSD": assert self.wsd_decay_steps is not None self.start_wd = start_wd self.end_wd = end_wd assert self.start_wd >= 0.0 assert self.end_wd >= self.start_wd self.wd_incr_steps = wd_incr_steps self.wd_incr_style = wd_incr_style self.override_opt_param_scheduler = override_opt_param_scheduler self.use_checkpoint_opt_param_scheduler = use_checkpoint_opt_param_scheduler if self.override_opt_param_scheduler: assert not self.use_checkpoint_opt_param_scheduler, "both override and " "use-checkpoint are set." # Set the learning rate self.step(0) logger.info("learning rate decay style: {}".format(self.lr_decay_style))
[docs] def get_wd(self) -> float: """ Weight decay incr functions. """ if self.num_steps > self.wd_incr_steps: return self.end_wd if self.wd_incr_style == "constant": assert self.start_wd == self.end_wd return self.end_wd incr_ratio = float(self.num_steps) / float(self.wd_incr_steps) assert incr_ratio >= 0.0 assert incr_ratio <= 1.0 delta_wd = self.end_wd - self.start_wd if self.wd_incr_style == "linear": coeff = incr_ratio elif self.wd_incr_style == "cosine": coeff = 0.5 * (math.cos(math.pi * (1 - incr_ratio)) + 1.0) else: raise Exception(f"{self.wd_incr_style} weight decay increment style is not supported.") return self.start_wd + coeff * delta_wd
[docs] def get_lr(self, param_group: dict) -> float: """ Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4. Argsa: param_group (dict): parameter group from the optimizer. """ max_lr = param_group.get("max_lr", self.max_lr) min_lr = param_group.get("min_lr", self.min_lr) # Use linear warmup for the initial part. if self.lr_warmup_steps > 0 and self.num_steps <= self.lr_warmup_steps: return self.init_lr + ((max_lr - self.init_lr) * float(self.num_steps) / float(self.lr_warmup_steps)) # If the learning rate is constant, just return the initial value. if self.lr_decay_style == "constant": return max_lr # For any steps larger than `self.lr_decay_steps`, use `min_lr`. if self.num_steps > self.lr_decay_steps: return min_lr # If we are done with the warmup period, use the decay style. if self.lr_decay_style == "inverse-square-root": warmup_steps = max(self.lr_warmup_steps, 1) num_steps = max(self.num_steps, 1) lr = max_lr * warmup_steps**0.5 / (num_steps**0.5) return max(min_lr, lr) num_steps_ = self.num_steps - self.lr_warmup_steps decay_steps_ = self.lr_decay_steps - self.lr_warmup_steps decay_ratio = float(num_steps_) / float(decay_steps_) assert decay_ratio >= 0.0 assert decay_ratio <= 1.0 delta_lr = max_lr - min_lr coeff = None if self.lr_decay_style == "linear": coeff = 1.0 - decay_ratio elif self.lr_decay_style == "cosine": coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0) elif self.lr_decay_style == "WSD": wsd_anneal_start_ = self.lr_decay_steps - self.wsd_decay_steps if self.num_steps <= wsd_anneal_start_: coeff = 1.0 else: wsd_steps = self.num_steps - wsd_anneal_start_ wsd_decay_ratio = float(wsd_steps) / float(self.wsd_decay_steps) if self.lr_wsd_decay_style == "linear": coeff = 1.0 - wsd_decay_ratio elif self.lr_wsd_decay_style == "cosine": coeff = 0.5 * (math.cos(math.pi * wsd_decay_ratio) + 1.0) elif self.lr_wsd_decay_style == "exponential": coeff = (2.0 * math.pow(0.5, wsd_decay_ratio)) - 1.0 elif self.lr_wsd_decay_style == "minus_sqrt": coeff = 1.0 - math.sqrt(wsd_decay_ratio) else: raise Exception(f"{self.lr_decay_style} decay style is not supported.") assert coeff is not None return min_lr + coeff * delta_lr
[docs] def step(self, increment: int) -> None: """ Set lr for all parameters groups. Args: increment (int): number of steps to increment """ self.num_steps += increment new_wd = self.get_wd() for param_group in self.optimizer.param_groups: new_lr = self.get_lr(param_group) param_group["lr"] = new_lr * param_group.get("lr_mult", 1.0) param_group["weight_decay"] = new_wd * param_group.get("wd_mult", 1.0)
[docs] def state_dict(self) -> dict: """ Return the state dict. """ state_dict = { "max_lr": self.max_lr, "lr_warmup_steps": self.lr_warmup_steps, "num_steps": self.num_steps, "lr_decay_style": self.lr_decay_style, "lr_decay_steps": self.lr_decay_steps, "min_lr": self.min_lr, "start_wd": self.start_wd, "end_wd": self.end_wd, "wd_incr_style": self.wd_incr_style, "wd_incr_steps": self.wd_incr_steps, } return state_dict
[docs] def _check_and_set(self, cls_value: float, sd_value: float, name: str) -> float: """ Auxiliary function for checking the values in the checkpoint and setting them. Args: cls_value (float): class value sd_value (float): checkpoint value name (str): name of the parameter """ if self.override_opt_param_scheduler: logger.info("overriding {} value to {}".format(name, cls_value)) return cls_value if not self.use_checkpoint_opt_param_scheduler: assert cls_value == sd_value, ( f"OptimizerParamScheduler: class input value {cls_value} and checkpoint" f"value {sd_value} for {name} do not match" ) logger.info("using checkpoint value {} for {}".format(sd_value, name)) return sd_value
[docs] def load_state_dict(self, state_dict: dict) -> None: """ Load the state dict. Args: state_dict (dict): state dict to be load """ if "start_lr" in state_dict: max_lr_ = state_dict["start_lr"] else: max_lr_ = state_dict["max_lr"] self.max_lr = self._check_and_set(self.max_lr, max_lr_, "learning rate") self.min_lr = self._check_and_set(self.min_lr, state_dict["min_lr"], "minimum learning rate") if "warmup_iter" in state_dict: lr_warmup_steps_ = state_dict["warmup_iter"] elif "warmup_steps" in state_dict: lr_warmup_steps_ = state_dict["warmup_steps"] else: lr_warmup_steps_ = state_dict["lr_warmup_steps"] self.lr_warmup_steps = self._check_and_set(self.lr_warmup_steps, lr_warmup_steps_, "warmup iterations") if "end_iter" in state_dict: lr_decay_steps_ = state_dict["end_iter"] elif "decay_steps" in state_dict: lr_decay_steps_ = state_dict["decay_steps"] else: lr_decay_steps_ = state_dict["lr_decay_steps"] self.lr_decay_steps = self._check_and_set(self.lr_decay_steps, lr_decay_steps_, "total number of iterations") if "decay_style" in state_dict: lr_decay_style_ = state_dict["decay_style"] else: lr_decay_style_ = state_dict["lr_decay_style"] self.lr_decay_style = self._check_and_set(self.lr_decay_style, lr_decay_style_, "learning rate decay style") if "num_iters" in state_dict: num_steps = state_dict["num_iters"] else: num_steps = state_dict["num_steps"] self.step(increment=num_steps) if "start_wd" in state_dict: self.start_wd = self._check_and_set(self.start_wd, state_dict["start_wd"], "start weight decay") self.end_wd = self._check_and_set(self.end_wd, state_dict["end_wd"], "end weight decay") self.wd_incr_steps = self._check_and_set( self.wd_incr_steps, state_dict["wd_incr_steps"], "total number of weight decay iterations", ) self.wd_incr_style = self._check_and_set( self.wd_incr_style, state_dict["wd_incr_style"], "weight decay incr style" )