core.optimizer_param_scheduler#

Learning rate decay and weight decay incr functions.

Module Contents#

Classes#

ParamGroupOverride

Override values for a parameter group. These values may be optimizer-state/scheduler related.

OptimizerParamScheduler

Anneals learning rate and weight decay

Functions#

param_group_override_to_tuple

Convert a param group override to a tuple for use as a key in a dictionary.

combine_param_group_overrides

Combine a list of param group overrides into a single param group override.

Data#

API#

core.optimizer_param_scheduler.logger#

‘getLogger(…)’

class core.optimizer_param_scheduler.ParamGroupOverride#

Bases: typing.TypedDict

Override values for a parameter group. These values may be optimizer-state/scheduler related.

These are the values you see later in param_group.get(…) calls in the OptimizerParamScheduler.get_lr and get_wd methods. If you use a custom optimizer or scheduler, you could override those variables instead.

.. rubric:: Example

param_group_override = ParamGroupOverride(min_lr=1e-4, wd_mult=0.1) param_group_override == ParamGroupOverride(newvar=3) # this is ok too

Initialization

Initialize self. See help(type(self)) for accurate signature.

max_lr: float#

None

min_lr: float#

None

start_wd: float#

None

end_wd: float#

None

wd_mult: float#

None

core.optimizer_param_scheduler.param_group_override_to_tuple(
param_group_override: core.optimizer_param_scheduler.ParamGroupOverride | None,
) tuple[tuple[str, Any], ...] | None#

Convert a param group override to a tuple for use as a key in a dictionary.

The tuple is sorted by the keys of the param group override to handle different orderings of the keys in different override dictionaries which still mean the same thing.

core.optimizer_param_scheduler.combine_param_group_overrides(
param_group_overrides: list[core.optimizer_param_scheduler.ParamGroupOverride | None],
) core.optimizer_param_scheduler.ParamGroupOverride#

Combine a list of param group overrides into a single param group override.

This function ensures that the overrides are not conflicting as well.

Parameters:

param_group_overrides (list[ParamGroupOverride]) – list of param group overrides to combine

Returns:

combined param group override

Return type:

ParamGroupOverride

class core.optimizer_param_scheduler.OptimizerParamScheduler(
optimizer: megatron.core.optimizer.MegatronOptimizer,
init_lr: float,
max_lr: float,
min_lr: float,
lr_warmup_steps: int,
lr_decay_steps: int,
lr_decay_style: str,
start_wd: float,
end_wd: float,
wd_incr_steps: int,
wd_incr_style: str,
use_checkpoint_opt_param_scheduler: Optional[bool] = True,
override_opt_param_scheduler: Optional[bool] = False,
wsd_decay_steps: Optional[int] = None,
lr_wsd_decay_style: Optional[str] = None,
)#

Anneals learning rate and weight decay

Parameters:
  • optimizer (MegatronOptimizer) – the optimizer to be used

  • init_lr (float) – initial learning rate

  • max_lr (float) – maximum learning rate

  • min_lr (float) – minimum learning rate

  • lr_warmup_steps (int) – number of warmup steps

  • lr_decay_steps (int) – number of decay steps

  • lr_decay_style (str) – decay style for learning rate

  • start_wd (float) – initial weight decay

  • end_wd (float) – final weight decay

  • wd_incr_steps (int) – number of weight decay increment steps

  • wd_incr_style (str) – weight decay increment style

  • use_checkpoint_opt_param_scheduler (bool, optional) – whether to use the checkpoint values for the optimizer param scheduler

  • override_opt_param_scheduler (bool, optional) – whether to override the optimizer param scheduler values with the class values

  • wsd_decay_steps (int, optional) – number of weight decay decay steps

  • lr_wsd_decay_style (str, optional) – decay style for learning rate during weight decay decay steps

Initialization

get_wd(param_group: Optional[dict] = None) float#

Weight decay incr functions

Parameters:

param_group (dict) – parameter group from the optimizer.

get_lr(param_group: dict) float#

Learning rate decay functions from: https://openreview.net/pdf?id=BJYwwY9ll pg. 4

Parameters:

param_group (dict) – parameter group from the optimizer.

step(increment: int) None#

Set lr for all parameters groups.

Parameters:

increment (int) – number of steps to increment

state_dict() dict#

Return the state dict.

_check_and_set(cls_value: float, sd_value: float, name: str) float#

Auxiliary function for checking the values in the checkpoint and setting them.

Parameters:
  • cls_value (float) – class value

  • sd_value (float) – checkpoint value

  • name (str) – name of the parameter

load_state_dict(state_dict: dict) None#

Load the state dict.

Parameters:

state_dict (dict) – state dict to be load