# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Checkpoint utilities for saving and loading training state.
Provides :func:`save_checkpoint` and :func:`load_checkpoint` for persisting
and restoring model weights, optimizer/scheduler/scaler state, and arbitrary
metadata. Supports local filesystems and remote stores via ``fsspec``.
When models are wrapped with FSDP or use DTensor/ShardTensor parameters,
:func:`save_checkpoint` and :func:`load_checkpoint` automatically use
PyTorch's distributed checkpoint state-dict APIs to gather and scatter
model and optimizer state. In this *distributed* mode **all ranks** must
call the functions (the collective operations inside the DCP helpers require
it), while only rank 0 performs actual file I/O.
"""
import io
import os
import re
import tarfile
import zipfile
from pathlib import Path, PurePath
from typing import Any
import fsspec
import fsspec.utils
import torch
from torch.amp import GradScaler
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
from torch.distributed.tensor import DTensor, distribute_tensor
from torch.optim.lr_scheduler import LRScheduler
import physicsnemo
from physicsnemo.core.filesystem import LOCAL_CACHE, _download_cached
from physicsnemo.core.module import Module
from physicsnemo.distributed import DistributedManager
from physicsnemo.utils.capture import _StaticCapture
from physicsnemo.utils.logging import PythonLogger
checkpoint_logging = PythonLogger("checkpoint")
# ---------------------------------------------------------------------------
# Distributed-model detection helpers
# ---------------------------------------------------------------------------
def _is_distributed_model(model: torch.nn.Module) -> bool:
"""Return ``True`` when *model* is FSDP-wrapped or has DTensor params."""
if isinstance(model, FSDP):
return True
return any(isinstance(p, DTensor) for p in model.parameters())
def _unwrap_ddp_compile(
model: torch.nn.Module, loading: bool = False
) -> torch.nn.Module:
"""Strip DDP / DataParallel / ``torch.compile`` wrappers, keep FSDP."""
if isinstance(
model,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel),
):
model = model.module
if isinstance(model, torch._dynamo.eval_frame.OptimizedModule):
if loading:
checkpoint_logging.warning(
f"Model {type(model._orig_mod).__name__} is already compiled, "
"consider loading first and then compiling."
)
model = model._orig_mod
return model
def _unwrap_fsdp(model: torch.nn.Module) -> torch.nn.Module:
"""Unwrap one FSDP layer (if present) to reach the user module."""
if isinstance(model, FSDP):
return model.module
return model
def _cpu_offload_state_dict(state_dict: dict[str, Any]) -> dict[str, Any]:
"""Move every tensor in *state_dict* to CPU (shallow copy)."""
out: dict[str, Any] = {}
for k, v in state_dict.items():
if isinstance(v, torch.Tensor):
out[k] = v.cpu()
elif isinstance(v, dict):
out[k] = _cpu_offload_state_dict(v)
else:
out[k] = v
return out
def _get_dtensor_param_placements(
model: torch.nn.Module,
) -> dict[str, tuple[Any, tuple[Any, ...]]]:
"""Map parameter names to ``(device_mesh, placements)`` for DTensor params.
Uses ``get_model_state_dict`` with native (non-full) format so that the
DCP layer unflattens FlatParameters back to original names and preserves
each parameter's DTensor placement. This works correctly for both
``use_orig_params=True`` and ``use_orig_params=False``.
**Collective** — all ranks must call this together.
"""
native_sd = get_model_state_dict(model, options=StateDictOptions())
info: dict[str, tuple[Any, tuple[Any, ...]]] = {}
for name, value in native_sd.items():
if isinstance(value, DTensor):
info[name] = (value.device_mesh, tuple(value.placements))
return info
def _has_non_fsdp_dtensors(
model: torch.nn.Module,
dtensor_plc: dict[str, tuple[Any, tuple[Any, ...]]],
) -> bool:
"""Return ``True`` when *dtensor_plc* contains placements not managed by FSDP.
FSDP with ``FULL_SHARD`` or ``SHARD_GRAD_OP`` wraps parameters as
DTensors on its own mesh. ``broadcast_from_rank0`` handles these
natively, so manual redistribution should be skipped. Only
user-created DTensors (e.g. ShardTensor on a separate domain mesh)
require explicit redistribution.
"""
if not dtensor_plc:
return False
if not isinstance(model, FSDP):
return True
if model.sharding_strategy == ShardingStrategy.NO_SHARD:
return True
return False
def _redistribute_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
state_dict: dict[str, Any],
) -> dict[str, Any]:
"""Convert plain tensors in *state_dict* to DTensors matching *placements*.
Entries whose key appears in *placements* are converted via
``distribute_tensor`` so that each rank receives its correct local shard.
"""
target_device = next(iter(placements.values()))[0].device_type
out: dict[str, Any] = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor) or isinstance(value, DTensor):
out[key] = value
continue
if key in placements:
mesh, plc = placements[key]
out[key] = distribute_tensor(value.to(mesh.device_type), mesh, list(plc))
else:
out[key] = value.to(target_device)
return out
def _redistribute_optim_sd_for_dtensor(
placements: dict[str, tuple[Any, tuple[Any, ...]]],
optim_sd: dict[str, Any],
) -> dict[str, Any]:
"""Shard optimizer state tensors to local chunks matching model placements.
FSDP's ``optim_state_dict_to_load`` expects each optimizer state tensor
(``exp_avg``, ``exp_avg_sq``, …) to be a **plain tensor** whose shape
matches the parameter's *local* shape — not a DTensor. We use
``distribute_tensor(...).to_local()`` to extract each rank's shard.
Scalar state entries (e.g. ``step``) are left unchanged.
"""
if "state" not in optim_sd:
return optim_sd
target_device = next(iter(placements.values()))[0].device_type
new_state: dict[str, Any] = {}
for param_name, param_state in optim_sd["state"].items():
if not isinstance(param_state, dict):
new_state[param_name] = param_state
continue
new_ps: dict[str, Any] = {}
mesh_plc = placements.get(param_name)
for k, v in param_state.items():
if (
not isinstance(v, torch.Tensor)
or isinstance(v, DTensor)
or v.dim() == 0
):
new_ps[k] = v
elif mesh_plc is not None:
mesh, plc = mesh_plc
new_ps[k] = distribute_tensor(
v.to(mesh.device_type), mesh, list(plc)
).to_local()
else:
new_ps[k] = v.to(target_device)
new_state[param_name] = new_ps
return {**optim_sd, "state": new_state}
def _is_mdlus_archive(path: str) -> bool:
"""Return ``True`` if *path* is a ``.mdlus`` archive (tar or zip containing ``model.pt``)."""
cached = _cache_if_needed(path)
if tarfile.is_tarfile(cached):
with tarfile.open(cached, "r") as tar:
return "model.pt" in tar.getnames()
if zipfile.is_zipfile(cached):
with zipfile.ZipFile(cached, "r") as archive:
return "model.pt" in archive.namelist()
return False
def _extract_mdlus_state_dict(
file_name: str, device: str | torch.device = "cpu"
) -> dict[str, Any]:
"""Read only the ``state_dict`` from a ``.mdlus`` archive."""
cached = _cache_if_needed(file_name)
fmt = Module._detect_checkpoint_format(cached)
if fmt == "tar":
with tarfile.open(cached, "r") as tar:
f = tar.extractfile("model.pt")
return torch.load(
io.BytesIO(f.read()), map_location=device, weights_only=False
)
else:
with zipfile.ZipFile(cached, "r") as archive:
model_bytes = archive.read("model.pt")
return torch.load(
io.BytesIO(model_bytes), map_location=device, weights_only=False
)
def _get_checkpoint_filename(
path: str,
base_name: str = "checkpoint",
index: int | None = None,
saving: bool = False,
model_type: str = "mdlus",
distributed: bool = False,
) -> str:
r"""Build the filename for a numbered checkpoint.
Resolution logic:
* **Explicit index** (``index`` is not ``None``): returns that exact
checkpoint path.
* **Latest** (``index is None``, ``saving=False``): scans for existing
checkpoints and returns the one with the largest index.
* **Next** (``index is None``, ``saving=True``): returns the path for
the *next* index after the largest existing one.
When no existing checkpoints are found, the returned path uses index 0.
Parameters
----------
path : str
Directory containing checkpoint files.
base_name : str, optional
Stem used in the filename, by default ``"checkpoint"``.
index : int | None, optional
Specific checkpoint index to use. When ``None``, the latest or
next index is determined automatically.
saving : bool, optional
If ``True`` (and ``index is None``), return the *next* available
filename rather than the latest existing one. By default ``False``.
model_type : str, optional
``"mdlus"`` for :class:`~physicsnemo.core.Module` models,
``"pt"`` for vanilla PyTorch models. Determines the file
extension. By default ``"mdlus"``.
distributed : bool, optional
When ``True`` the model_parallel_rank component of the filename is
forced to ``0`` because FSDP/DTensor distribution is handled by the
DCP APIs, not per-rank files. By default ``False``.
Returns
-------
str
Fully qualified checkpoint filename.
"""
# Get model parallel rank so all processes in the first model parallel group
# can save their checkpoint. In the case without model parallelism,
# model_parallel_rank should be the same as the process rank itself and
# only rank 0 saves
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
if distributed:
model_parallel_rank = 0
else:
model_parallel_rank = (
manager.group_rank("model_parallel")
if "model_parallel" in manager.group_names
else 0
)
# Determine input file name. Get absolute file path if Posix path.
# pathlib does not support custom schemes (eg: msc://...) so only perform resolve() for Posix.
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
if protocol == "file":
path = str(Path(path).resolve())
checkpoint_filename = f"{path}/{base_name}.{model_parallel_rank}"
# File extension for PhysicsNeMo models or PyTorch models
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"
# If epoch is provided load that file
if index is not None:
checkpoint_filename = checkpoint_filename + f".{index}"
checkpoint_filename += file_extension
# Otherwise try loading the latest epoch or rolling checkpoint
else:
file_names = [
fname for fname in fs.glob(checkpoint_filename + "*" + file_extension)
]
if len(file_names) > 0:
# If checkpoint from a null index save exists load that
# This is the most likely line to error since it will fail with
# invalid checkpoint names
file_idx = []
for fname in file_names:
fname_path = PurePath(fname)
file_stem = fname_path.name
pattern = rf"^{re.escape(base_name)}\.{model_parallel_rank}\.(\d+){re.escape(file_extension)}$"
match = re.match(pattern, file_stem)
if match:
file_idx.append(int(match.group(1)))
file_idx.sort()
# If we are saving index by 1 to get the next free file name
if saving:
checkpoint_filename = checkpoint_filename + f".{file_idx[-1] + 1}"
else:
checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}"
checkpoint_filename += file_extension
else:
checkpoint_filename += ".0" + file_extension
return checkpoint_filename
def _unique_model_names(
models: list[torch.nn.Module],
loading: bool = False,
) -> dict[str, torch.nn.Module]:
r"""Map a list of models to unique names derived from their class names.
DDP and ``torch.compile`` wrappers are stripped, but FSDP wrappers are
preserved so that the returned modules can be passed to PyTorch's DCP
state-dict helpers when needed.
When multiple models share a class name a numeric suffix is appended
(e.g. ``"MyModel0"``, ``"MyModel1"``).
Parameters
----------
models : list[torch.nn.Module]
Models to generate names for.
loading : bool, optional
When ``True``, emits a warning if a model is already compiled
(loading into a compiled model can cause issues). By default
``False``.
Returns
-------
dict[str, torch.nn.Module]
Mapping from unique name to module (with FSDP intact if present).
"""
model_dict: dict[str, list[torch.nn.Module]] = {}
for model0 in models:
model0 = _unwrap_ddp_compile(model0, loading=loading)
base_name = type(_unwrap_fsdp(model0)).__name__
if base_name in model_dict:
model_dict[base_name].append(model0)
else:
model_dict[base_name] = [model0]
output_dict: dict[str, torch.nn.Module] = {}
for key, model_list in model_dict.items():
if len(model_list) > 1:
for i, m in enumerate(model_list):
output_dict[key + str(i)] = m
else:
output_dict[key] = model_list[0]
return output_dict
[docs]
def save_checkpoint(
path: Path | str,
models: torch.nn.Module | list[torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: LRScheduler | None = None,
scaler: GradScaler | None = None,
epoch: int | None = None,
metadata: dict[str, Any] | None = None,
optimizer_model: torch.nn.Module | None = None,
) -> None:
r"""Save a training checkpoint to disk (or a remote store).
Up to two categories of files are created inside ``path``:
* **Model weights** (when ``models`` is provided) - one file per model:
``{class_name}{id}.{mp_rank}.{epoch}.{ext}`` where *ext* is
``.mdlus`` for :class:`~physicsnemo.core.Module` instances or
``.pt`` for plain PyTorch models. When several models share a class
name, a numeric *id* is appended (``"MyModel0"``, ``"MyModel1"``).
* **Training state** (when any of ``optimizer`` / ``scheduler`` /
``scaler`` is provided, or
:class:`~physicsnemo.utils.capture._StaticCapture` scalers exist):
``checkpoint.{mp_rank}.{epoch}.pt`` containing their combined
``state_dict`` entries, plus ``epoch`` and ``metadata``.
When any model is FSDP-wrapped or contains DTensor/ShardTensor
parameters the function enters *distributed* mode: all ranks **must**
call it, state is gathered via DCP collective helpers, and only rank 0
writes files.
Use :func:`load_checkpoint` to restore from these files.
To instantiate *and* load a model in one step (without pre-constructing
it), use :meth:`~physicsnemo.core.module.Module.from_checkpoint`.
Parameters
----------
path : Path | str
Directory in which to store checkpoint files. Created
automatically for local paths if it does not exist.
models : torch.nn.Module | list[torch.nn.Module] | None, optional
Model(s) whose weights should be saved.
optimizer : torch.optim.Optimizer | None, optional
Optimizer whose ``state_dict`` should be saved.
scheduler : LRScheduler | None, optional
Learning-rate scheduler whose ``state_dict`` should be saved.
scaler : GradScaler | None, optional
AMP gradient scaler whose ``state_dict`` should be saved.
If ``None`` but a
:class:`~physicsnemo.utils.capture._StaticCapture` scaler exists,
that scaler's state is saved instead.
epoch : int | None, optional
Epoch index to embed in the filename and the checkpoint dict.
When ``None``, the next available index is used.
metadata : dict[str, Any] | None, optional
Arbitrary key-value pairs persisted alongside the training state
(e.g. best validation loss, MLflow run ID).
optimizer_model : torch.nn.Module | None, optional
The model whose parameters the ``optimizer`` is tracking so that
parameter unsharding of optimizer state can be performed correctly.
Only required when multiple models are provided, and at least one of
them is a distributed model (FSDP/ShardTensor). When ``None``, the
first model in ``models`` is used. Ignored when *not* in distributed
mode.
Examples
--------
Save a model together with optimizer and scheduler state:
>>> import tempfile, os, torch
>>> from physicsnemo.utils.checkpoint import save_checkpoint
>>> from physicsnemo.models.mlp import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler, epoch=1)
... sorted(f for f in os.listdir(tmpdir))
['FullyConnected.0.1.mdlus', 'checkpoint.0.1.pt']
Save at multiple epochs with additional metadata:
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=1,
... metadata={"loss": 0.42, "experiment": "run_01"})
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=2,
... metadata={"loss": 0.31, "experiment": "run_01"})
... sorted(f for f in os.listdir(tmpdir))
['FullyConnected.0.1.mdlus', 'FullyConnected.0.2.mdlus', 'checkpoint.0.1.pt', 'checkpoint.0.2.pt']
"""
path = str(path)
protocol = fsspec.utils.get_protocol(path)
fs = fsspec.filesystem(protocol)
# Prepare models and detect distributed mode
named_models: dict[str, torch.nn.Module] = {}
is_distributed = False
if models:
if not isinstance(models, list):
models = [models]
named_models = _unique_model_names(models)
is_distributed = any(_is_distributed_model(m) for m in named_models.values())
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. "
"Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
is_rank0 = manager.rank == 0
should_write = is_rank0 if is_distributed else True
# Create checkpoint directory (only writing rank)
if should_write and protocol == "file" and not Path(path).is_dir():
checkpoint_logging.warning(
f"Output directory {path} does not exist, will attempt to create"
)
Path(path).mkdir(parents=True, exist_ok=True)
if is_distributed:
torch.distributed.barrier()
# == Saving model checkpoint ==
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path,
name,
index=epoch,
saving=True,
model_type=model_type,
distributed=is_distributed,
)
if _is_distributed_model(model):
# cpu_offload is handled manually because the DCP option
# hangs for FSDP NO_SHARD + DTensor topologies.
options = StateDictOptions(full_state_dict=True)
state_dict = get_model_state_dict(model, options=options)
if should_write:
state_dict = _cpu_offload_state_dict(state_dict)
if isinstance(inner, physicsnemo.core.Module):
inner.save(file_name, _state_dict=state_dict)
else:
with fs.open(file_name, "wb") as fp:
torch.save(state_dict, fp)
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")
else:
if should_write:
if isinstance(inner, physicsnemo.core.Module):
inner.save(file_name)
else:
with fs.open(file_name, "wb") as fp:
torch.save(model.state_dict(), fp)
checkpoint_logging.success(f"Saved model state dictionary: {file_name}")
# == Saving training checkpoint ==
checkpoint_dict: dict[str, Any] = {}
if optimizer:
if is_distributed:
opt_model = optimizer_model or next(
(m for m in named_models.values() if _is_distributed_model(m)),
None,
)
if opt_model is not None:
# cpu_offload is handled manually because the DCP option
# hangs for FSDP NO_SHARD + DTensor topologies.
options = StateDictOptions(full_state_dict=True)
opt_state_dict = get_optimizer_state_dict(
opt_model, optimizer, options=options
)
if should_write:
opt_state_dict = _cpu_offload_state_dict(opt_state_dict)
else:
opt_state_dict = optimizer.state_dict()
else:
opt_state_dict = optimizer.state_dict()
# Strip out torch dynamo wrapper prefix
for pg in opt_state_dict.get("param_groups", []):
param_names = pg.get("param_names")
if param_names is None:
continue
pg["param_names"] = [pn.removeprefix("_orig_mod.") for pn in param_names]
checkpoint_dict["optimizer_state_dict"] = opt_state_dict
if scheduler:
checkpoint_dict["scheduler_state_dict"] = scheduler.state_dict()
if scaler:
checkpoint_dict["scaler_state_dict"] = scaler.state_dict()
if _StaticCapture._amp_scalers:
checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict()
output_filename = _get_checkpoint_filename(
path,
index=epoch,
saving=True,
model_type="pt",
distributed=is_distributed,
)
if epoch:
checkpoint_dict["epoch"] = epoch
if metadata:
checkpoint_dict["metadata"] = metadata
if bool(checkpoint_dict) and should_write:
with fs.open(output_filename, "wb") as fp:
torch.save(checkpoint_dict, fp)
checkpoint_logging.success(f"Saved training checkpoint: {output_filename}")
[docs]
def load_checkpoint(
path: Path | str,
models: torch.nn.Module | list[torch.nn.Module] | None = None,
optimizer: torch.optim.Optimizer | None = None,
scheduler: LRScheduler | None = None,
scaler: GradScaler | None = None,
epoch: int | None = None,
metadata_dict: dict[str, Any] | None = None,
device: str | torch.device = "cpu",
optimizer_model: torch.nn.Module | None = None,
) -> int:
r"""Load a training checkpoint saved by :func:`save_checkpoint`.
Scans ``path`` for checkpoint files and restores state dictionaries
into the provided training objects. Objects that are ``None`` are
silently skipped.
When any model is FSDP-wrapped or contains DTensor/ShardTensor
parameters the function enters *distributed* mode: all ranks **must**
call it, rank 0 reads files from disk, and model/optimizer state is
scattered to all ranks via DCP helpers.
Parameters
----------
path : Path | str
Directory containing checkpoint files (local path or ``fsspec``
URI). If the directory does not exist, the load is skipped and
``0`` is returned.
models : torch.nn.Module | list[torch.nn.Module] | None, optional
Model(s) whose ``state_dict`` should be restored. DDP and
``torch.compile`` wrappers are stripped automatically.
optimizer : torch.optim.Optimizer | None, optional
Optimizer whose ``state_dict`` should be restored.
scheduler : LRScheduler | None, optional
Learning-rate scheduler whose ``state_dict`` should be restored.
scaler : GradScaler | None, optional
AMP gradient scaler whose ``state_dict`` should be restored.
epoch : int | None, optional
Specific checkpoint index to load. When ``None``, the checkpoint
with the largest index (most recent) is loaded.
metadata_dict : dict[str, Any] | None, optional
If a ``dict`` is provided, it is updated **in-place** with any
metadata that was persisted by :func:`save_checkpoint`.
device : str | torch.device, optional
Device onto which tensors are mapped during loading. By default
``"cpu"``.
optimizer_model : torch.nn.Module | None, optional
The model whose parameters the ``optimizer`` is tracking.
Required by the DCP ``set_optimizer_state_dict`` helper when
distributed mode is active. When ``None``, the first model in
``models`` is used. Ignored when *not* in distributed mode.
Returns
-------
int
The epoch stored in the checkpoint. Returns ``0`` when:
* The checkpoint directory does not exist.
* No training-state file is found inside the directory.
* The training-state file does not contain an ``"epoch"`` key.
Examples
--------
Save and then restore a model, optimizer, and scheduler from a checkpoint:
>>> import tempfile, torch
>>> from physicsnemo.utils.checkpoint import save_checkpoint, load_checkpoint
>>> from physicsnemo.models.mlp import FullyConnected
>>> model = FullyConnected(in_features=32, out_features=64)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
>>> scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler, epoch=1)
... epoch = load_checkpoint(tmpdir, models=model, optimizer=optimizer,
... scheduler=scheduler)
... epoch
1
Load a specific epoch and retrieve saved metadata:
>>> with tempfile.TemporaryDirectory() as tmpdir:
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=1,
... metadata={"loss": 0.42, "experiment": "run_01"})
... save_checkpoint(tmpdir, models=model, optimizer=optimizer, epoch=2,
... metadata={"loss": 0.31, "experiment": "run_01"})
... meta = {}
... epoch = load_checkpoint(tmpdir, models=model, optimizer=optimizer,
... epoch=1, metadata_dict=meta)
... epoch
1
>>> meta["loss"]
0.42
"""
path = str(path)
fs = fsspec.filesystem(fsspec.utils.get_protocol(path))
# Prepare models and detect distributed mode
named_models: dict[str, torch.nn.Module] = {}
is_distributed = False
if models:
if not isinstance(models, list):
models = [models]
named_models = _unique_model_names(models, loading=True)
is_distributed = any(_is_distributed_model(m) for m in named_models.values())
if not DistributedManager.is_initialized():
checkpoint_logging.warning(
"`DistributedManager` not initialized already. "
"Initializing now, but this might lead to unexpected errors"
)
DistributedManager.initialize()
manager = DistributedManager()
is_rank0 = manager.rank == 0
# ------------------------------------------------------------------
# Distributed load path -- all ranks participate
# ------------------------------------------------------------------
if is_distributed:
return _load_checkpoint_distributed(
path=path,
fs=fs,
named_models=named_models,
optimizer=optimizer,
scheduler=scheduler,
scaler=scaler,
epoch=epoch,
metadata_dict=metadata_dict,
device=device,
optimizer_model=optimizer_model,
is_rank0=is_rank0,
)
# ------------------------------------------------------------------
# Non-distributed load path
# ------------------------------------------------------------------
if fs.exists(path):
if fs.isfile(path):
raise FileNotFoundError(
f"Provided checkpoint directory {path} is a file, not directory"
)
else:
checkpoint_logging.warning(
f"Provided checkpoint directory {path} does not exist, skipping load"
)
return 0
# == Loading model checkpoint ==
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path, name, index=epoch, model_type=model_type
)
if not fs.exists(file_name):
checkpoint_logging.error(
f"Could not find valid model file {file_name}, skipping load"
)
continue
if isinstance(inner, physicsnemo.core.Module):
inner.load(file_name)
else:
file_to_load = _cache_if_needed(file_name)
missing_keys, unexpected_keys = model.load_state_dict(
torch.load(file_to_load, map_location=device, weights_only=False)
)
if missing_keys:
checkpoint_logging.warning(
f"Missing keys when loading {name}: {missing_keys}"
)
if unexpected_keys:
checkpoint_logging.warning(
f"Unexpected keys when loading {name}: {unexpected_keys}"
)
checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
# == Loading training checkpoint ==
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
if not fs.exists(checkpoint_filename):
checkpoint_logging.warning(
"Could not find valid checkpoint file, skipping load"
)
return 0
file_to_load = _cache_if_needed(checkpoint_filename)
checkpoint_dict = torch.load(file_to_load, map_location=device, weights_only=False)
checkpoint_logging.success(
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
)
if optimizer and "optimizer_state_dict" in checkpoint_dict:
optimizer.load_state_dict(checkpoint_dict["optimizer_state_dict"])
checkpoint_logging.success("Loaded optimizer state dictionary")
if scheduler and "scheduler_state_dict" in checkpoint_dict:
scheduler.load_state_dict(checkpoint_dict["scheduler_state_dict"])
checkpoint_logging.success("Loaded scheduler state dictionary")
if scaler and "scaler_state_dict" in checkpoint_dict:
scaler.load_state_dict(checkpoint_dict["scaler_state_dict"])
checkpoint_logging.success("Loaded grad scaler state dictionary")
if "static_capture_state_dict" in checkpoint_dict:
_StaticCapture.load_state_dict(checkpoint_dict["static_capture_state_dict"])
checkpoint_logging.success("Loaded static capture state dictionary")
loaded_epoch = 0
if "epoch" in checkpoint_dict:
loaded_epoch = checkpoint_dict["epoch"]
if metadata_dict is not None:
metadata_dict.update(checkpoint_dict.get("metadata", {}))
return loaded_epoch
[docs]
def load_model_weights(
model: torch.nn.Module,
weights_path: str,
device: str | torch.device = "cpu",
) -> None:
r"""Load model weights from a single checkpoint file.
Loads a ``.mdlus`` (or ``.pt``) file directly into *model*, handling
FSDP and DTensor/ShardTensor distribution automatically. Unlike
:func:`load_checkpoint` (which expects a checkpoint *directory* with
numbered files), this function accepts a path to a single file.
When the model is FSDP-wrapped or has DTensor parameters this is a
**collective** operation — all ranks must call it. Rank 0 reads the
file and state is scattered via DCP helpers.
Parameters
----------
model : torch.nn.Module
The model to load weights into. May be FSDP-wrapped, contain
DTensor/ShardTensor parameters, or be a plain module.
weights_path : str
Path to a ``.mdlus`` or ``.pt`` checkpoint file (local path or
``fsspec`` URI).
device : str | torch.device, optional
Device for :func:`torch.load` ``map_location``. By default
``"cpu"``.
"""
model = _unwrap_ddp_compile(model, loading=True)
is_mdlus = _is_mdlus_archive(weights_path)
if not _is_distributed_model(model):
inner = _unwrap_fsdp(model)
if is_mdlus and isinstance(inner, physicsnemo.core.Module):
inner.load(weights_path)
else:
cached = _cache_if_needed(weights_path)
if is_mdlus:
sd = _extract_mdlus_state_dict(weights_path, device)
else:
sd = torch.load(cached, map_location=device, weights_only=False)
inner.load_state_dict(sd)
checkpoint_logging.success(f"Loaded model weights from {weights_path}")
return
if not DistributedManager.is_initialized():
DistributedManager.initialize()
is_rank0 = DistributedManager().rank == 0
state_dict: dict[str, Any] = {}
if is_rank0:
if is_mdlus:
state_dict = _extract_mdlus_state_dict(weights_path, device)
else:
cached = _cache_if_needed(weights_path)
state_dict = torch.load(cached, map_location=device, weights_only=False)
dtensor_plc = _get_dtensor_param_placements(model)
if _has_non_fsdp_dtensors(model, dtensor_plc):
sd_list: list[Any] = [state_dict]
torch.distributed.broadcast_object_list(sd_list, src=0)
state_dict = _redistribute_sd_for_dtensor(dtensor_plc, sd_list[0])
options = StateDictOptions(full_state_dict=True)
else:
options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True)
set_model_state_dict(model, state_dict, options=options)
checkpoint_logging.success(f"Loaded model weights from {weights_path}")
# ------------------------------------------------------------------
# Distributed load implementation
# ------------------------------------------------------------------
def _load_checkpoint_distributed(
*,
path: str,
fs: fsspec.AbstractFileSystem,
named_models: dict[str, torch.nn.Module],
optimizer: torch.optim.Optimizer | None,
scheduler: LRScheduler | None,
scaler: GradScaler | None,
epoch: int | None,
metadata_dict: dict[str, Any] | None,
device: str | torch.device,
optimizer_model: torch.nn.Module | None,
is_rank0: bool,
) -> int:
"""Distributed load: rank 0 reads files, DCP broadcasts to all ranks."""
broadcast_options = StateDictOptions(
full_state_dict=True, broadcast_from_rank0=True
)
full_options = StateDictOptions(full_state_dict=True)
# --- Rank 0 checks directory existence and loads raw data -----------
dir_exists = fs.exists(path) and not fs.isfile(path) if is_rank0 else None
flags: list[Any] = [dir_exists]
torch.distributed.broadcast_object_list(flags, src=0)
dir_exists = flags[0]
if not dir_exists:
checkpoint_logging.warning(
f"Provided checkpoint directory {path} does not exist, skipping load"
)
return 0
# --- Load model checkpoints -----------------------------------------
# Rank 0: determine which model files exist and load their state dicts
model_file_info: dict[str, str | None] = {}
model_state_dicts: dict[str, dict[str, Any]] = {}
if is_rank0:
for name, model in named_models.items():
inner = _unwrap_fsdp(model)
model_type = "mdlus" if isinstance(inner, physicsnemo.core.Module) else "pt"
file_name = _get_checkpoint_filename(
path,
name,
index=epoch,
model_type=model_type,
distributed=True,
)
if fs.exists(file_name):
model_file_info[name] = file_name
if isinstance(inner, physicsnemo.core.Module):
model_state_dicts[name] = _extract_mdlus_state_dict(
file_name, device
)
else:
file_to_load = _cache_if_needed(file_name)
model_state_dicts[name] = torch.load(
file_to_load,
map_location=device,
weights_only=False,
)
else:
model_file_info[name] = None
# Broadcast which model files were found
info_list: list[Any] = [model_file_info]
torch.distributed.broadcast_object_list(info_list, src=0)
model_file_info = info_list[0]
# Distribute model state dicts via DCP
for name, model in named_models.items():
if model_file_info.get(name) is None:
checkpoint_logging.error(
f"Could not find valid model file for {name}, skipping load"
)
continue
if _is_distributed_model(model):
# Collective: inspect native state dict for DTensor placements.
# This is needed because use_orig_params=False flattens DTensor
# params into a plain FlatParameter, hiding them from inspection.
dtensor_plc = _get_dtensor_param_placements(model)
if _has_non_fsdp_dtensors(model, dtensor_plc):
# broadcast_from_rank0 does not handle user-managed DTensor
# redistribution (e.g. ShardTensor on a domain mesh), so we
# broadcast the full state dict ourselves and convert entries
# to DTensors.
sd_list: list[Any] = [
model_state_dicts.get(name, {}) if is_rank0 else {}
]
torch.distributed.broadcast_object_list(sd_list, src=0)
sd = _redistribute_sd_for_dtensor(dtensor_plc, sd_list[0])
set_model_state_dict(model, sd, options=full_options)
else:
# FSDP-managed DTensors (FULL_SHARD/SHARD_GRAD_OP) or no
# DTensors at all — broadcast_from_rank0 handles both.
sd = model_state_dicts.get(name, {}) if is_rank0 else {}
set_model_state_dict(model, sd, options=broadcast_options)
else:
# A mix of distributed and non-distributed models is valid
# (e.g. a main FSDP model alongside a small auxiliary model).
sd_list = [model_state_dicts.get(name, {}) if is_rank0 else {}]
torch.distributed.broadcast_object_list(sd_list, src=0)
inner = _unwrap_fsdp(model)
inner.load_state_dict(sd_list[0])
file_name = model_file_info[name]
checkpoint_logging.success(
f"Loaded model state dictionary {file_name} to device {device}"
)
# --- Load training checkpoint ---------------------------------------
checkpoint_filename = _get_checkpoint_filename(
path, index=epoch, model_type="pt", distributed=True
)
checkpoint_dict: dict[str, Any] = {}
if is_rank0:
if fs.exists(checkpoint_filename):
file_to_load = _cache_if_needed(checkpoint_filename)
checkpoint_dict = torch.load(
file_to_load, map_location=device, weights_only=False
)
checkpoint_logging.success(
f"Loaded checkpoint file {checkpoint_filename} to device {device}"
)
# Optimizer state via DCP (collective)
if optimizer:
opt_model = optimizer_model or next(
(m for m in named_models.values() if _is_distributed_model(m)),
None,
)
optim_sd = checkpoint_dict.get("optimizer_state_dict", {}) if is_rank0 else {}
if opt_model is not None and _is_distributed_model(opt_model):
dtensor_plc = _get_dtensor_param_placements(opt_model)
if _has_non_fsdp_dtensors(opt_model, dtensor_plc):
osd_list: list[Any] = [optim_sd]
torch.distributed.broadcast_object_list(osd_list, src=0)
optim_sd = _redistribute_optim_sd_for_dtensor(dtensor_plc, osd_list[0])
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=full_options
)
else:
set_optimizer_state_dict(
opt_model, optimizer, optim_sd, options=broadcast_options
)
checkpoint_logging.success("Loaded optimizer state dictionary")
elif optim_sd:
optimizer.load_state_dict(optim_sd)
checkpoint_logging.success("Loaded optimizer state dictionary")
# Broadcast remaining training state (scheduler, scaler, epoch, metadata)
rest: dict[str, Any] = {}
if is_rank0:
rest = {k: v for k, v in checkpoint_dict.items() if k != "optimizer_state_dict"}
rest_list: list[Any] = [rest]
torch.distributed.broadcast_object_list(rest_list, src=0)
rest = rest_list[0]
if scheduler and "scheduler_state_dict" in rest:
scheduler.load_state_dict(rest["scheduler_state_dict"])
checkpoint_logging.success("Loaded scheduler state dictionary")
if scaler and "scaler_state_dict" in rest:
scaler.load_state_dict(rest["scaler_state_dict"])
checkpoint_logging.success("Loaded grad scaler state dictionary")
if "static_capture_state_dict" in rest:
_StaticCapture.load_state_dict(rest["static_capture_state_dict"])
checkpoint_logging.success("Loaded static capture state dictionary")
loaded_epoch = rest.get("epoch", 0)
if metadata_dict is not None:
metadata_dict.update(rest.get("metadata", {}))
return loaded_epoch
[docs]
def get_checkpoint_dir(base_dir: Path | str, model_name: str) -> str:
r"""Build a model-specific checkpoint directory path.
Returns ``"{base_dir}/checkpoints_{model_name}"``, handling both
local paths and ``msc://`` URIs.
Parameters
----------
base_dir : Path | str
Root directory under which the checkpoint subdirectory is placed.
model_name : str
Model name used as the directory suffix.
Returns
-------
str
Full path to the checkpoint directory.
"""
base_dir = str(base_dir)
top_level_dir = f"checkpoints_{model_name}"
protocol = fsspec.utils.get_protocol(base_dir)
if protocol == "msc":
if not base_dir.endswith("/"):
base_dir += "/"
return base_dir + top_level_dir
else:
return os.path.join(base_dir, top_level_dir)
def _cache_if_needed(path: str) -> str:
r"""Return a local path for ``path``, downloading to cache if remote.
For the ``"file"`` protocol the path is returned unchanged. For remote
protocols the file is fetched via
:func:`~physicsnemo.core.filesystem._download_cached` into a
process-specific cache directory.
Parameters
----------
path : str
Checkpoint file path (local or ``fsspec`` URI).
Returns
-------
str
Local filesystem path suitable for :func:`torch.load`.
"""
protocol = fsspec.utils.get_protocol(path)
if protocol == "file":
return path
else:
return _download_cached(
path,
recursive=False,
local_cache_path=os.path.join(LOCAL_CACHE, f"checkpoint_pid_{os.getpid()}"),
)