Optimizer Parameters Scheduler
This api is used to calculate the learning rate and weight decay for the optimizer.
Learning rate decay and weight decay incr functions.
- class core.optimizer_param_scheduler.OptimizerParamScheduler(optimizer: megatron.core.optimizer.MegatronOptimizer, init_lr: float, max_lr: float, min_lr: float, lr_warmup_steps: int, lr_decay_steps: int, lr_decay_style: str, start_wd: float, end_wd: float, wd_incr_steps: int, wd_incr_style: str, use_checkpoint_opt_param_scheduler: Optional[bool] = True, override_opt_param_scheduler: Optional[bool] = False, wsd_decay_steps: Optional[int] = None, lr_wsd_decay_style: Optional[str] = None)
Bases:
object
Anneals learning rate and weight decay
- Parameters
optimizer (MegatronOptimizer) – the optimizer to be used
init_lr (float) – initial learning rate
max_lr (float) – maximum learning rate
min_lr (float) – minimum learning rate
lr_warmup_steps (int) – number of warmup steps
lr_decay_steps (int) – number of decay steps
lr_decay_style (str) – decay style for learning rate
start_wd (float) – initial weight decay
end_wd (float) – final weight decay
wd_incr_steps (int) – number of weight decay increment steps
wd_incr_style (str) – weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional) – whether to use the checkpoint values for the optimizer param scheduler
override_opt_param_scheduler (bool, optional) – whether to override the optimizer param scheduler values with the class values
wsd_decay_steps (int, optional) – number of weight decay decay steps
lr_wsd_decay_style (str, optional) – decay style for learning rate during weight decay decay steps
- get_lr(param_group: dict) → float
Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4
- Parameters
param_group (dict) – parameter group from the optimizer.
- get_wd() → float
Weight decay incr functions
- load_state_dict(state_dict: dict) → None
Load the state dict.
- Parameters
state_dict (dict) – state dict to be load
- state_dict() → dict
Return the state dict.
- step(increment: int) → None
Set lr for all parameters groups.
- Parameters
increment (int) – number of steps to increment