Optimizer Parameters Scheduler#

This api is used to calculate the learning rate and weight decay for the optimizer.

Module contents#

Learning rate decay and weight decay incr functions.

class core.optimizer_param_scheduler.OptimizerParamScheduler(
optimizer: megatron.core.optimizer.MegatronOptimizer,
init_lr: float,
max_lr: float,
min_lr: float,
lr_warmup_steps: int,
lr_decay_steps: int,
lr_decay_style: str,
start_wd: float,
end_wd: float,
wd_incr_steps: int,
wd_incr_style: str,
use_checkpoint_opt_param_scheduler: bool | None = True,
override_opt_param_scheduler: bool | None = False,
wsd_decay_steps: int | None = None,
lr_wsd_decay_style: str | None = None,
)#

Bases: object

Anneals learning rate and weight decay

Parameters:
  • optimizer (MegatronOptimizer) – the optimizer to be used

  • init_lr (float) – initial learning rate

  • max_lr (float) – maximum learning rate

  • min_lr (float) – minimum learning rate

  • lr_warmup_steps (int) – number of warmup steps

  • lr_decay_steps (int) – number of decay steps

  • lr_decay_style (str) – decay style for learning rate

  • start_wd (float) – initial weight decay

  • end_wd (float) – final weight decay

  • wd_incr_steps (int) – number of weight decay increment steps

  • wd_incr_style (str) – weight decay increment style

  • use_checkpoint_opt_param_scheduler (bool, optional) – whether to use the checkpoint values for the optimizer param scheduler

  • override_opt_param_scheduler (bool, optional) – whether to override the optimizer param scheduler values with the class values

  • wsd_decay_steps (int, optional) – number of weight decay decay steps

  • lr_wsd_decay_style (str, optional) – decay style for learning rate during weight decay decay steps

get_lr(param_group: dict) float#

Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4

Parameters:

param_group (dict) – parameter group from the optimizer.

get_wd() float#

Weight decay incr functions

load_state_dict(state_dict: dict) None#

Load the state dict.

Parameters:

state_dict (dict) – state dict to be load

state_dict() dict#

Return the state dict.

step(increment: int) None#

Set lr for all parameters groups.

Parameters:

increment (int) – number of steps to increment