core.dist_checkpointing.optimizer#

Helpers for defining sharding for optimizer states based on existing sharding for model parameters.

Module Contents#

Functions#

get_optim_param_to_id_map

Generate mapping from optimizer param to optimizer state id.

get_param_id_to_sharded_param_map

Generate mapping from optimizer state ids to model sharded parameters.

make_sharded_optimizer_tensor

Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param

optim_state_to_sharding_state

Turn optimizer state dict to sharded state dict based on model state dict in-place.

Data#

API#

core.dist_checkpointing.optimizer.logger#

‘getLogger(…)’

core.dist_checkpointing.optimizer.KEEP_VARS_HINT#

‘ Make sure state dict contains original torch.nn.Parameters (not pure torch.Tensors) by passing `kee…’

core.dist_checkpointing.optimizer.get_optim_param_to_id_map(
optim_params_iter: Iterable[torch.nn.Parameter],
) Dict[int, int]#

Generate mapping from optimizer param to optimizer state id.

core.dist_checkpointing.optimizer.get_param_id_to_sharded_param_map(
model_sharded_state_dict: core.dist_checkpointing.mapping.ShardedStateDict,
optim_params_iter: Iterable[torch.nn.Parameter],
) Dict[int, Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]]#

Generate mapping from optimizer state ids to model sharded parameters.

Parameters:
  • model_sharded_state_dict – sharded state dict with all model sharded tensors (can have any structure)

  • optim_params_iter – iterable which iterates over model parameters tracked by the optimizer. The iteration must be in the same order as in the optimizer parameters.

Returns:

mapping from optimizer state ids to model sharded parameters.

Return type:

Dict[int, Union[ShardedTensor, ShardedTensorFactory]]

core.dist_checkpointing.optimizer.make_sharded_optimizer_tensor(
model_param: Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory],
optim_param: torch.Tensor,
prefix: str,
) Union[core.dist_checkpointing.mapping.ShardedTensor, core.dist_checkpointing.mapping.ShardedTensorFactory]#

Build a ShardedTensor or ShardedTensorFactory for optimizer param based on model param

Parameters:
  • model_param (Union[ShardedTensor, ShardedTensorFactory]) – model param

  • optim_param (torch.Tensor) – corresponding optimizer param

  • prefix (str) – optimizer prefix for the ShardedTensor or ShardedTensorFactory

Returns:

wrapped optimizer parameter

Return type:

Union[ShardedTensor, ShardedTensorFactory]

core.dist_checkpointing.optimizer.optim_state_to_sharding_state(
optim_state_dict: core.dist_checkpointing.mapping.StateDict,
id_to_sharded_param_map: Dict[int, core.dist_checkpointing.mapping.ShardedTensor],
exclude_keys: Tuple[str] = (),
)#

Turn optimizer state dict to sharded state dict based on model state dict in-place.

Can be used to add sharding information to most common optimizer state dict. Creates separate ShardedTensors for each key in optim_state_dict['state'] (e.g. for torch.optim.Adam there will be separate tensors for exp_avg and exp_avg_sq)

Parameters:
  • optim_state_dict (StateDict) – optimizer state dict with state parameters under state key and group hyperparameters under param_groups -> params key.

  • id_to_sharded_param_map (Dict[int, ShardedTensor]) – mapping from optimizer param ids to model sharded tensors. Can be generated with get_param_id_to_sharded_param_map function.

  • exclude_keys (Tuple[str]) – optimizer state keys to exclude from the final state dict.

Returns:

state dict is modified in place

Return type:

None