nemo_automodel.checkpoint._backports.default_planner#

Module Contents#

Classes#

DefaultSavePlanner

DefaultLoadPlanner

DefaultLoadPlanner that adds multiple features on top of LoadPlanner.

_EmptyStateDictLoadPlanner

Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. Useful for loading in state_dict without first initializing a model, such as when converting a DCP checkpoint into a Torch save file.

Functions#

create_default_local_load_plan

create_default_global_load_plan

Create global load plan used by DefaultLoadPlanner.

create_default_local_save_plan

Create the SavePlan used by DefaultSavePlanner.

create_default_global_save_plan

Create the global plan and metadata used by DefaultSavePlanner.

_create_default_local_metadata

Return the Metadata if DefaultSavePlanner was used to checkpoint state_dict.

_check_box_overlap

Check if two boxes overlap. Tuples are (offset, lengths).

_check_box_bounds

_validate_global_plan

Data#

API#

nemo_automodel.checkpoint._backports.default_planner.logger: logging.Logger#

‘getLogger(…)’

nemo_automodel.checkpoint._backports.default_planner.__all__#

[‘DefaultSavePlanner’, ‘DefaultLoadPlanner’, ‘create_default_local_load_plan’, ‘create_default_globa…

class nemo_automodel.checkpoint._backports.default_planner.DefaultSavePlanner(
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
dedup_save_to_lowest_rank: bool = False,
enable_plan_caching: bool = False,
)[source]#

Bases: torch.distributed.checkpoint.planner.SavePlanner

mappings: torch.distributed.checkpoint._nested_dict.FLATTEN_MAPPING#

None

set_up_planner(
state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE,
storage_meta: Optional[torch.distributed.checkpoint.metadata.StorageMeta] = None,
is_coordinator: bool = False,
) None[source]#
create_local_plan() torch.distributed.checkpoint.planner.SavePlan[source]#
_dedup_save_plans(
all_plans: list[torch.distributed.checkpoint.planner.SavePlan],
) list[torch.distributed.checkpoint.planner.SavePlan][source]#
_create_global_plan(
all_plans: list[torch.distributed.checkpoint.planner.SavePlan],
) tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata][source]#
_create_global_plan_with_caching(
all_plans: list[torch.distributed.checkpoint.planner.SavePlan],
) tuple[list[torch.distributed.checkpoint.planner.SavePlan], list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata][source]#

Create global plan with caching. Returns a tuple of global_plan_delta, global_plan, metadata.

create_global_plan(
all_plans: list[torch.distributed.checkpoint.planner.SavePlan],
) tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata][source]#
_finish_plan_with_caching(
new_plan: torch.distributed.checkpoint.planner.SavePlan,
) torch.distributed.checkpoint.planner.SavePlan[source]#
finish_plan(
new_plan: torch.distributed.checkpoint.planner.SavePlan,
) torch.distributed.checkpoint.planner.SavePlan[source]#
resolve_data(
write_item: torch.distributed.checkpoint.planner.WriteItem,
) Union[torch.Tensor, io.BytesIO][source]#
lookup_object(
index: torch.distributed.checkpoint.metadata.MetadataIndex,
) Any[source]#

Extension from the planner interface to make it easy to extend the default planner.

transform_object(
write_item: torch.distributed.checkpoint.planner.WriteItem,
object: Any,
)[source]#

Extension from the planner interface to make it easy to extend the default planner.

class nemo_automodel.checkpoint._backports.default_planner.DefaultLoadPlanner(
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
allow_partial_load: bool = False,
)[source]#

Bases: torch.distributed.checkpoint.planner.LoadPlanner

DefaultLoadPlanner that adds multiple features on top of LoadPlanner.

In particular it adds the following:

flatten_state_dict: Handle state_dict with nested dicts flatten_sharded_tensors: For FSDP in 2D parallel mode allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.

Initialization

original_state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE#

None

mappings: torch.distributed.checkpoint._nested_dict.FLATTEN_MAPPING#

None

set_up_planner(
state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE,
metadata: Optional[torch.distributed.checkpoint.metadata.Metadata] = None,
is_coordinator: bool = False,
) None[source]#
create_local_plan() torch.distributed.checkpoint.planner.LoadPlan[source]#
create_global_plan(
global_plan: list[torch.distributed.checkpoint.planner.LoadPlan],
) list[torch.distributed.checkpoint.planner.LoadPlan][source]#
finish_plan(
new_plan: torch.distributed.checkpoint.planner.LoadPlan,
) torch.distributed.checkpoint.planner.LoadPlan[source]#
load_bytes(
read_item: torch.distributed.checkpoint.planner.ReadItem,
value: io.BytesIO,
) None[source]#
resolve_tensor(
read_item: torch.distributed.checkpoint.planner.ReadItem,
)[source]#
commit_tensor(
read_item: torch.distributed.checkpoint.planner.ReadItem,
tensor: torch.Tensor,
) None[source]#
lookup_tensor(
index: torch.distributed.checkpoint.metadata.MetadataIndex,
) torch.Tensor[source]#

Extension from the planner interface to make it easy to extend the default planner.

transform_tensor(
read_item: torch.distributed.checkpoint.planner.ReadItem,
tensor: torch.Tensor,
)[source]#

Extension from the planner interface to make it easy to extend the default planner.

class nemo_automodel.checkpoint._backports.default_planner._EmptyStateDictLoadPlanner(keys=None, *args, **kwargs)[source]#

Bases: nemo_automodel.checkpoint._backports.default_planner.DefaultLoadPlanner

Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. Useful for loading in state_dict without first initializing a model, such as when converting a DCP checkpoint into a Torch save file.

. N.B. state_dict must be an empty dictionary when used with this LoadPlanner

.. warning:: Because the entire state dict is initialized, It’s recommended to only utilize this LoadPlanner on a single rank or process to avoid OOM.

Initialization

_should_include_key(
key: str,
metadata: torch.distributed.checkpoint.metadata.Metadata,
) bool[source]#
set_up_planner(
state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE,
metadata: Optional[torch.distributed.checkpoint.metadata.Metadata] = None,
is_coordinator: bool = False,
) None[source]#
nemo_automodel.checkpoint._backports.default_planner.create_default_local_load_plan(
state_dict: dict[str, Any],
metadata: torch.distributed.checkpoint.metadata.Metadata,
strict: bool = True,
) torch.distributed.checkpoint.planner.LoadPlan[source]#
nemo_automodel.checkpoint._backports.default_planner.create_default_global_load_plan(
all_plans: list[torch.distributed.checkpoint.planner.LoadPlan],
) list[torch.distributed.checkpoint.planner.LoadPlan][source]#

Create global load plan used by DefaultLoadPlanner.

The default load behavior involved no global coordination and this function currently doesn’t change the local plans.

nemo_automodel.checkpoint._backports.default_planner.create_default_local_save_plan(
state_dict: dict[str, Any],
is_coordinator: bool,
) torch.distributed.checkpoint.planner.SavePlan[source]#

Create the SavePlan used by DefaultSavePlanner.

On non-coordinator ranks, this function ignores tensors and non-tensor objects, only producing writes for ShardedTensor objects.

On the coordinator rank, produce writes for all values.

nemo_automodel.checkpoint._backports.default_planner.create_default_global_save_plan(
all_plans: list[torch.distributed.checkpoint.planner.SavePlan],
rewrite_index_hints: bool = True,
) tuple[list[torch.distributed.checkpoint.planner.SavePlan], torch.distributed.checkpoint.metadata.Metadata][source]#

Create the global plan and metadata used by DefaultSavePlanner.

Metadata is produced by concatenating the metadata of all WriteItem from the supplied plans.

The only global planning change is to update index hints in all MetadataIndex objects if rewrite_index_hints is True.

nemo_automodel.checkpoint._backports.default_planner._create_default_local_metadata(
state_dict: torch.distributed.checkpoint.metadata.STATE_DICT_TYPE,
) torch.distributed.checkpoint.metadata.Metadata[source]#

Return the Metadata if DefaultSavePlanner was used to checkpoint state_dict.

nemo_automodel.checkpoint._backports.default_planner._check_box_overlap(
box0: torch.distributed.checkpoint.metadata.ChunkStorageMetadata,
box1: torch.distributed.checkpoint.metadata.ChunkStorageMetadata,
) bool[source]#

Check if two boxes overlap. Tuples are (offset, lengths).

nemo_automodel.checkpoint._backports.default_planner._check_box_bounds(
outer_box_size: torch.Size,
inner_box: torch.distributed.checkpoint.metadata.ChunkStorageMetadata,
) bool[source]#
nemo_automodel.checkpoint._backports.default_planner._validate_global_plan(
global_plan: list[torch.distributed.checkpoint.planner.SavePlan],
metadata: torch.distributed.checkpoint.metadata.Metadata,
) bool[source]#