# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint utilities for saving and loading training state.
Provides :func:`save_checkpoint` and :func:`load_checkpoint` for persisting
and restoring model weights, optimizer/scheduler/scaler state, and arbitrary
metadata. Supports local filesystems and remote stores via ``fsspec``.
When models are wrapped with FSDP or use DTensor/ShardTensor parameters,
:func:`save_checkpoint` and :func:`load_checkpoint` automatically use
PyTorch's distributed checkpoint state-dict APIs to gather and scatter
model and optimizer state. In this *distributed* mode **all ranks** must
call the functions (the collective operations inside the DCP helpers require
it), while only rank 0 performs actual file I/O.
"""
import io
import os
import re
import tarfile
import zipfile
from pathlib import Path, PurePath
from typing import Any
import fsspec
import fsspec.utils
import torch
from torch.amp import GradScaler
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FSDPModule, ShardingStrategy
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.tensor import DTensor, distribute_tensor
from torch.optim.lr_scheduler import LRScheduler
import physicsnemo
from physicsnemo.core.filesystem import LOCAL_CACHE, _download_cached
from physicsnemo.core.module import Module
from physicsnemo.distributed import DistributedManager
from physicsnemo.utils.capture import _StaticCapture
from physicsnemo.utils.logging import PythonLogger
checkpoint_logging = PythonLogger("checkpoint")
# ---------------------------------------------------------------------------
# Distributed-model detection helpers
# ---------------------------------------------------------------------------
def _is_distributed_model(model: torch.nn.Module) -> bool:
"""
A simple helper function to determine whether to save/load using DCP or not.
Return ``True`` when *model* is FSDP-wrapped or has DTensor params.
FSDP: FSDP1
FSDPModule: FSDP2 (fully_shard)
"""
if isinstance(model, (FSDP, FSDPModule)):
return True
return any(isinstance(p, DTensor) for p in model.parameters())
def _unwrap_ddp_compile(
model: torch.nn.Module, loading: bool = False
) -> torch.nn.Module:
"""
Strip DDP / DataParallel / ``torch.compile`` wrappers, keep FSDP.
"""
if isinstance(
model,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
):
model = model.module
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
if loading:
checkpoint_logging.warning(
f"Model {type(model._orig_mod).__name__} is already compiled, "
"consider loading first and then compiling."
)
model = model._orig_mod
return model
def _unwrap_fsdp(model: torch.nn.Module) -> torch.nn.Module:
"""
Unwrap one FSDP layer (if present) to reach the user module.
FSDP1 only.
"""
if isinstance(model, FSDP):
return model.module
return model
def _unwrapped_class_name(model: torch.nn.Module) -> str:
"""
Return the user-facing class name, peeling FSDP1/FSDP2 wrappers.
FSDP2's ``fully_shard`` rebinds ``model.__class__`` to a dynamically
generated ``FSDP{ClassName}`` subclass with bases ``(FSDPModule, original_cls)``.
Without this fix, saving an FSDP2-wrapped model produces a ``.mdlus``
file named after the synthetic class (e.g. ``FSDPFullyConnected.mdlus``)
instead of the original (``FullyConnected.mdlus``).
"""
inner = _unwrap_fsdp(model)
if isinstance(inner, FSDPModule):
for cls in type(inner).__mro__:
if issubclass(cls, FSDPModule) or cls in (torch.nn.Module, object):
continue
return cls.__name__
return type(inner).__name__
def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Move every tensor in *state_dict* to CPU (shallow copy)."""
out: dict[str, Any] = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
out[k] = v.cpu()
elif isinstance(v, dict):
out[k] = _cpu_offload_state_dict(v)
else:
out[k] = v
return out
def _materialize_dtensors_to_full(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Replace DTensor entries with their fully-gathered plain tensor.
DCP's ``get_model_state_dict`` / ``get_optimizer_state_dict`` with
``StateDictOptions(full_state_dict=True)`` is documented to gather
DTensor shards into full plain tensors, but for FSDP2 (``fully_shard``)
models on some PyTorch versions the returned dict still contains
DTensors whose ``_local_tensor`` is just the local shard. When that
output is then serialized via ``torch.save`` and reloaded on a model
with a different mesh shape (e.g. ``(ddp=N,)`` -> ``(ddp=1, domain=N)``),
only rank 0's local shard survives the round-trip -- producing a
silent half-sized parameter and a ``size mismatch`` ``RuntimeError``
on ``load_state_dict``.
This helper compensates: every rank calls ``DTensor.full_tensor()``
(a collective all-gather) on any remaining DTensors, so the resulting
dict on every rank contains only plain tensors with the full data.
Without this fix, ``test_checkpointing[None-1-1-1-1-False-unet]`` fails
with ``RuntimeError: found no DeviceMesh from dtensor args for
c10d::broadcast_`` when reloading a checkpoint saved on a different mesh.
**Collective** -- must be called on the same keys in the same order
on every rank.
"""
out: dict[str, Any] = {}
for k, v in state_dict.items():
if isinstance(v, DTensor):
out[k] = v.full_tensor()
elif isinstance(v, dict):
out[k] = _materialize_dtensors_to_full(v)
else:
out[k] = v
return out
def _force_standard_contiguous(state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Before broadcasting model/optim state from rank 0,
converts channels_last tensors to standard NCHW layout via .contiguous().
This is necessary because DCP's broadcast_from_rank0 does not handle channels_last tensors (torch bug)
Needed for both FSDP1 + FSDP2.
"""
out: dict[str, Any] = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor) and not isinstance(v, DTensor) and v.dim() > 0:
out[k] = v.contiguous()
elif isinstance(v, dict):
out[k] = _force_standard_contiguous(v)
else:
out[k] = v
return out
def _get_dtensor_param_placements(
model: torch.nn.Module,
) -> dict[str, tuple[Any, tuple[Any, ...]]]:
"""
Builds a map {param_name: (device_mesh, placements)} for every DTensor param.
**Collective** — all ranks must call this together.
"""
native_sd = get_model_state_dict(model, options=StateDictOptions())
info: dict[str, tuple[Any, tuple[Any, ...]]] = {}
for name, value in native_sd.items():
if isinstance(value, DTensor):
info[name] = (value.device_mesh, tuple(value.placements))
return info
def _needs_dcp_broadcast_bypass(
model: torch.nn.Module,
dtensor_plc: dict[str, tuple[Any, tuple[Any, ...]]],
) -> bool:
"""
Return ``True`` when loading should bypass DCP's ``broadcast_from_rank0``.
Use the explicit ``broadcast_object_list`` /
``_redistribute_sd_for_dtensor`` path instead for:
* Plain modules with user-managed DTensor params (domain mesh)
* FSDP1 + NO_SHARD
* FSDP2 with a degenerate mesh axis (size == 1)
* FSDP2 with no materialized DTensor params (e.g. 1-GPU)
Without this fix, loading a checkpoint with FSDP2 on 1-GPU crashes with
``KeyError: 'state.0.step'`` or ``RuntimeError: found no DeviceMesh`` depending on the mesh configuration.
FSDP: FSDP1
FSDPModule: FSDP2
"""
if isinstance(model, FSDPModule):
if not dtensor_plc:
return True
return any(any(s == 1 for s in mesh.shape) for mesh, _ in dtensor_plc.values())
if not dtensor_plc:
return False
if not isinstance(model, FSDP):
return True
if model.sharding_strategy == ShardingStrategy.NO_SHARD:
return True
return False
def _redistribute_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
state_dict: dict[str, Any],
) -> dict[str, Any]:
"""
Converts tensors in *state_dict* to DTensors matching *placements*.
Entries whose key appears in *placements* are converted to DTensors via
distribute_tensor so that each rank receives its correct local shard.
When *placements* is empty (e.g. FSDP2 on world_size == 1 materialises
no DTensor parameters), the state dict is returned as-is.
Without this fix, cross-mesh checkpoint reloads (e.g. save with
ddp=4, load with ddp=1, domain=4) fail because
distribute_tensor refuses to re-mesh an existing DTensor.
"""
if not placements:
return dict(state_dict)
target_device = next(iter(placements.values()))[0].device_type
out: dict[str, Any] = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
out[key] = value
continue
if key in placements:
mesh, plc = placements[key]
if isinstance(value, DTensor):
if value.device_mesh == mesh:
# Same mesh; ``distribute_tensor`` would no-op and
# passing the DTensor straight through avoids an
# unnecessary collective.
out[key] = value
continue
# Cross-mesh: peel back to plain data, then redistribute.
value = value.to_local()
out[key] = distribute_tensor(value.to(mesh.device_type), mesh, list(plc))
elif isinstance(value, DTensor):
# Non-distributed key but the input is still a DTensor; keep
# it as-is so the caller can pass it through to ``set_*_state_dict``
# untouched.
out[key] = value
else:
out[key] = value.to(target_device)
return out
def _redistribute_optim_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
optim_sd: dict[str, Any],
) -> dict[str, Any]:
"""
Shard optimizer state tensors to local chunks matching model placements.
FSDP's optim_state_dict_to_load expects each optimizer state tensor
(exp_avg, exp_avg_sq, ...) to be a plain tensor whose shape
matches the parameter's local_shape - not a DTensor. We use
distribute_tensor(...).to_local() to extract each rank's shard.
Scalar state entries (e.g. step) are left unchanged.
When placements is empty (e.g. FSDP2 on world_size == 1 materialises no DTensor parameters),
the optimizer state dict is returned as-is - the live optimizer expects plain tensors and the serialized data already matches.
Without this fix, optimizer reload fails on cross-mesh scenarios for the same reason as _redistribute_sd_for_dtensor.
"""
if "state" not in optim_sd:
return optim_sd
if not placements:
return optim_sd
target_device = next(iter(placements.values()))[0].device_type
new_state: dict[str, Any] = {}
for param_name, param_state in optim_sd["state"].items():
if not isinstance(param_state, dict):
new_state[param_name] = param_state
continue
new_ps: dict[str, Any] = {}
mesh_plc = placements.get(param_name)
for k, v in param_state.items():
if not isinstance(v, torch.Tensor) or v.dim() == 0:
new_ps[k] = v
elif isinstance(v, DTensor) and mesh_plc is None:
# No target placement for this entry; leave the DTensor
# untouched so the caller can pass it through.
new_ps[k] = v
elif mesh_plc is not None:
mesh, plc = mesh_plc
if isinstance(v, DTensor):
if v.device_mesh == mesh:
# Already on the target mesh -- ``to_local`` already
# yields the per-rank local shard.
new_ps[k] = v.to_local()
continue
# Cross-mesh DTensor: peel to plain data before
# redistributing. Assumes save materialised the full
# tensor on every rank (see
# :func:`_redistribute_sd_for_dtensor` for the same
# assumption).
v = v.to_local()
new_ps[k] = distribute_tensor(
v.to(mesh.device_type), mesh, list(plc)
).to_local()
else:
new_ps[k] = v.to(target_device)
new_state[param_name] = new_ps
return {**optim_sd, "state": new_state}
def _materialize_optimizer_state_for_dcp(
optimizer: torch.optim.Optimizer,
loaded_state: dict[Any, Any],
) -> None:
"""
Pre-fills optimizer.state[p] with zero placeholders (step, exp_avg, etc.) before set_optimizer_state_dict.
Without this fix, loading a trained checkpoint into a freshly
constructed FSDP2 optimizer raises ``KeyError: 'state.0.step'``
inside DCP's ``_unflatten_state_dict``.
"""
if not loaded_state:
return
# Loaded state is keyed by *index* (0, 1, ...) on save; pick any entry
# to learn the per-param state keys.
sample = next(iter(loaded_state.values()), None)
if not isinstance(sample, dict):
return
state_keys = list(sample.keys())
if not state_keys:
return
for group in optimizer.param_groups:
for p in group["params"]:
slot = optimizer.state[p]
for k in state_keys:
if k in slot:
continue
ref = sample.get(k)
if isinstance(ref, torch.Tensor) and ref.dim() == 0:
# Scalar state (typically ``step``). Match dtype/device
# of the live param so subsequent .copy_ inside DCP
# doesn't trip on a mismatch.
slot[k] = torch.zeros((), dtype=ref.dtype, device=p.device)
else:
# Tensor-shaped state (``exp_avg``, ``exp_avg_sq``, ...);
# zeros_like matches dtype/device/sharding of the live
# param, which is what DCP will copy into.
slot[k] = torch.zeros_like(p)
def _fsdp_uses_flat_param_optim(model: torch.nn.Module | None) -> bool:
"""
Returns True only for FSDP1 with use_orig_params=False (FlatParameter optimizer path).
Not needed after dropping FSDP1 support.
"""
if not isinstance(model, FSDP):
return False
return not getattr(model, "_use_orig_params", True)
def _strides_match_channels_last(
shape: tuple[int, ...] | torch.Size,
stride: tuple[int, ...],
) -> bool:
"""
Checks whether a tensor’s strides match canonical channels_last (4D) or channels_last_3d (5D) layout.
Used to determine whether to remap optimizer state for channels_last tensors.
Not needed after dropping FSDP1 support (only affects FSDP1 with use_orig_params=False).
"""
if len(shape) != len(stride):
return False
if len(shape) == 4:
n, c, h, w = shape
return tuple(stride) == (c * h * w, 1, w * c, c)
if len(shape) == 5:
n, c, d, h, w = shape
return tuple(stride) == (d * h * w * c, 1, h * w * c, w * c, c)
return False
def _get_cl_param_fqns(opt_model: torch.nn.Module | None) -> set[str]:
"""
Scans FSDP1 FlatParameter metadata (_fqns, _strides, _contiguities) to find which original
params are stored in channels_last byte order.
Returns their fully-qualified names (FQNs).
Returns an empty set when opt_model is not FSDP1 with use_orig_params=False.
Not needed after dropping FSDP1 support (only affects FSDP1 with use_orig_params=False).
"""
if not _fsdp_uses_flat_param_optim(opt_model):
return set()
cl_fqns: set[str] = set()
for module_name, module in opt_model.named_modules():
if not isinstance(module, FSDP):
continue
flat_param = getattr(module, "_flat_param", None)
if flat_param is None:
continue
# DCP's ``_get_fqns`` skips the ``_fsdp_wrapped_module`` attribute
# when building parameter FQNs; mirror that by removing the segment
# from the module path.
path_segments = [
seg
for seg in module_name.split(".")
if seg and seg != "_fsdp_wrapped_module"
]
prefix = ".".join(path_segments)
if prefix:
prefix += "."
for fqn, shape, stride, contig in zip(
flat_param._fqns,
flat_param._shapes,
flat_param._strides,
flat_param._contiguities,
):
if contig:
continue
if _strides_match_channels_last(shape, stride):
cl_fqns.add((prefix + fqn).removeprefix("_orig_mod."))
return cl_fqns
def _remap_channels_last_optim_sd(
opt_model: torch.nn.Module | None,
optim_sd: dict[str, Any],
) -> dict[str, Any]:
"""
Fixes FSDP1 FlatParameter flatten/unflatten asymmetry:
- save reads bytes in storage order, load reads in logical order.
For channels_last conv weights, permutes optimizer state tensors so the load-side flatten sees the correct byte layout.
Also normalizes non-CL tensors to standard contiguity.
Not needed after dropping FSDP1 support (only affects FSDP1 with use_orig_params=False).
"""
if "state" not in optim_sd:
return optim_sd
if not _fsdp_uses_flat_param_optim(opt_model):
return optim_sd
cl_fqns = _get_cl_param_fqns(opt_model)
def _normalize(t: torch.Tensor, is_cl_dest: bool) -> torch.Tensor:
if isinstance(t, DTensor) or t.dim() == 0:
return t
# Force standard contiguity first so any saved-CL bytes are
# rewritten in NCHW storage order before the layout decision; this
# makes the subsequent broadcast inside DCP layout-safe whether or
# not we permute.
t = t.contiguous()
if not is_cl_dest:
return t
if t.dim() == 4:
return t.permute(0, 2, 3, 1).contiguous().view(*t.shape)
if t.dim() == 5:
return t.permute(0, 2, 3, 4, 1).contiguous().view(*t.shape)
return t
new_state: dict[str, Any] = {}
for pname, pstate in optim_sd["state"].items():
if not isinstance(pstate, dict):
new_state[pname] = pstate
continue
is_cl_dest = pname.removeprefix("_orig_mod.") in cl_fqns
new_ps: dict[str, Any] = {}
for k, v in pstate.items():
new_ps[k] = _normalize(v, is_cl_dest) if isinstance(v, torch.Tensor) else v
new_state[pname] = new_ps
return {**optim_sd, "state": new_state}
def _is_mdlus_archive(path: str) -> bool:
"""
Checks if a path is a .mdlus archive (tar or zip containing model.pt).
"""
cached = _cache_if_needed(path)
if tarfile.is_tarfile(cached):
with tarfile.open(cached, "r") as tar:
return "model.pt" in tar.getnames()
if zipfile.is_zipfile(cached):
with zipfile.ZipFile(cached, "r") as archive:
return "model.pt" in archive.namelist()
return False
def _extract_mdlus_state_dict(
file_name: str, device: str | torch.device = "cpu"
) -> dict[str, Any]:
"""
Reads only the state_dict from a .mdlus file without instantiating the full model.
Used in distributed load where rank 0 reads the file and broadcasts tensors.
"""
cached = _cache_if_needed(file_name)
fmt = Module._detect_checkpoint_format(cached)
if fmt == "tar":
with tarfile.open(cached, "r") as tar:
f = tar.extractfile("model.pt")
return torch.load(
io.BytesIO(f.read()), map_location=device, weights_only=False
)
else:
with zipfile.ZipFile(cached, "r") as archive:
model_bytes = archive.read("model.pt")
return torch.load(
io.BytesIO(model_bytes), map_location=device, weights_only=False
)
def _get_checkpoint_filename(
path: str,
base_name: str = "checkpoint",
index: int | None = None,
saving: bool = False,
model_type: str = "mdlus",
distributed: bool = False,
) -> str:
"""
Builds the filename for a numbered checkpoint.
When no existing checkpoints are found, the returned path uses index 0.
Parameters
----------
path : str
Directory containing checkpoint files.
base_name : str, optional
Stem used in the filename, by default ``"checkpoint"``.
index : int | None, optional
Specific checkpoint index to use. When ``None``, the latest or
next index is determined automatically.
saving : bool, optional
If ``True`` (and ``index is None``), return the *next* available
filename rather than the latest existing one. By default ``False``.
model_type : str, optional
``"mdlus"`` for :class:`~physicsnemo.core.Module` models,
``"pt"`` for vanilla PyTorch models. Determines the file
extension. By default ``"mdlus"``.
distributed : bool, optional
When ``True`` the model_parallel_rank component of the filename is
forced to ``0`` because FSDP/DTensor distribution is handled by the
DCP APIs, not per-rank files. By default ``False``.
Returns
-------
str
Fully qualified checkpoint filename.
"""
# Get model parallel rank so all processes in the first model parallel group
# can save their checkpoint. In the case without model parallelism,
# model_parallel_rank should be the same as the process rank itself and
# only rank 0 saves
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
if distributed:
model_parallel_rank = 0
else:
model_parallel_rank = (
manager.group_rank("model_parallel")
if "model_parallel" in manager.group_names
else 0
)
# Determine input file name. Get absolute file path if Posix path.
# pathlib does not support custom schemes (eg: msc://...) so only perform resolve() for Posix.
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
if protocol == "file":
path = str(Path(path).resolve())
checkpoint_filename = f"{path}/{base_name}.{model_parallel_rank}"
# File extension for PhysicsNeMo models or PyTorch models
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"
# If epoch is provided load that file
if index is not None:
checkpoint_filename = checkpoint_filename + f".{index}"
checkpoint_filename += file_extension
# Otherwise try loading the latest epoch or rolling checkpoint
else:
file_names = [
fname for fname in fs.glob(checkpoint_filename + "*" + file_extension)
]
if len(file_names) > 0:
# If checkpoint from a null index save exists load that
# This is the most likely line to error since it will fail with
# invalid checkpoint names
file_idx = []
for fname in file_names:
fname_path = PurePath(fname)
file_stem = fname_path.name
pattern = rf"^{re.escape(base_name)}\.{model_parallel_rank}\.(\d+){re.escape(file_extension)}$"
match = re.match(pattern, file_stem)
if match:
file_idx.append(int(match.group(1)))
file_idx.sort()
# If we are saving index by 1 to get the next free file name
if saving:
checkpoint_filename = checkpoint_filename + f".{file_idx[-1] + 1}"
else:
checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}"
checkpoint_filename += file_extension
else:
checkpoint_filename += ".0" + file_extension
return checkpoint_filename
def _unique_model_names(
models: list[torch.nn.Module],
loading: bool = False,
) -> dict[str, torch.nn.Module]:
r"""Map a list of models to unique names derived from their class names.
DDP and ``torch.compile`` wrappers are stripped, but FSDP wrappers are
preserved so that the returned modules can be passed to PyTorch's DCP
state-dict helpers when needed.
When multiple models share a class name a numeric suffix is appended
(e.g. ``"MyModel0"``, ``"MyModel1"``).
Parameters
----------
models : list[torch.nn.Module]
Models to generate names for.
loading : bool, optional
When ``True``, emits a warning if a model is already compiled
(loading into a compiled model can cause issues). By default
``False``.
Returns
-------
dict[str, torch.nn.Module]
Mapping from unique name to module (with FSDP intact if present).
"""
model_dict: dict[str, list[torch.nn.Module]] = {}
for model0 in models:
model0 = _unwrap_ddp_compile(model0, loading=loading)
base_name = _unwrapped_class_name(model0)
if base_name in model_dict:
model_dict[base_name].append(model0)
else:
model_dict[base_name] = [model0]
output_dict: dict[str, torch.nn.Module] = {}
for key, model_list in model_dict.items():
if len(model_list) > 1:
for i, m in enumerate(model_list):
output_dict[key + str(i)] = m
else:
output_dict[key] = model_list[0]
return output_dict
[docs]
def save_checkpoint(
path: Path | str,
models: torch.nn.Module | list[torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: LRScheduler | None = None,
scaler: GradScaler | None = None,
epoch: int | None = None,
metadata: dict[str, Any] | None = None,
optimizer_model: torch.nn.Module | None = None,
) -> None:
r"""Save a training checkpoint to disk (or a remote store).
Up to two categories of files are created inside ``path``:
* **Model weights** (when ``models`` is provided) - one file per model:
``{class_name}{id}.{mp_rank}.{epoch}.{ext}`` where *ext* is
``.mdlus`` for :class:`~physicsnemo.core.Module` instances or
``.pt`` for plain PyTorch models. When several models share a class
name, a numeric *id* is appended (``"MyModel0"``, ``"MyModel1"``).
* **Training state** (when any of ``optimizer`` / ``scheduler`` /
``scaler`` is provided, or
:class:`~physicsnemo.utils.capture._StaticCapture` scalers exist):
``checkpoint.{mp_rank}.{epoch}.pt`` containing their combined
``state_dict`` entries, plus ``epoch`` and ``metadata``.
When any model is FSDP-wrapped or contains DTensor/ShardTensor
parameters the function enters *distributed* mode: all ranks **must**
call it, state is gathered via DCP collective helpers, and only rank 0
writes files.
Use :func:`load_checkpoint` to restore from these files.
To instantiate *and* load a model in one step (without pre-constructing
it), use :meth:`~physicsnemo.core.module.Module.from_checkpoint`.
Parameters
----------
path : Path | str
Directory in which to store checkpoint files. Created
automatically for local paths if it does not exist.
models : torch.nn.Module | list[torch.nn.Module] | None, optional
Model(s) whose weights should be saved.
optimizer : torch.optim.Optimizer | None, optional
Optimizer whose ``state_dict`` should be saved.
scheduler : LRScheduler | None, optional
Learning-rate scheduler whose ``state_dict`` should be saved.
scaler : GradScaler | None, optional
AMP gradient scaler whose ``state_dict`` should be saved.
If ``None`` but a
:class:`~physicsnemo.utils.capture._StaticCapture` scaler exists,
that scaler's state is saved instead.
epoch : int | None, optional
Epoch index to embed in the filename and the checkpoint dict.
When ``None``, the next available index is used.
metadata : dict[str, Any] | None, optional
Arbitrary key-value pairs persisted alongside the training state
(e.g. best validation loss, MLflow run ID).
optimizer_model : torch.nn.Module | None, optional
The model whose parameters the ``optimizer`` is tracking so that
parameter unsharding of optimizer state can be performed correctly.
Only required when multiple models are provided, and at least one of
them is a distributed model (FSDP/ShardTensor). When ``None``, the
first model in ``models`` is used. Ignored when *not* in distributed
mode.
Examples
--------
Save a model together with optimizer and scheduler state:
>>> import tempfile, os, torch
>>> from physicsnemo.utils.checkpoint import save_checkpoint
>>> from physicsnemo.models.mlp import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler, epoch=1)
... sorted(f for f in os.listdir(tmpdir))
['FullyConnected.0.1.mdlus', 'checkpoint.0.1.pt']
Save at multiple epochs with additional metadata:
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=1,
... metadata={"loss": 0.42, "experiment": "run_01"})
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=2,
... metadata={"loss": 0.31, "experiment": "run_01"})
... sorted(f for f in os.listdir(tmpdir))
['FullyConnected.0.1.mdlus', 'FullyConnected.0.2.mdlus', 'checkpoint.0.1.pt', 'checkpoint.0.2.pt']
"""
path = str(path)
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
# Prepare models and detect distributed mode
named_models: dict[str, torch.nn.Module] = {}
is_distributed = False
if models:
if not isinstance(models, list):
models = [models]
named_models = _unique_model_names(models)
is_distributed = any(_is_distributed_model(m) for m in named_models.values())
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. "
"Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
is_rank0 = manager.rank == 0
should_write = is_rank0 if is_distributed else True
# Create checkpoint directory (only writing rank)
if should_write and protocol == "file" and not Path(path).is_dir():
checkpoint_logging.warning(
f"Output directory {path} does not exist, will attempt to create"
)
Path(path).mkdir(parents=True, exist_ok=True)
if is_distributed:
torch.distributed.barrier()
# == Saving model checkpoint ==
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path,
name,
index=epoch,
saving=True,
model_type=model_type,
distributed=is_distributed,
)
if _is_distributed_model(model):
# cpu_offload is handled manually because the DCP option
# hangs for FSDP NO_SHARD + DTensor topologies.
options = StateDictOptions(full_state_dict=True)
state_dict = get_model_state_dict(model, options=options)
# FSDP2 (``fully_shard``) does not always honour
# ``full_state_dict=True``; the result can still contain DTensors
# with sharded ``_local_tensor`` on each rank. Calling
# ``full_tensor()`` collectively materialises full plain
# tensors so the serialized file is mesh-shape-agnostic.
state_dict = _materialize_dtensors_to_full(state_dict)
if should_write:
state_dict = _cpu_offload_state_dict(state_dict)
if isinstance(inner, physicsnemo.core.Module):
inner.save(file_name, _state_dict=state_dict)
else:
with fs.open(file_name, "wb") as fp:
torch.save(state_dict, fp)
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")
else:
if should_write:
if isinstance(inner, physicsnemo.core.Module):
inner.save(file_name)
else:
with fs.open(file_name, "wb") as fp:
torch.save(model.state_dict(), fp)
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")
# == Saving training checkpoint ==
checkpoint_dict: dict[str, Any] = {}
if optimizer:
if is_distributed:
opt_model = optimizer_model or next(
(m for m in named_models.values() if _is_distributed_model(m)),
None,
)
if opt_model is not None:
# cpu_offload is handled manually because the DCP option
# hangs for FSDP NO_SHARD + DTensor topologies.
options = StateDictOptions(full_state_dict=True)
opt_state_dict = get_optimizer_state_dict(
opt_model, optimizer, options=options
)
# See note above ``_materialize_dtensors_to_full`` -- the
# same FSDP2 ``full_state_dict`` gap affects optimizer
# state (``exp_avg``, ``exp_avg_sq``, ...), so we gather
# those too.
opt_state_dict = _materialize_dtensors_to_full(opt_state_dict)
if should_write:
opt_state_dict = _cpu_offload_state_dict(opt_state_dict)
else:
opt_state_dict = optimizer.state_dict()
else:
opt_state_dict = optimizer.state_dict()
# Strip out torch dynamo wrapper prefix
for pg in opt_state_dict.get("param_groups", []):
param_names = pg.get("param_names")
if param_names is None:
continue
pg["param_names"] = [pn.removeprefix("_orig_mod.") for pn in param_names]
checkpoint_dict["optimizer_state_dict"] = opt_state_dict
if scheduler:
checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict()
if scaler:
checkpoint_dict["scaler_state_dict"] = scaler.state_dict()
if _StaticCapture._amp_scalers:
checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict()
output_filename = _get_checkpoint_filename(
path,
index=epoch,
saving=True,
model_type="pt",
distributed=is_distributed,
)
if epoch:
checkpoint_dict["epoch"] = epoch
if metadata:
checkpoint_dict["metadata"] = metadata
if bool(checkpoint_dict) and should_write:
with fs.open(output_filename, "wb") as fp:
torch.save(checkpoint_dict, fp)
checkpoint_logging.success(f"Saved training checkpoint: {output_filename}")
[docs]
def load_checkpoint(
path: Path | str,
models: torch.nn.Module | list[torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: LRScheduler | None = None,
scaler: GradScaler | None = None,
epoch: int | None = None,
metadata_dict: dict[str, Any] | None = None,
device: str | torch.device = "cpu",
optimizer_model: torch.nn.Module | None = None,
) -> int:
r"""Load a training checkpoint saved by :func:`save_checkpoint`.
Scans ``path`` for checkpoint files and restores state dictionaries
into the provided training objects. Objects that are ``None`` are
silently skipped.
When any model is FSDP-wrapped or contains DTensor/ShardTensor
parameters the function enters *distributed* mode: all ranks **must**
call it, rank 0 reads files from disk, and model/optimizer state is
scattered to all ranks via DCP helpers.
Parameters
----------
path : Path | str
Directory containing checkpoint files (local path or ``fsspec``
URI). If the directory does not exist, the load is skipped and
``0`` is returned.
models : torch.nn.Module | list[torch.nn.Module] | None, optional
Model(s) whose ``state_dict`` should be restored. DDP and
``torch.compile`` wrappers are stripped automatically.
optimizer : torch.optim.Optimizer | None, optional
Optimizer whose ``state_dict`` should be restored.
scheduler : LRScheduler | None, optional
Learning-rate scheduler whose ``state_dict`` should be restored.
scaler : GradScaler | None, optional
AMP gradient scaler whose ``state_dict`` should be restored.
epoch : int | None, optional
Specific checkpoint index to load. When ``None``, the checkpoint
with the largest index (most recent) is loaded.
metadata_dict : dict[str, Any] | None, optional
If a ``dict`` is provided, it is updated **in-place** with any
metadata that was persisted by :func:`save_checkpoint`.
device : str | torch.device, optional
Device onto which tensors are mapped during loading. By default
``"cpu"``.
optimizer_model : torch.nn.Module | None, optional
The model whose parameters the ``optimizer`` is tracking.
Required by the DCP ``set_optimizer_state_dict`` helper when
distributed mode is active. When ``None``, the first model in
``models`` is used. Ignored when *not* in distributed mode.
Returns
-------
int
The epoch stored in the checkpoint. Returns ``0`` when:
* The checkpoint directory does not exist.
* No training-state file is found inside the directory.
* The training-state file does not contain an ``"epoch"`` key.
Examples
--------
Save and then restore a model, optimizer, and scheduler from a checkpoint:
>>> import tempfile, torch
>>> from physicsnemo.utils.checkpoint import save_checkpoint, load_checkpoint
>>> from physicsnemo.models.mlp import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler, epoch=1)
... epoch = load_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler)
... epoch
1
Load a specific epoch and retrieve saved metadata:
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=1,
... metadata={"loss": 0.42, "experiment": "run_01"})
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=2,
... metadata={"loss": 0.31, "experiment": "run_01"})
... meta = {}
... epoch = load_checkpoint(tmpdir, models=model, optimizer=optimizer,
... epoch=1, metadata_dict=meta)
... epoch
1
>>> meta["loss"]
0.42
"""
path = str(path)
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
# Prepare models and detect distributed mode
named_models: dict[str, torch.nn.Module] = {}
is_distributed = False
if models:
if not isinstance(models, list):
models = [models]
named_models = _unique_model_names(models, loading=True)
is_distributed = any(_is_distributed_model(m) for m in named_models.values())
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. "
"Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
is_rank0 = manager.rank == 0
# ------------------------------------------------------------------
# Distributed load path -- all ranks participate
# ------------------------------------------------------------------
if is_distributed:
return _load_checkpoint_distributed(
path=path,
fs=fs,
named_models=named_models,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
epoch=epoch,
metadata_dict=metadata_dict,
device=device,
optimizer_model=optimizer_model,
is_rank0=is_rank0,
)
# ------------------------------------------------------------------
# Non-distributed load path
# ------------------------------------------------------------------
if fs.exists(path):
if fs.isfile(path):
raise FileNotFoundError(
f"Provided checkpoint directory {path} is a file, not directory"
)
else:
checkpoint_logging.warning(
f"Provided checkpoint directory {path} does not exist, skipping load"
)
return 0
# == Loading model checkpoint ==
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path, name, index=epoch, model_type=model_type
)
if not fs.exists(file_name):
checkpoint_logging.error(
f"Could not find valid model file {file_name}, skipping load"
)
continue
if isinstance(inner, physicsnemo.core.Module):
inner.load(file_name)
else:
file_to_load = _cache_if_needed(file_name)
missing_keys, unexpected_keys = model.load_state_dict(
torch.load(file_to_load, map_location=device, weights_only=False)
)
if missing_keys:
checkpoint_logging.warning(
f"Missing keys when loading {name}: {missing_keys}"
)
if unexpected_keys:
checkpoint_logging.warning(
f"Unexpected keys when loading {name}: {unexpected_keys}"
)
checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
# == Loading training checkpoint ==
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
if not fs.exists(checkpoint_filename):
checkpoint_logging.warning(
"Could not find valid checkpoint file, skipping load"
)
return 0
file_to_load = _cache_if_needed(checkpoint_filename)
checkpoint_dict = torch.load(file_to_load, map_location=device, weights_only=False)
checkpoint_logging.success(
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
)
if optimizer and "optimizer_state_dict" in checkpoint_dict:
optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"])
checkpoint_logging.success("Loaded optimizer state dictionary")
if scheduler and "scheduler_state_dict" in checkpoint_dict:
scheduler.load_state_dict(checkpoint_dict["scheduler_state_dict"])
checkpoint_logging.success("Loaded scheduler state dictionary")
if scaler and "scaler_state_dict" in checkpoint_dict:
scaler.load_state_dict(checkpoint_dict["scaler_state_dict"])
checkpoint_logging.success("Loaded grad scaler state dictionary")
if "static_capture_state_dict" in checkpoint_dict:
_StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"])
checkpoint_logging.success("Loaded static capture state dictionary")
loaded_epoch = 0
if "epoch" in checkpoint_dict:
loaded_epoch = checkpoint_dict["epoch"]
if metadata_dict is not None:
metadata_dict.update(checkpoint_dict.get("metadata", {}))
return loaded_epoch
[docs]
def load_model_weights(
model: torch.nn.Module,
weights_path: str,
device: str | torch.device = "cpu",
) -> None:
r"""Load model weights from a single checkpoint file.
Loads a ``.mdlus`` (or ``.pt``) file directly into *model*, handling
FSDP and DTensor/ShardTensor distribution automatically. Unlike
:func:`load_checkpoint` (which expects a checkpoint *directory* with
numbered files), this function accepts a path to a single file.
When the model is FSDP-wrapped or has DTensor parameters this is a
**collective** operation — all ranks must call it. Rank 0 reads the
file and state is scattered via DCP helpers.
Parameters
----------
model : torch.nn.Module
The model to load weights into. May be FSDP-wrapped, contain
DTensor/ShardTensor parameters, or be a plain module.
weights_path : str
Path to a ``.mdlus`` or ``.pt`` checkpoint file (local path or
``fsspec`` URI).
device : str | torch.device, optional
Device for :func:`torch.load` ``map_location``. By default
``"cpu"``.
"""
model = _unwrap_ddp_compile(model, loading=True)
is_mdlus = _is_mdlus_archive(weights_path)
if not _is_distributed_model(model):
inner = _unwrap_fsdp(model)
if is_mdlus and isinstance(inner, physicsnemo.core.Module):
inner.load(weights_path)
else:
cached = _cache_if_needed(weights_path)
if is_mdlus:
sd = _extract_mdlus_state_dict(weights_path, device)
else:
sd = torch.load(cached, map_location=device, weights_only=False)
inner.load_state_dict(sd)
checkpoint_logging.success(f"Loaded model weights from {weights_path}")
return
if not DistributedManager.is_initialized():
DistributedManager.initialize()
is_rank0 = DistributedManager().rank == 0
state_dict: dict[str, Any] = {}
if is_rank0:
if is_mdlus:
state_dict = _extract_mdlus_state_dict(weights_path, device)
else:
cached = _cache_if_needed(weights_path)
state_dict = torch.load(cached, map_location=device, weights_only=False)
dtensor_plc = _get_dtensor_param_placements(model)
if _needs_dcp_broadcast_bypass(model, dtensor_plc):
sd_list: list[Any] = [state_dict]
torch.distributed.broadcast_object_list(sd_list, src=0)
state_dict = _redistribute_sd_for_dtensor(dtensor_plc, sd_list[0])
options = StateDictOptions(full_state_dict=True)
else:
options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
set_model_state_dict(model, state_dict, options=options)
checkpoint_logging.success(f"Loaded model weights from {weights_path}")
# ------------------------------------------------------------------
# Distributed load implementation
# ------------------------------------------------------------------
def _load_checkpoint_distributed(
*,
path: str,
fs: fsspec.AbstractFileSystem,
named_models: dict[str, torch.nn.Module],
optimizer: torch.optim.Optimizer | None,
scheduler: LRScheduler | None,
scaler: GradScaler | None,
epoch: int | None,
metadata_dict: dict[str, Any] | None,
device: str | torch.device,
optimizer_model: torch.nn.Module | None,
is_rank0: bool,
) -> int:
"""
Distributed load: rank 0 reads files, DCP broadcasts to all ranks.
"""
broadcast_options = StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True
)
full_options = StateDictOptions(full_state_dict=True)
# --- Rank 0 checks directory existence and loads raw data -----------
dir_exists = fs.exists(path) and not fs.isfile(path) if is_rank0 else None
flags: list[Any] = [dir_exists]
torch.distributed.broadcast_object_list(flags, src=0)
dir_exists = flags[0]
if not dir_exists:
checkpoint_logging.warning(
f"Provided checkpoint directory {path} does not exist, skipping load"
)
return 0
# --- Load model checkpoints -----------------------------------------
# Rank 0: determine which model files exist and load their state dicts
model_file_info: dict[str, str | None] = {}
model_state_dicts: dict[str, dict[str, Any]] = {}
if is_rank0:
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path,
name,
index=epoch,
model_type=model_type,
distributed=True,
)
if fs.exists(file_name):
model_file_info[name] = file_name
if isinstance(inner, physicsnemo.core.Module):
model_state_dicts[name] = _extract_mdlus_state_dict(
file_name, device
)
else:
file_to_load = _cache_if_needed(file_name)
model_state_dicts[name] = torch.load(
file_to_load,
map_location=device,
weights_only=False,
)
else:
model_file_info[name] = None
# Broadcast which model files were found
info_list: list[Any] = [model_file_info]
torch.distributed.broadcast_object_list(info_list, src=0)
model_file_info = info_list[0]
# Distribute model state dicts via DCP
for name, model in named_models.items():
if model_file_info.get(name) is None:
checkpoint_logging.error(
f"Could not find valid model file for {name}, skipping load"
)
continue
if _is_distributed_model(model):
# Collective: inspect native state dict for DTensor placements.
# This is needed because use_orig_params=False flattens DTensor
# params into a plain FlatParameter, hiding them from inspection.
dtensor_plc = _get_dtensor_param_placements(model)
if _needs_dcp_broadcast_bypass(model, dtensor_plc):
# broadcast_from_rank0 does not handle user-managed DTensor
# redistribution (e.g. ShardTensor on a domain mesh), so we
# broadcast the full state dict ourselves and convert entries
# to DTensors.
sd_list: list[Any] = [
model_state_dicts.get(name, {}) if is_rank0 else {}
]
torch.distributed.broadcast_object_list(sd_list, src=0)
sd = _redistribute_sd_for_dtensor(dtensor_plc, sd_list[0])
set_model_state_dict(model, sd, options=full_options)
elif isinstance(model, FSDPModule):
# FSDP2 (fully_shard): DCP's broadcast_from_rank0 can hang
# for FSDPModule on multi-rank meshes. Broadcast the full
# state dict explicitly and let DCP handle the DTensor
# sharding locally on each rank via full_state_dict=True.
sd_list = [model_state_dicts.get(name, {}) if is_rank0 else {}]
torch.distributed.broadcast_object_list(sd_list, src=0)
sd = _force_standard_contiguous(sd_list[0])
set_model_state_dict(model, sd, options=full_options)
else:
# FSDP-managed DTensors (FULL_SHARD/SHARD_GRAD_OP) or no
# DTensors at all — broadcast_from_rank0 handles both. Force
# standard contiguity on rank 0 first so the per-tensor
# broadcast inside DCP doesn't permute channels_last params on
# receive (see ``_force_standard_contiguous`` for the why).
sd = model_state_dicts.get(name, {}) if is_rank0 else {}
if is_rank0:
sd = _force_standard_contiguous(sd)
set_model_state_dict(model, sd, options=broadcast_options)
else:
# A mix of distributed and non-distributed models is valid
# (e.g. a main FSDP model alongside a small auxiliary model).
sd_list = [model_state_dicts.get(name, {}) if is_rank0 else {}]
torch.distributed.broadcast_object_list(sd_list, src=0)
inner = _unwrap_fsdp(model)
inner.load_state_dict(sd_list[0])
file_name = model_file_info[name]
checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
# --- Load training checkpoint ---------------------------------------
checkpoint_filename = _get_checkpoint_filename(
path, index=epoch, model_type="pt", distributed=True
)
# Broadcast file existence so all ranks agree on whether to enter the
# (collective) optimizer load. Without this, a rundir that has model
# weights but no training checkpoint -- e.g. fine-tuning from a
# weights-only export -- would have rank 0 enter ``set_optimizer_state_dict``
# with an empty dict and trip the "missing 'state'" error inside DCP.
ckpt_exists = fs.exists(checkpoint_filename) if is_rank0 else None
ckpt_flags: list[Any] = [ckpt_exists]
torch.distributed.broadcast_object_list(ckpt_flags, src=0)
ckpt_exists = ckpt_flags[0]
if not ckpt_exists:
checkpoint_logging.warning(
f"No training checkpoint at {checkpoint_filename}; "
"skipping optimizer/scheduler/scaler load"
)
return 0
checkpoint_dict: dict[str, Any] = {}
if is_rank0:
file_to_load = _cache_if_needed(checkpoint_filename)
checkpoint_dict = torch.load(
file_to_load, map_location=device, weights_only=False
)
checkpoint_logging.success(
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
)
# Optimizer state via DCP (collective)
if optimizer:
opt_model = optimizer_model or next(
(m for m in named_models.values() if _is_distributed_model(m)),
None,
)
optim_sd = checkpoint_dict.get("optimizer_state_dict", {}) if is_rank0 else {}
if opt_model is not None and _is_distributed_model(opt_model):
dtensor_plc = _get_dtensor_param_placements(opt_model)
if isinstance(opt_model, FSDPModule) and _needs_dcp_broadcast_bypass(
opt_model, dtensor_plc
):
# FSDP2 with a degenerate mesh axis (e.g. ddp=1, domain=2):
# broadcast_from_rank0 hangs on degenerate axes, so broadcast
# manually and use full_state_dict=True (no broadcast_from_rank0).
osd_list: list[Any] = [optim_sd]
torch.distributed.broadcast_object_list(osd_list, src=0)
optim_sd = osd_list[0]
optim_sd = _remap_channels_last_optim_sd(opt_model, optim_sd)
# Pre-populate live optimizer ``state[p]`` so DCP's
# flatten/unflatten round-trip has the ``state.X.*`` keys
# the checkpoint provides. Without this, a freshly
# constructed optimizer (which has empty ``state``) trips
# ``KeyError: 'state.0.step'`` inside DCP's
# ``_unflatten_state_dict``. The placeholders are
# overwritten by the following ``set_optimizer_state_dict``.
_materialize_optimizer_state_for_dcp(
optimizer, optim_sd.get("state", {})
)
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=full_options
)
elif isinstance(opt_model, FSDPModule):
# FSDP2 with a fully non-degenerate mesh (e.g. ddp=2, domain=2):
# use broadcast_from_rank0 which redistributes shards correctly.
optim_sd = _remap_channels_last_optim_sd(opt_model, optim_sd)
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=broadcast_options
)
elif _needs_dcp_broadcast_bypass(opt_model, dtensor_plc):
# FSDP1 NO_SHARD / plain DTensor: redistribute full tensors to
# per-rank local shards before loading.
osd_list = [optim_sd]
torch.distributed.broadcast_object_list(osd_list, src=0)
optim_sd = _redistribute_optim_sd_for_dtensor(dtensor_plc, osd_list[0])
optim_sd = _remap_channels_last_optim_sd(opt_model, optim_sd)
_materialize_optimizer_state_for_dcp(
optimizer, optim_sd.get("state", {})
)
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=full_options
)
else:
# Remap on rank 0 only -- DCP broadcasts the rank-0 dict to
# the others as part of broadcast_from_rank0.
optim_sd = _remap_channels_last_optim_sd(opt_model, optim_sd)
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=broadcast_options
)
checkpoint_logging.success("Loaded optimizer state dictionary")
elif optim_sd:
optimizer.load_state_dict(optim_sd)
checkpoint_logging.success("Loaded optimizer state dictionary")
# Broadcast remaining training state (scheduler, scaler, epoch, metadata)
rest: dict[str, Any] = {}
if is_rank0:
rest = {k: v for k, v in checkpoint_dict.items() if k != "optimizer_state_dict"}
rest_list: list[Any] = [rest]
torch.distributed.broadcast_object_list(rest_list, src=0)
rest = rest_list[0]
if scheduler and "scheduler_state_dict" in rest:
scheduler.load_state_dict(rest["scheduler_state_dict"])
checkpoint_logging.success("Loaded scheduler state dictionary")
if scaler and "scaler_state_dict" in rest:
scaler.load_state_dict(rest["scaler_state_dict"])
checkpoint_logging.success("Loaded grad scaler state dictionary")
if "static_capture_state_dict" in rest:
_StaticCapture.load_state_dict(rest["static_capture_state_dict"])
checkpoint_logging.success("Loaded static capture state dictionary")
loaded_epoch = rest.get("epoch", 0)
if metadata_dict is not None:
metadata_dict.update(rest.get("metadata", {}))
return loaded_epoch
[docs]
def get_checkpoint_dir(base_dir: Path | str, model_name: str) -> str:
r"""Build a model-specific checkpoint directory path.
Returns ``"{base_dir}/checkpoints_{model_name}"``, handling both local
paths and ``fsspec`` URIs (e.g. ``msc://``). Always uses ``/`` as the
appended separator so the result is identical across operating systems
and remains a valid URI when ``base_dir`` is a remote scheme. This
matches the path convention used elsewhere in this module (see e.g.
:func:`_get_checkpoint_filename`).
Parameters
----------
base_dir : Path | str
Root directory under which the checkpoint subdirectory is placed.
Any trailing ``/`` or ``\\`` is stripped before concatenation, so
``"foo"``, ``"foo/"``, and (on Windows) ``"foo\\"`` all behave
identically.
model_name : str
Model name used as the directory suffix.
Returns
-------
str
Full path to the checkpoint directory, always joined with ``/``.
"""
base_dir = str(base_dir).rstrip("/\\")
return f"{base_dir}/checkpoints_{model_name}"
def _cache_if_needed(path: str) -> str:
r"""Return a local path for ``path``, downloading to cache if remote.
For the ``"file"`` protocol the path is returned unchanged. For remote
protocols the file is fetched via
:func:`~physicsnemo.core.filesystem._download_cached` into a
process-specific cache directory.
Parameters
----------
path : str
Checkpoint file path (local or ``fsspec`` URI).
Returns
-------
str
Local filesystem path suitable for :func:`torch.load`.
"""
protocol = fsspec.utils.get_protocol(path)
if protocol == "file":
return path
else:
return _download_cached(
path,
recursive=False,
local_cache_path=os.path.join(LOCAL_CACHE, f"checkpoint_pid_{os.getpid()}"),
)