core.optimizer_param_scheduler#
Learning rate decay and weight decay incr functions.
Module Contents#
Classes#
Override values for a parameter group. These values may be optimizer-state/scheduler related. |
|
Anneals learning rate and weight decay |
Functions#
Convert a param group override to a tuple for use as a key in a dictionary. |
|
Combine a list of param group overrides into a single param group override. |
Data#
API#
- core.optimizer_param_scheduler.logger#
‘getLogger(…)’
- class core.optimizer_param_scheduler.ParamGroupOverride#
Bases:
typing.TypedDictOverride values for a parameter group. These values may be optimizer-state/scheduler related.
These are the values you see later in param_group.get(…) calls in the OptimizerParamScheduler.get_lr and get_wd methods. If you use a custom optimizer or scheduler, you could override those variables instead.
.. rubric:: Example
param_group_override = ParamGroupOverride(min_lr=1e-4, wd_mult=0.1) param_group_override == ParamGroupOverride(newvar=3) # this is ok too
Initialization
Initialize self. See help(type(self)) for accurate signature.
- max_lr: float#
None
- min_lr: float#
None
- start_wd: float#
None
- end_wd: float#
None
- wd_mult: float#
None
- core.optimizer_param_scheduler.param_group_override_to_tuple(
- param_group_override: core.optimizer_param_scheduler.ParamGroupOverride | None,
Convert a param group override to a tuple for use as a key in a dictionary.
The tuple is sorted by the keys of the param group override to handle different orderings of the keys in different override dictionaries which still mean the same thing.
- core.optimizer_param_scheduler.combine_param_group_overrides(
- param_group_overrides: list[core.optimizer_param_scheduler.ParamGroupOverride | None],
Combine a list of param group overrides into a single param group override.
This function ensures that the overrides are not conflicting as well.
- Parameters:
param_group_overrides (list[ParamGroupOverride]) – list of param group overrides to combine
- Returns:
combined param group override
- Return type:
- class core.optimizer_param_scheduler.OptimizerParamScheduler(
- optimizer: megatron.core.optimizer.MegatronOptimizer,
- init_lr: float,
- max_lr: float,
- min_lr: float,
- lr_warmup_steps: int,
- lr_decay_steps: int,
- lr_decay_style: str,
- start_wd: float,
- end_wd: float,
- wd_incr_steps: int,
- wd_incr_style: str,
- use_checkpoint_opt_param_scheduler: Optional[bool] = True,
- override_opt_param_scheduler: Optional[bool] = False,
- wsd_decay_steps: Optional[int] = None,
- lr_wsd_decay_style: Optional[str] = None,
Anneals learning rate and weight decay
- Parameters:
optimizer (MegatronOptimizer) – the optimizer to be used
init_lr (float) – initial learning rate
max_lr (float) – maximum learning rate
min_lr (float) – minimum learning rate
lr_warmup_steps (int) – number of warmup steps
lr_decay_steps (int) – number of decay steps
lr_decay_style (str) – decay style for learning rate
start_wd (float) – initial weight decay
end_wd (float) – final weight decay
wd_incr_steps (int) – number of weight decay increment steps
wd_incr_style (str) – weight decay increment style
use_checkpoint_opt_param_scheduler (bool, optional) – whether to use the checkpoint values for the optimizer param scheduler
override_opt_param_scheduler (bool, optional) – whether to override the optimizer param scheduler values with the class values
wsd_decay_steps (int, optional) – number of weight decay decay steps
lr_wsd_decay_style (str, optional) – decay style for learning rate during weight decay decay steps
Initialization
- get_wd(param_group: Optional[dict] = None) float#
Weight decay incr functions
- Parameters:
param_group (dict) – parameter group from the optimizer.
- get_lr(param_group: dict) float#
Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4
- Parameters:
param_group (dict) – parameter group from the optimizer.
- step(increment: int) None#
Set lr for all parameters groups.
- Parameters:
increment (int) – number of steps to increment
- state_dict() dict#
Return the state dict.
- _check_and_set(cls_value: float, sd_value: float, name: str) float#
Auxiliary function for checking the values in the checkpoint and setting them.
- Parameters:
cls_value (float) – class value
sd_value (float) – checkpoint value
name (str) – name of the parameter
- load_state_dict(state_dict: dict) None#
Load the state dict.
- Parameters:
state_dict (dict) – state dict to be load