# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# taken and edited from
# https://github.com/pytorch/pytorch/blob/c13e725edd8dd21406c629bf625f2d6c59ceedd1/torch/distributed/checkpoint/default_planner.py
# pylint: disable=missing-class-docstring, missing-function-docstring,line-too-long
import dataclasses
import io
import logging
import operator
from collections import ChainMap
from functools import reduce
from typing import Any, Optional, Union, cast
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._dedup_save_plans import dedup_save_plans
from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING,
flatten_state_dict,
)
from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.metadata import (
STATE_DICT_TYPE,
STORAGE_TYPES,
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
StorageMeta,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
LoadPlan,
LoadPlanner,
ReadItem,
SavePlan,
SavePlanner,
WriteItem,
WriteItemType,
)
from torch.distributed.checkpoint.planner_helpers import (
_compare_save_plans,
_create_default_metadata_only_plan,
_create_read_items,
_create_write_items,
_init_state_dict,
_merge_delta_local_plans,
)
from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.distributed.tensor import DTensor
from nemo_automodel.checkpoint._backports import _version
from nemo_automodel.checkpoint._backports.planner_helpers import _contains_usable_plan
logger: logging.Logger = logging.getLogger(__name__)
__all__ = [
"DefaultSavePlanner",
"DefaultLoadPlanner",
"create_default_local_load_plan",
"create_default_global_load_plan",
"create_default_local_save_plan",
"create_default_global_save_plan",
]
# TODO: Update docstrings for default_planner.py
[docs]
class DefaultSavePlanner(SavePlanner):
mappings: FLATTEN_MAPPING
def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
dedup_replicated_tensors: Optional[bool] = None,
dedup_save_to_lowest_rank: bool = False,
enable_plan_caching: bool = False,
) -> None:
self.flatten_state_dict = flatten_state_dict
self.flatten_sharded_tensors = flatten_sharded_tensors
self.mappings = {}
self.dedup_save_to_lowest_rank = dedup_save_to_lowest_rank
if dedup_replicated_tensors is not None:
logger.warning(
"DefaultSavePlanner's `dedup_replicated_tensors` argument is being "
"deprecated, and no longer has any effect. Please remove this argument "
"from your call."
)
self._cached_plans_key: str = self.__class__.__name__
self._enable_plan_caching = enable_plan_caching
[docs]
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
storage_meta: Optional[StorageMeta] = None,
is_coordinator: bool = False,
) -> None:
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)
self.state_dict = state_dict
self.is_coordinator = is_coordinator
[docs]
def create_local_plan(self) -> SavePlan:
plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
self.plan = plan
if self._enable_plan_caching:
# If plans are equal, we can skip sending the plan to the coordinator.
if self._cached_plans_key in SavePlanner._cached_save_plan and _compare_save_plans(
plan, SavePlanner._cached_save_plan[self._cached_plans_key]
):
logger.info("No change in the local plan. Skipping sending the plan to the coordinator")
return SavePlan([], usable=False)
else:
SavePlanner._cached_save_plan[self._cached_plans_key] = plan
return self.plan
[docs]
def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
return dedup_save_plans(all_plans, self.dedup_save_to_lowest_rank)
[docs]
def _create_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
deduped_plans = self._dedup_save_plans(all_plans)
global_plan, metadata = create_default_global_save_plan(deduped_plans)
if self.flatten_state_dict:
# | does not work for Python 3.8 or older version.
# merged_mappings = reduce(
# lambda x, y: x | y, (p.planner_data for p in global_plan)
# )
planner_data_dict = [p.planner_data for p in global_plan]
merged_mappings = dict(ChainMap(*planner_data_dict))
metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
if not _validate_global_plan(global_plan, metadata):
raise ValueError("Failed to validate global plan")
return global_plan, metadata
[docs]
def _create_global_plan_with_caching(
self, all_plans: list[SavePlan]
) -> tuple[list[SavePlan], list[SavePlan], Metadata]:
"""
Create global plan with caching.
Returns a tuple of global_plan_delta, global_plan, metadata.
"""
global_plan_delta: list[SavePlan] = []
if self._cached_plans_key not in SavePlanner._cached_all_plans:
# Case 1: If the plans are not cached, the cache will be hydrated with the
# all_plans, global_plans (Deduped), and metadata.
# Cache the original all_plans
SavePlanner._cached_all_plans[self._cached_plans_key] = all_plans
global_plan, metadata = self._create_global_plan(all_plans)
# Cache the deduped and validated global_plan
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
# Cache the metadata
SavePlanner._cached_metadata[self._cached_plans_key] = metadata
# If plans are not cached, global_plan delta will be the same as global plan.
return global_plan, global_plan, metadata
# Case 2: Plans are cached
if not _contains_usable_plan(all_plans):
# Case 2.1: Plans are cached and the local plans have NOT changed (No usable plans).
# Global plan delta will be empty plans to avoid the collective overhead.
# We can reuse the deduped global plan and metadata from the cache directly.
global_plan_delta = [SavePlan([], usable=False)] * len(all_plans)
global_plan = SavePlanner._cached_global_plan[self._cached_plans_key]
metadata = SavePlanner._cached_metadata[self._cached_plans_key]
else:
# Case 2.2: Plans are cached but the local plans have changed.
# We will merge the changed local plans with the cached local plans.
# Updated plans will overwrite the cached plans. New global plan and metadata will be created and cached.
# Global plan delta will be created by comparing the new global plan with the cached global plan.
# Only the global plan delta (updated ones) will be sent to the coordinator to avoid the collective overhead.
merged_plans = _merge_delta_local_plans(SavePlanner._cached_all_plans[self._cached_plans_key], all_plans)
# Cache the updated local plans
SavePlanner._cached_all_plans[self._cached_plans_key] = merged_plans
global_plan, metadata = self._create_global_plan(merged_plans)
if self._cached_plans_key in self._cached_global_plan:
for cached_plan, new_plan in zip(SavePlanner._cached_global_plan[self._cached_plans_key], global_plan):
if _compare_save_plans(cached_plan, new_plan):
global_plan_delta.append(SavePlan([], usable=False))
else:
global_plan_delta.append(new_plan)
# Cache the new global plan and the metadata
SavePlanner._cached_global_plan[self._cached_plans_key] = global_plan
SavePlanner._cached_metadata[self._cached_plans_key] = metadata
return global_plan_delta, global_plan, metadata
[docs]
def create_global_plan(self, all_plans: list[SavePlan]) -> tuple[list[SavePlan], Metadata]:
global_plan_delta: list[SavePlan] = []
if self._enable_plan_caching:
# If the plans are cached, we only need to send the global plan delta to be scattered
# across ranks. Ranks will use the cached final plans instead.
(
global_plan_delta,
global_plan,
metadata,
) = self._create_global_plan_with_caching(all_plans)
else:
global_plan, metadata = self._create_global_plan(all_plans)
# If the caching is not enabled, global delta plan will always be same as the new global plan.
global_plan_delta = global_plan
self.global_plan = global_plan
self.metadata = metadata
return global_plan_delta, self.metadata
[docs]
def _finish_plan_with_caching(self, new_plan: SavePlan) -> SavePlan:
finished_plan: SavePlan = new_plan
if not new_plan.usable:
finished_plan = SavePlanner._cached_final_save_plan[self._cached_plans_key]
else:
finished_plan = new_plan
SavePlanner._cached_final_save_plan[self._cached_plans_key] = new_plan
return finished_plan
[docs]
def finish_plan(self, new_plan: SavePlan) -> SavePlan:
finished_plan: SavePlan = new_plan
if self._enable_plan_caching:
finished_plan = self._finish_plan_with_caching(new_plan)
self.plan = finished_plan
return self.plan
[docs]
def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
object = self.lookup_object(write_item.index)
return self.transform_object(write_item, object)
[docs]
def lookup_object(self, index: MetadataIndex) -> Any:
"""Extension from the planner interface to make it easy to extend the default planner."""
return find_state_dict_object(self.state_dict, index)
[docs]
class DefaultLoadPlanner(LoadPlanner):
"""
DefaultLoadPlanner that adds multiple features on top of LoadPlanner.
In particular it adds the following:
flatten_state_dict: Handle state_dict with nested dicts
flatten_sharded_tensors: For FSDP in 2D parallel mode
allow_partial_load: If False, will raise a runtime error if a key is present in state_dict, but not in the checkpoint.
"""
original_state_dict: STATE_DICT_TYPE
mappings: FLATTEN_MAPPING
def __init__(
self,
flatten_state_dict: bool = True,
flatten_sharded_tensors: bool = True,
allow_partial_load: bool = False,
) -> None:
self.flatten_state_dict = flatten_state_dict
self.flatten_sharded_tensors = flatten_sharded_tensors
self.original_state_dict = {}
self.mappings = {}
self.allow_partial_load = allow_partial_load
[docs]
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
_init_state_dict(state_dict)
self.original_state_dict = state_dict
if self.flatten_sharded_tensors:
state_dict = _flatten_sharded_tensors(state_dict)
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
self.state_dict = state_dict
self.metadata = metadata
self.is_coordinator = is_coordinator
[docs]
def create_local_plan(self) -> LoadPlan:
assert self.metadata is not None
if self.flatten_state_dict:
# To support checkpoints that are saved before v2.4, we have to
# differentiate if the missing keys are due to old checkpoints.
# The contracts are:
# 1. There are 3 cases when we found a missing key.
# 1.1 Actual missing key, but allow_partial_load is False
# 1.2 Actual missing key, but allow_partial load is True
# 1.3 Old checkpoint, but allow_partial_load is False
# 1.4 Old checkpoint, but allow_partial_load is True
# 2. If we found a missing key, we first convert the keys back to
# the key format of v2.3
# 3. If the previous missing keys are in the v2.3 keys, we assume
# this is a old checkpoint.
# 4. Pass the state_dict to `create_default_local_load_plan()`,
# which has the logic to check missing for allow_partial_load.
# So for 1.2 and 1.4 cases, we delegate allow_partial_load check to
# `create_default_local_load_plan()`. The logic here is to determine
# whether the checkpoint belong to 2.3 (or before) or 2.4 (or after).
current_keys = set(self.state_dict.keys())
load_keys = set(self.metadata.state_dict_metadata.keys())
missing_keys = load_keys - current_keys
if missing_keys:
_version._derived_version = "2_3"
old_state_dict, old_mappings = flatten_state_dict(self.original_state_dict)
old_keys = set(old_state_dict.keys())
if old_keys & missing_keys:
self.state_dict, self.mappings = old_state_dict, old_mappings
# _derived_version is only used by flatten_state_dict now.
# Set it back to None so that later we can save to a new version.
_version._derived_version = None
return create_default_local_load_plan(self.state_dict, self.metadata, not self.allow_partial_load)
[docs]
def create_global_plan(self, global_plan: list[LoadPlan]) -> list[LoadPlan]:
return create_default_global_load_plan(global_plan)
[docs]
def finish_plan(self, new_plan: LoadPlan) -> LoadPlan:
return new_plan
[docs]
def load_bytes(self, read_item: ReadItem, value: io.BytesIO) -> None:
if self.flatten_state_dict:
set_element(
self.original_state_dict,
self.mappings[read_item.dest_index.fqn],
torch.load(value, weights_only=False),
)
else:
self.state_dict[read_item.dest_index.fqn] = torch.load(value, weights_only=False)
[docs]
def resolve_tensor(self, read_item: ReadItem):
tensor = self.lookup_tensor(read_item.dest_index)
return self.transform_tensor(read_item, tensor)
[docs]
def commit_tensor(self, read_item: ReadItem, tensor: torch.Tensor) -> None:
pass
[docs]
def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor:
"""Extension from the planner interface to make it easy to extend the default planner."""
return find_state_dict_object(self.state_dict, index)
[docs]
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
Useful for loading in state_dict without first initializing a model, such as
when converting a DCP checkpoint into a Torch save file.
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
.. warning::
Because the entire state dict is initialized, It's recommended to only utilize
this LoadPlanner on a single rank or process to avoid OOM.
"""
def __init__(self, keys=None, *args, **kwargs):
self.keys = keys
super().__init__(*args, **kwargs)
[docs]
def _should_include_key(self, key: str, metadata: Metadata) -> bool:
if self.keys is None:
return True
if key in self.keys:
True
unflattened_keys: list[str] = []
planner_data = metadata.planner_data.get(key)
for unflattened_key in planner_data:
if unflattened_keys:
unflattened_keys.append(".".join([unflattened_keys[-1], str(unflattened_key)]))
else:
unflattened_keys.append(unflattened_key)
if any(unflattened_key in self.keys for unflattened_key in unflattened_keys):
return True
return False
[docs]
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Optional[Metadata] = None,
is_coordinator: bool = False,
) -> None:
assert not state_dict
assert metadata is not None
# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if not self._should_include_key(k, metadata):
continue
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if metadata.planner_data is not None and k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v
super().set_up_planner(state_dict, metadata, is_coordinator)
[docs]
def create_default_local_load_plan(state_dict: dict[str, Any], metadata: Metadata, strict: bool = True) -> LoadPlan:
requests = []
"""
Create the ``LoadPlan`` used by DefaultLoadPlanner.
It produces one read item per value in ``state_dict`` using the metadata in ``metadata``.
The default behavior is to match key exactly between state_dict and metadata.
It handles resharding by issuing multiple read requests against storage in order to match
load requirements.
"""
for fqn, obj in state_dict.items():
# ignore state_dict keys which do not exist in `state_dict` if strict=False
if fqn not in metadata.state_dict_metadata:
if strict:
raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
else:
continue
md = metadata.state_dict_metadata[fqn]
if isinstance(md, TensorStorageMetadata) and getattr(obj, "size", None) is not None and md.size != obj.size():
raise ValueError(
f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
)
# Since DTensor supports submesh, adding extra check to ensure _create_read_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
if isinstance(obj, DTensor):
if obj.device_mesh.get_coordinate() is not None:
requests += _create_read_items(fqn, md, obj)
else:
requests += _create_read_items(fqn, md, obj)
return LoadPlan(requests)
[docs]
def create_default_global_load_plan(
all_plans: list[LoadPlan],
) -> list[LoadPlan]:
"""
Create global load plan used by DefaultLoadPlanner.
The default load behavior involved no global coordination and this function
currently doesn't change the local plans.
"""
return all_plans
[docs]
def create_default_local_save_plan(state_dict: dict[str, Any], is_coordinator: bool) -> SavePlan:
"""
Create the ``SavePlan`` used by DefaultSavePlanner.
On non-coordinator ranks, this function ignores tensors and non-tensor objects,
only producing writes for ShardedTensor objects.
On the coordinator rank, produce writes for all values.
"""
requests = []
for fqn, obj in state_dict.items():
# Since DTensor supports submesh, adding extra check to ensure _create_write_items()
# gets called only when the current rank is part of the mesh for the corresponding DTensor.
if isinstance(obj, DTensor):
if obj.device_mesh.get_coordinate() is not None:
requests += _create_write_items(fqn, obj)
else:
# For the plain tensor and non-tensor values, add the request for all
# the ranks. Coordinator will decides whether to deduplicate the
# values based on the keys.
requests += _create_write_items(fqn, obj)
return SavePlan(requests)
[docs]
def create_default_global_save_plan(
all_plans: list[SavePlan],
rewrite_index_hints: bool = True,
) -> tuple[list[SavePlan], Metadata]:
"""
Create the global plan and metadata used by DefaultSavePlanner.
Metadata is produced by concatenating the metadata of all ``WriteItem`` from the supplied plans.
The only global planning change is to update index hints in all ``MetadataIndex`` objects if
``rewrite_index_hints`` is True.
"""
md: dict[str, STORAGE_TYPES] = {}
new_plans = []
for plan in all_plans:
new_items = []
for item in plan.items:
if not item.type == WriteItemType.SHARD:
assert item.index.fqn not in md
if item.type == WriteItemType.BYTE_IO:
md[item.index.fqn] = BytesStorageMetadata()
new_items.append(item)
else:
assert item.tensor_data is not None
tensor_md = cast(
TensorStorageMetadata,
md.setdefault(
item.index.fqn,
TensorStorageMetadata(
properties=item.tensor_data.properties,
size=item.tensor_data.size,
chunks=[],
),
),
)
new_item = item
if rewrite_index_hints:
new_index = dataclasses.replace(item.index, index=len(tensor_md.chunks))
new_item = dataclasses.replace(item, index=new_index)
new_items.append(new_item)
assert item.tensor_data.chunk is not None, f"""
Cannot create MD for tensor without bounds.
FQN: {item.index.fqn}
"""
tensor_md.chunks.append(item.tensor_data.chunk)
new_plans.append(dataclasses.replace(plan, items=new_items))
return (new_plans, Metadata(md))
[docs]
def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
"""Check if two boxes overlap. Tuples are (offset, lengths)."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
# shard, we would check if one shard is above or on the left of the
# other shard.
ndims = len(box0.offsets)
for i in range(ndims):
if box0.offsets[i] >= box1.offsets[i] + box1.sizes[i]:
return False
if box1.offsets[i] >= box0.offsets[i] + box0.sizes[i]:
return False
return True
[docs]
def _check_box_bounds(outer_box_size: torch.Size, inner_box: ChunkStorageMetadata) -> bool:
for i in range(len(outer_box_size)):
if inner_box.offsets[i] < 0:
return False
if inner_box.sizes[i] < 0:
return False
if inner_box.offsets[i] + inner_box.sizes[i] > outer_box_size[i]:
return False
return True
[docs]
def _validate_global_plan(global_plan: list[SavePlan], metadata: Metadata) -> bool:
all_good = True
for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata):
continue
if len(value.size) == 0:
continue
chunks_volume = 0
for chunk_idx, chunk0 in enumerate(value.chunks):
# Compute the volume
if not _check_box_bounds(value.size, chunk0):
logger.warning(
"""
key:%s has out of bounds chunk:
tensor-size:%s chunk: %s
""",
key,
value.size,
chunk0,
)
all_good = False
chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
# Check for overlap
for chunk1 in value.chunks[chunk_idx + 1 :]:
if _check_box_overlap(chunk0, chunk1):
logger.warning("key:%s has overlapping chunks: %s %s", key, chunk0, chunk1)
all_good = False
# Check whether combined chunk cover the whole tensor
tensor_volume = reduce(operator.mul, value.size, 1)
if chunks_volume != tensor_volume:
logger.warning(
"""
key:%s invalid fill tensor-volume:
%s chunks-volume: %s
""",
key,
tensor_volume,
chunks_volume,
)
all_good = False
return all_good