core.dist_checkpointing.validation#

Module Contents#

Classes#

StrictHandling

Determines handling of load mismatch (non-empty “unexpected” or “missing” keys).

Functions#

parse_strict_flag

Parse user passed strict flag from a string to StrictHandling instance.

validate_integrity_and_strict_load

Validates sharding integrity and potential mismatches with the checkpoint.

verify_checkpoint_and_load_strategy

Verifies if checkpoint metadata exists and matches given strategies.

adjust_non_strict_load

Adjusts sharded state dict removing keys not existing in the checkpoint.

_determine_missing_and_unexpected_keys

Determines load mismatches based on metadata.

maybe_report_missing_and_unexpected_keys

Raises or logs an error in case missing or unexpected keys are non-empty.

_validate_common_state_dict

Validate consistancy across ranks for the common state dict

validate_sharding_integrity

Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.

_validate_sharding_for_key

_compute_shards_access

_validate_objects_for_key

Ensure uniqueness of saved objects.

determine_global_metadata

Exchanges local metadata with all_gather_object to determine global metadata.

validate_sharded_objects_handling

Checks if either of the passed strategies can handle sharded objects.

Data#

API#

core.dist_checkpointing.validation.logger#

‘getLogger(…)’

core.dist_checkpointing.validation._LocalMetadata#

None

core.dist_checkpointing.validation._GlobalMetadata#

None

class core.dist_checkpointing.validation.StrictHandling(*args, **kwds)#

Bases: enum.Enum

Determines handling of load mismatch (non-empty “unexpected” or “missing” keys).

Different flags carry different implications on performance and behaviour and are divided into two groups:

  • *_UNEXPECTED

  • *_ALL The first group ignores missing keys (present in the checkpoint but missing in the sharded state dict) which is created in order to avoid inter-rank metadata exchange. Note that the metadata exchange will happen anyway with load(..., validate_access_integrity=True) flag in which case using the *_ALL option is recommended as it provides a more thorough check with no performance penalty wrt. *_UNEXPECTED group.

All options except for the first one (ASSUME_OK_UNEXPECTED) require extra disk access before the load in order to remove unexpected keys from the sharded state dict requested to load.

Initialization

ASSUME_OK_UNEXPECTED#

‘assume_ok_unexpected’

LOG_UNEXPECTED#

‘log_unexpected’

LOG_ALL#

‘log_all’

RAISE_UNEXPECTED#

‘raise_unexpected’

RAISE_ALL#

‘raise_all’

RETURN_UNEXPECTED#

‘return_unexpected’

RETURN_ALL#

‘return_all’

IGNORE_ALL#

‘ignore_all’

static requires_explicit_ckpt_mismatch_check(
val: core.dist_checkpointing.validation.StrictHandling,
) bool#

Whether a given strict flag involves mismatch check against the checkpoint.

static requires_global_app_metadata(
val: core.dist_checkpointing.validation.StrictHandling,
) bool#

Whether a given strict option requires global metadata for validation.

static requires_returning_mismatch_keys(
val: core.dist_checkpointing.validation.StrictHandling,
) bool#

Whether a given strict option results in extra return value from the load function.

core.dist_checkpointing.validation.parse_strict_flag(
strict: Union[str, core.dist_checkpointing.validation.StrictHandling],
) core.dist_checkpointing.validation.StrictHandling#

Parse user passed strict flag from a string to StrictHandling instance.

Parameters:

strict (str, StrictHandling) – strict flag to parse. If already an instance of StrictHandling, this function is a noop.

Returns:

enum instance

Return type:

StrictHandling

core.dist_checkpointing.validation.validate_integrity_and_strict_load(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
strict: core.dist_checkpointing.validation.StrictHandling,
validate_access_integrity: bool,
local_metadata: Optional[core.dist_checkpointing.validation._LocalMetadata] = None,
global_metadata: Optional[core.dist_checkpointing.validation._GlobalMetadata] = None,
ckpt_sharded_metadata: Optional[megatron.core.dist_checkpointing.serialization.CkptShardedMetadata] = None,
) Tuple[megatron.core.dist_checkpointing.mapping.ShardedStateDict, Set[str], Set[str]]#

Validates sharding integrity and potential mismatches with the checkpoint.

validate_access_integrity controls sharding integrity check (orthogonal to strictness checking) which verifies sharded_state_dict runtime completeness (in isolation from the actual checkpoint).

strict flag controls handling of mismatches between the requested sharded state dict to load and the actual checkpoint. See StrictHandling docs for details regarding flag behavior and performance implications (disk interactions or inter-rank communication).

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to verify.

  • strict (StrictHandling) – flag determining how to handle sharded keys mismatch.

  • validate_access_integrity (bool) – whether to perform sharding validation.

  • local_metadata (_LocalMetadata, optional) – local sharded state dict metadata. Defaults to None, in which case it’s determined based on sharded_state_dict.

  • global_metadata (_GlobalMetadata, optional) – global sharded state dict metadata (exchanged between ranks). Defaults to None, in which case “missing” keys are not determined.

  • ckpt_sharded_metadata (CkptShardedMetadata, optional) – sharded metadata from the checkpoint. Defaults to None, which only makes sense for the StrictHandling.ASSUME_OK_UNEXPECTED strict value.

Returns:

tuple of: sharded state dict without unexpected keys, missing and unexpected keys. Missing keys are equal on all ranks, unexpected keys might differ across ranks. Additionally, missing keys might be erroneously empty (depending on strict value).

Return type:

Tuple[ShardedStateDict, Set[str], Set[str]]

core.dist_checkpointing.validation.verify_checkpoint_and_load_strategy(
checkpoint_dir: str,
sharded_strategy: Union[megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy, Tuple[str, int], None] = None,
common_strategy: Union[megatron.core.dist_checkpointing.strategies.base.LoadCommonStrategy, Tuple[str, int], None] = None,
) Tuple[megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy, megatron.core.dist_checkpointing.strategies.base.LoadCommonStrategy]#

Verifies if checkpoint metadata exists and matches given strategies.

If no strategies are passed, they are determined based on the checkpoint metadata.

Parameters:
  • checkpoint_dir (str) – checkpoint directory

  • sharded_strategy (LoadShardedStrategy, Tuple[str, int], optional) – sharded load strategy to be verified if compatible with the checkpoint content. If None, the default sharded load strategy for the checkpoint backend will be returned.

  • common_strategy (LoadCommonStrategy, Tuple[str, int], optional) – common load strategy to be verified if compatible with the checkpoint content. If None, the default common load strategy for the checkpoint backend will be returned.

core.dist_checkpointing.validation.adjust_non_strict_load(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
sharded_keys_to_remove: Set[str],
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Adjusts sharded state dict removing keys not existing in the checkpoint.

Parameters:
  • sharded_state_dict (ShardedStateDict) – sharded state dict to modify

  • sharded_keys_to_remove (Set[str]) – keys to remove from the state dict

Returns:

state dict without ShardedBase objects with specified keys

Return type:

ShardedStateDict

core.dist_checkpointing.validation._determine_missing_and_unexpected_keys(
ckpt_sharded_metadata: megatron.core.dist_checkpointing.serialization.CkptShardedMetadata,
local_metadata: core.dist_checkpointing.validation._LocalMetadata,
global_metadata: Optional[core.dist_checkpointing.validation._GlobalMetadata] = None,
) Tuple[Set[str], Set[str]]#

Determines load mismatches based on metadata.

There is an asymmetry between “unexpected” and “missing” keys. Unexpected keys can be determined based only on local metadata. Missing keys must be based on global metadata, since other ranks might access different keys than the current rank. In consequence, the return value of this function is different on each rank: “missing_keys” are equal, but “unexpected_keys” might differ across ranks.

Parameters:
  • ckpt_sharded_metadata (CkptShardedMetadata) – sharded state dict (without data) constructed based on the checkpoint content

  • local_metadata (_LocalMetadata) – list of local ShardedBase objects requested to be loaded by this rank

  • global_metadata (_GlobalMetadata, optional) – list of global ShardedBase objects requested to be loaded by all ranks. Defaults to None, in which case returned “missing” keys are empty.

Returns:

missing and unexpected keys. Missing keys are equal on all ranks, unexpected keys might differ across ranks. If passed global_metadata is empty, returned missing keys are empty as well.

Return type:

Tuple[Set[str], Set[str]]

core.dist_checkpointing.validation.maybe_report_missing_and_unexpected_keys(
missing_keys: Set[str],
unexpected_keys: Set[str],
raise_error: bool = True,
) None#

Raises or logs an error in case missing or unexpected keys are non-empty.

Parameters:
  • missing_keys (Set[str]) – missing keys in the state dict

  • unexpected_keys (Set[str]) – unexpected keys in the state dict

  • raise_error – If True, raises error on mismatch. Otherwise, logs mismatch with WARNING level.

Returns:

None

Raises:
core.dist_checkpointing.validation._validate_common_state_dict(
common_state_dict: megatron.core.dist_checkpointing.mapping.CommonStateDict,
) None#

Validate consistancy across ranks for the common state dict

We save the common state dict only on rank 0. We validate to make sure that the common dict is consistent across ranks before saving.

Parameters:

common_state_dict – The common state dict present in all ransk

core.dist_checkpointing.validation.validate_sharding_integrity(
global_metadata: core.dist_checkpointing.validation._GlobalMetadata,
common_state_dict: megatron.core.dist_checkpointing.mapping.CommonStateDict = None,
) None#

Validate if the ShardedTensors and ShardedObjects from multiple processes define correct sharding.

Local ShardedTensors and ShardedObject metadata is exchanged with torch.distributed.all_gather_object and then process with global rank 0 checks if main replicas of the shards:

  • cover the whole global tensors

  • don’t overlap

Parameters:
  • global_metadata (_GlobalMetadata) – ShardedTensor and ShardedObject objects from all ranks.

  • common_state_dict (CommonStateDict) – The common state dict stored by rank 0

Returns:

None

Raises:

CheckpointingException for invalid access pattern

core.dist_checkpointing.validation._validate_sharding_for_key(
rank_sharding: List[Tuple[int, megatron.core.dist_checkpointing.ShardedTensor]],
)#
core.dist_checkpointing.validation._compute_shards_access(rank_sharding)#
core.dist_checkpointing.validation._validate_objects_for_key(
sharded_objects: List[megatron.core.dist_checkpointing.mapping.ShardedObject],
)#

Ensure uniqueness of saved objects.

core.dist_checkpointing.validation.determine_global_metadata(
sharded_state_dict: megatron.core.dist_checkpointing.mapping.ShardedStateDict,
) Tuple[core.dist_checkpointing.validation._LocalMetadata, core.dist_checkpointing.validation._GlobalMetadata]#

Exchanges local metadata with all_gather_object to determine global metadata.

Parameters:

sharded_state_dict (ShardedStateDict) – local sharded state dict

Returns:

local and global ShardedBase objects with stripped data

Return type:

Tuple[_LocalMetadata, _GlobalMetadata]

core.dist_checkpointing.validation.validate_sharded_objects_handling(
sharded_strategy: Union[megatron.core.dist_checkpointing.strategies.base.SaveShardedStrategy, megatron.core.dist_checkpointing.strategies.base.LoadShardedStrategy],
common_strategy: Union[megatron.core.dist_checkpointing.strategies.base.SaveCommonStrategy, megatron.core.dist_checkpointing.strategies.base.LoadCommonStrategy],
) None#

Checks if either of the passed strategies can handle sharded objects.

Parameters:
Returns:

None

Raises:

CheckpointingException – if both strategies can’t handle ShardedObjects