core.dist_checkpointing.validation#
Module Contents#
Classes#
Determines handling of load mismatch (non-empty “unexpected” or “missing” keys). |
Functions#
Parse user passed strict flag from a string to StrictHandling instance. |
|
Validates sharding integrity and potential mismatches with the checkpoint. |
|
Verifies if checkpoint metadata exists and matches given strategies. |
|
Adjusts sharded state dict removing keys not existing in the checkpoint. |
|
Determines load mismatches based on metadata. |
|
Raises or logs an error in case missing or unexpected keys are non-empty. |
|
Validate consistancy across ranks for the common state dict |
|
Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding. |
|
Ensure uniqueness of saved objects. |
|
Exchanges local metadata with |
|
Checks if either of the passed strategies can handle sharded objects. |
Data#
API#
- core.dist_checkpointing.validation.logger#
‘getLogger(…)’
- core.dist_checkpointing.validation._LocalMetadata#
None
- core.dist_checkpointing.validation._GlobalMetadata#
None
- class core.dist_checkpointing.validation.StrictHandling(*args, **kwds)#
Bases:
enum.EnumDetermines handling of load mismatch (non-empty “unexpected” or “missing” keys).
Different flags carry different implications on performance and behaviour and are divided into two groups:
*_UNEXPECTED
*_ALL The first group ignores missing keys (present in the checkpoint but missing in the sharded state dict) which is created in order to avoid inter-rank metadata exchange. Note that the metadata exchange will happen anyway with
load(..., validate_access_integrity=True)flag in which case using the*_ALLoption is recommended as it provides a more thorough check with no performance penalty wrt.*_UNEXPECTEDgroup.
All options except for the first one (
ASSUME_OK_UNEXPECTED) require extra disk access before the load in order to remove unexpected keys from the sharded state dict requested to load.Initialization
- ASSUME_OK_UNEXPECTED#
‘assume_ok_unexpected’
- LOG_UNEXPECTED#
‘log_unexpected’
- LOG_ALL#
‘log_all’
- RAISE_UNEXPECTED#
‘raise_unexpected’
- RAISE_ALL#
‘raise_all’
- RETURN_UNEXPECTED#
‘return_unexpected’
- RETURN_ALL#
‘return_all’
- IGNORE_ALL#
‘ignore_all’
- static requires_explicit_ckpt_mismatch_check( ) bool#
Whether a given strict flag involves mismatch check against the checkpoint.
- static requires_global_app_metadata( ) bool#
Whether a given strict option requires global metadata for validation.
- static requires_returning_mismatch_keys( ) bool#
Whether a given strict option results in extra return value from the
loadfunction.
- core.dist_checkpointing.validation.parse_strict_flag(
- strict: Union[str, core.dist_checkpointing.validation.StrictHandling],
Parse user passed strict flag from a string to StrictHandling instance.
- Parameters:
strict (str, StrictHandling) – strict flag to parse. If already an instance of StrictHandling, this function is a noop.
- Returns:
enum instance
- Return type:
- core.dist_checkpointing.validation.validate_integrity_and_strict_load(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- strict: core.dist_checkpointing.validation.StrictHandling,
- validate_access_integrity: bool,
- local_metadata: Optional[core.dist_checkpointing.validation._LocalMetadata] = None,
- global_metadata: Optional[core.dist_checkpointing.validation._GlobalMetadata] = None,
- ckpt_sharded_metadata: Optional[megatron.core.dist_checkpointing.serialization.CkptShardedMetadata] = None,
Validates sharding integrity and potential mismatches with the checkpoint.
validate_access_integritycontrols sharding integrity check (orthogonal to strictness checking) which verifiessharded_state_dictruntime completeness (in isolation from the actual checkpoint).strictflag controls handling of mismatches between the requested sharded state dict to load and the actual checkpoint. SeeStrictHandlingdocs for details regarding flag behavior and performance implications (disk interactions or inter-rank communication).- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to verify.
strict (StrictHandling) – flag determining how to handle sharded keys mismatch.
validate_access_integrity (bool) – whether to perform sharding validation.
local_metadata (_LocalMetadata, optional) – local sharded state dict metadata. Defaults to None, in which case it’s determined based on
sharded_state_dict.global_metadata (_GlobalMetadata, optional) – global sharded state dict metadata (exchanged between ranks). Defaults to None, in which case “missing” keys are not determined.
ckpt_sharded_metadata (CkptShardedMetadata, optional) – sharded metadata from the checkpoint. Defaults to None, which only makes sense for the
StrictHandling.ASSUME_OK_UNEXPECTEDstrict value.
- Returns:
tuple of: sharded state dict without unexpected keys, missing and unexpected keys. Missing keys are equal on all ranks, unexpected keys might differ across ranks. Additionally, missing keys might be erroneously empty (depending on
strictvalue).- Return type:
Tuple[ShardedStateDict, Set[str], Set[str]]
- core.dist_checkpointing.validation.verify_checkpoint_and_load_strategy(
- checkpoint_dir: str,
- sharded_strategy: Union[megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int], None] = None,
- common_strategy: Union[megatron.core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int], None] = None,
Verifies if checkpoint metadata exists and matches given strategies.
If no strategies are passed, they are determined based on the checkpoint metadata.
- Parameters:
checkpoint_dir (str) – checkpoint directory
sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – sharded load strategy to be verified if compatible with the checkpoint content. If None, the default sharded load strategy for the checkpoint backend will be returned.
common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – common load strategy to be verified if compatible with the checkpoint content. If None, the default common load strategy for the checkpoint backend will be returned.
- core.dist_checkpointing.validation.adjust_non_strict_load(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
- sharded_keys_to_remove: Set[str],
Adjusts sharded state dict removing keys not existing in the checkpoint.
- Parameters:
sharded_state_dict (ShardedStateDict) – sharded state dict to modify
sharded_keys_to_remove (Set[str]) – keys to remove from the state dict
- Returns:
state dict without ShardedBase objects with specified keys
- Return type:
ShardedStateDict
- core.dist_checkpointing.validation._determine_missing_and_unexpected_keys(
- ckpt_sharded_metadata: megatron.core.dist_checkpointing.serialization.CkptShardedMetadata,
- local_metadata: core.dist_checkpointing.validation._LocalMetadata,
- global_metadata: Optional[core.dist_checkpointing.validation._GlobalMetadata] = None,
Determines load mismatches based on metadata.
There is an asymmetry between “unexpected” and “missing” keys. Unexpected keys can be determined based only on local metadata. Missing keys must be based on global metadata, since other ranks might access different keys than the current rank. In consequence, the return value of this function is different on each rank: “missing_keys” are equal, but “unexpected_keys” might differ across ranks.
- Parameters:
ckpt_sharded_metadata (CkptShardedMetadata) – sharded state dict (without data) constructed based on the checkpoint content
local_metadata (_LocalMetadata) – list of local ShardedBase objects requested to be loaded by this rank
global_metadata (_GlobalMetadata, optional) – list of global ShardedBase objects requested to be loaded by all ranks. Defaults to None, in which case returned “missing” keys are empty.
- Returns:
missing and unexpected keys. Missing keys are equal on all ranks, unexpected keys might differ across ranks. If passed
global_metadatais empty, returned missing keys are empty as well.- Return type:
Tuple[Set[str], Set[str]]
- core.dist_checkpointing.validation.maybe_report_missing_and_unexpected_keys(
- missing_keys: Set[str],
- unexpected_keys: Set[str],
- raise_error: bool = True,
Raises or logs an error in case missing or unexpected keys are non-empty.
- Parameters:
missing_keys (Set[str]) – missing keys in the state dict
unexpected_keys (Set[str]) – unexpected keys in the state dict
raise_error – If True, raises error on mismatch. Otherwise, logs mismatch with WARNING level.
- Returns:
None
- Raises:
CheckpointingException – if
raise_erroris True and at least one ofmissing_keys –
- core.dist_checkpointing.validation._validate_common_state_dict(
- common_state_dict: megatron.core.dist_checkpointing.mapping.CommonStateDict,
Validate consistancy across ranks for the common state dict
We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.
- Parameters:
common_state_dict – The common state dict present in all ransk
- core.dist_checkpointing.validation.validate_sharding_integrity(
- global_metadata: core.dist_checkpointing.validation._GlobalMetadata,
- common_state_dict: megatron.core.dist_checkpointing.mapping.CommonStateDict = None,
Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.
Local ShardedTensors and ShardedObject metadata is exchanged with
torch.distributed.all_gather_objectand then process with global rank 0 checks if main replicas of the shards:cover the whole global tensors
don’t overlap
- Parameters:
global_metadata (_GlobalMetadata) – ShardedTensor and ShardedObject objects from all ranks.
common_state_dict (CommonStateDict) – The common state dict stored by rank 0
- Returns:
None
- Raises:
CheckpointingException for invalid access pattern –
- core.dist_checkpointing.validation._validate_sharding_for_key(
- rank_sharding: List[Tuple[int, megatron.core.dist_checkpointing.ShardedTensor]],
- core.dist_checkpointing.validation._compute_shards_access(rank_sharding)#
- core.dist_checkpointing.validation._validate_objects_for_key(
- sharded_objects: List[megatron.core.dist_checkpointing.mapping.ShardedObject],
Ensure uniqueness of saved objects.
- core.dist_checkpointing.validation.determine_global_metadata(
- sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
Exchanges local metadata with
all_gather_objectto determine global metadata.- Parameters:
sharded_state_dict (ShardedStateDict) – local sharded state dict
- Returns:
local and global ShardedBase objects with stripped data
- Return type:
Tuple[_LocalMetadata, _GlobalMetadata]
- core.dist_checkpointing.validation.validate_sharded_objects_handling(
- sharded_strategy: Union[megatron.core.dist_checkpointing.strategies.base.SaveShardedStrategy, megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy],
- common_strategy: Union[megatron.core.dist_checkpointing.strategies.base.SaveCommonStrategy, megatron.core.dist_checkpointing.strategies.base.LoadCommonStrategy],
Checks if either of the passed strategies can handle sharded objects.
- Parameters:
sharded_strategy (Union[SaveShardedStrategy, LoadShardedStrategy]) – sharded strategy used for saving/loading
common_strategy (Union[SaveCommonStrategy, LoadCommonStrategy]) – common strategy used for saving/loading
- Returns:
None
- Raises:
CheckpointingException – if both strategies can’t handle ShardedObjects