core.dist_checkpointing.optimizer#
Helpers for defining sharding for optimizer states based on existing sharding for model parameters.
Module Contents#
Functions#
Generate mapping from optimizer param to optimizer state id. |
|
Generate mapping from optimizer state ids to model sharded parameters. |
|
Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param |
|
Turn optimizer state dict to sharded state dict based on model state dict in-place. |
Data#
API#
- core.dist_checkpointing.optimizer.logger#
‘getLogger(…)’
- core.dist_checkpointing.optimizer.KEEP_VARS_HINT#
‘ Make sure state dict contains original torch.nn.Parameters (not pure torch.Tensors) by passing `kee…’
- core.dist_checkpointing.optimizer.get_optim_param_to_id_map(
- optim_params_iter: Iterable[torch.nn.Parameter],
Generate mapping from optimizer param to optimizer state id.
- core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(
- model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
- optim_params_iter: Iterable[torch.nn.Parameter],
Generate mapping from optimizer state ids to model sharded parameters.
- Parameters:
model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)
optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.
- Returns:
mapping from optimizer state ids to model sharded parameters.
- Return type:
Dict[int, Union[ShardedTensor, ShardedTensorFactory]]
- core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(
- model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory],
- optim_param: torch.Tensor,
- prefix: str,
Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param
- Parameters:
model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param
optim_param (torch.Tensor) – corresponding optimizer param
prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory
- Returns:
wrapped optimizer parameter
- Return type:
Union[ShardedTensor, ShardedTensorFactory]
- core.dist_checkpointing.optimizer.optim_state_to_sharding_state(
- optim_state_dict: core.dist_checkpointing.mapping.StateDict,
- id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor],
- exclude_keys: Tuple[str] = (),
Turn optimizer state dict to sharded state dict based on model state dict in-place.
Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in
optim_state_dict['state'](e.g. for torch.optim.Adam there will be separate tensors forexp_avgandexp_avg_sq)- Parameters:
optim_state_dict (StateDict) – optimizer state dict with state parameters under
statekey and group hyperparameters underparam_groups->paramskey.id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with
get_param_id_to_sharded_param_mapfunction.exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.
- Returns:
state dict is modified in place
- Return type:
None