nemo_automodel.components.checkpoint.checkpointing
nemo_automodel.components.checkpoint.checkpointing
Module Contents
Classes
Functions
Data
_CONSOLIDATED_SIZE_WARNING_THRESHOLD_BYTES
_DEFAULT_HF_CONSOLIDATED_SHARD_SIZE_BYTES
_MODELS_REQUIRING_BUFFER_REINIT
API
High-level checkpoint manager built on torch.distributed.checkpoint (DCP).
Supports:
- HF sharded safetensors via custom storage reader/writer
- Optional consolidated export (config, generation config, tokenizer)
- PEFT adapter save/load handling
- Async save for torch >= 2.9.0
Also provides DP-aware helpers for saving/loading auxiliary state and utilities to initialize from a base HF checkpoint.
Load a state dictionary from path using DCP or PEFT special-case logic.
Parameters:
Mutable state dict to populate with tensors.
Checkpoint directory path.
Optional HF storage reader for safetensors.
True if loading from a base checkpoint during initialization.
Returns: dict[str, torch.Tensor]
The populated state dictionary (may be replaced for PEFT).
Save a state dictionary to path using DCP or PEFT special-case logic.
- For PEFT model saves: only rank 0 writes
adapter_model.safetensors. - If async mode is enabled, schedule an asynchronous save.
Parameters:
State dict to be serialized.
Checkpoint directory path.
Optional HF storage writer for safetensors sharding.
Returns: Optional[AsyncSaveResponse]
Optional Future object if async mode is enabled.
Get the path to the original model from the Hugging Face checkpoint.
Construct a Hugging Face storage reader when loading safetensors or during init.
Prefers the upstream torch.distributed.checkpoint.hf_storage.HuggingFaceStorageReader
when no key_mapping is needed, since it uses safetensors’ native get_slice() for
efficient partial reads (only the bytes for the local DTensor shard are read from disk).
Falls back to the backported reader when key_mapping is required or when the upstream
reader is not available.
Parameters:
Path to the model checkpoint directory or HF snapshot.
Optional key remapping for conversion.
If True, always produce a reader for base HF load.
Whether model_path holds a safetensors checkpoint; computed
from the directory contents when not supplied.
Returns: Optional[_HuggingFaceStorageReader]
Configured storage reader, or None for the default DCP FileSystemReader.
Construct a Hugging Face storage writer for sharded safetensors.
Parameters:
Optional path for consolidated artifacts.
Optional mapping from FQN to shard index.
Optional mapping from FQN to original HF safetensors dtype string.
Path where the model checkpoint is saved.
If True, consolidate on all ranks on the main process.
Returns: Optional[_HuggingFaceStorageWriter]
Configured _HuggingFaceStorageWriter or None for non-safetensors.
Build FQN to shard index mapping for consolidated HF export.
Uses the base checkpoint index (if present), removes non-persistent keys, and assigns new keys to the last shard by default.
Parameters:
Wrapper exposing the primary model part.
The state dict that will be saved.
Returns: Optional[dict[str, int]]
Mapping from FQN to shard index, or None when not consolidating.
Build FQN to original HF safetensors dtype mapping for consolidated export.
Returns None when the run started from config-only weights or the original HF safetensors headers are not available. In that case consolidation keeps the saved checkpoint dtype unless the user explicitly passes CAST_DTYPE to the offline helper.
Log the final-checkpoint helper hint when consolidated export was disabled.
Write a conservative helper script for offline HF safetensors consolidation.
Wait for the async save to finish.
Close the checkpointer.
Materialize meta-device parameters and initialize model weights.
Moves empty parameter shells to the target device, resets HF initialization flags, calls the model’s weight initialization method, and initializes any PEFT adapters.
Parameters:
Model whose weights should be initialized.
Target device for materialized parameters.
Initialization method for PEFT adapters (e.g. “xavier”).
Load a model from the base Hugging Face checkpoint in parallel.
Parameters:
Model to load state into
Device to load model onto
Root directory of the model cache or snapshots
Name of the model or an absolute path to a snapshot
If True, restore from HF base checkpoint
Load a custom stateful object previously saved with DCP.
Load model weights from model_path.
Behavior:
- For PEFT (non-init): rank 0 reads
adapter_model.safetensors, then broadcasts. - Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.
- If the model exposes a
state_dict_adapter, convert to/from HF format as needed. - For models requiring tensor merging (e.g., Mixtral), uses transformers’ conversion mapping.
Parameters:
Model or parallelized model parts to load into.
Path to the model checkpoint directory or HF snapshot.
If True, treat load as initialization from a base checkpoint.
Pass checkpoint_id to DCP if True; disable when using direct HF paths.
Optional key remapping when reading from HF checkpoints.
If True, keep the model’s current initialization for parameters that are absent from the checkpoint instead of requiring an exact key match.
Load the stateful object.
This function is a helper function currently used to load the dataloader and rng state.
Parameters:
Stateful object to load
Name of the stateful object
Path to load stateful object
Load optimizer (and optional scheduler) state from weights_path/optim using DCP.
Parameters:
Optimizer to populate.
Model providing partitioning context for the optimizer wrapper.
Base directory for checkpoints.
Optional LR scheduler to populate.
Wait for the staging to finish if it is enabled.
Save a custom stateful object through DCP on all ranks.
This is intended for auxiliary objects whose state dict contains
sharded tensors, for example BAGEL EMA shadows under FSDP2. Rank-0
torch.save would only persist rank 0’s local shard; DCP sees the
DTensor metadata and writes all shards correctly.
Save model weights to weights_path/model.
Behavior:
- PEFT: write
adapter_model.safetensorsand metadata on rank 0. - Safetensors + consolidation: emit HF artifacts under
weights_path/model/consolidatedand build a consolidated index. - Otherwise: use DCP with a Hugging Face or default storage writer to save shards.
Parameters:
Model to checkpoint.
Base directory for checkpoints.
Optional PEFT configuration when saving adapters.
Optional tokenizer to save with consolidated artifacts.
Whether this save is the final scheduled training checkpoint.
Save the stateful object.
This function is a helper function currently used to save the dataloader and rng state.
Parameters:
Stateful object to save
Name of the stateful object
Path to save stateful object
Save optimizer (and optional scheduler) state to weights_path/optim using DCP.
Parameters:
Optimizer whose state will be saved.
Model providing partitioning context for the optimizer wrapper.
Base directory for checkpoints.
Optional LR scheduler to include.
Internal container for async checkpointing state.
One instance is maintained for the model save and one for the optimizer save to keep staging/upload futures and the associated process group and stager together in a single place.
Return the PEFT adapter safetensors path inside a checkpoint dir (local or msc://).
Apply a transformation function to parameters (and gradients) only.
Mirrors nn.Module.to_empty for parameters while skipping buffers. Respects
future flags controlling in-place vs swap behavior and safely handles
wrapper subclasses.
Parameters:
Module whose parameters are to be transformed.
Callable applied to each parameter (and its gradient).
Whether to recurse into child modules.
Returns: nn.Module
The same module instance after transformation.
Rename state-dict keys using regex-based key_mapping.
This mirrors the renaming logic used by the DCP / HuggingFace storage
reader but operates directly on an in-memory state dict. It is needed
when loading safetensors checkpoints outside of DCP so that HF checkpoint
keys (e.g. language_model.model.X) are translated to the model’s
parameter FQNs (e.g. model.language_model.X).
Parameters:
Original state dict whose keys may need renaming.
{regex_pattern: replacement} pairs applied in order.
Returns: dict[str, torch.Tensor]
A new dict with renamed keys.
Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.
This handles MoE models like Mixtral where the checkpoint has individual expert weights but the model uses grouped expert tensors. The transformers library’s WeightConverter operations handle the tensor merging (MergeModulelist, Concatenate).
This function converts the state dict WITHOUT loading it into the model, so it can be used with FSDP-aware loading mechanisms.
Parameters:
The model (used to get conversion mapping and target keys).
Path to the HuggingFace checkpoint directory.
Optional additional key mapping.
Returns: Optional[dict[str, torch.Tensor]]
Converted state dict ready for loading, or None if conversion failed.
Assign keys to deterministic size-based shards.
Create directories on all ranks and synchronize across ranks.
Parameters:
One or more directory paths that should exist.
Raise an error if MSC is not installed but a cloud path is used.
Equally divide the state dict keys into num_shards shards.
Return checkpoint FQNs present in metadata.
Return the local HF safetensors reference directory for a model.
Prefer the snapshot directory containing model.safetensors.index.json for
sharded checkpoints. If no index exists but a snapshot directory is present,
return that directory as the single-file safetensors reference path. Return
None when repo_id is None or the repo has no cached snapshot directory.
For example, if the located file is
/opt/models/models—meta-llama—Llama-3.2-3B/snapshots/13afe…/model.safetensors.index.json
this function will return the directory path
/opt/models/models—meta-llama—Llama-3.2-3B/snapshots/13afe…
This will error if the model hasn’t been downloaded or if the cache directory is incorrect.
Parameters:
Path to cache directory
Hugging Face repository ID
Returns: str | None
Path to the snapshot/model directory containing safetensors weights, or
Return the original HF safetensors index total size, if available.
Initialize the PEFT adapters with the scaled weights.
Parameters:
Model to initialize PEFT adapters for
Method to initialize PEFT adapters e.g. “xavier”. See LinearLoRA for more details.
Return True if path looks like a PyTorch .bin checkpoint.
True if the model has a custom implementation in nemo_automodel/components/models/.
The generic HFCheckpointingMixin (in .common.hf_checkpointing_mixin) is injected into every model by _get_mixin_wrapped_class and does NOT count as a “custom model”. Only actual model implementations (e.g. llama, deepseek_v3) that live under nemo_automodel.components.models qualify.
True if the model was loaded with trust_remote_code (HF dynamic modules).
Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.
Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.
Every rank must supply the full state dict. PyTorch’s
set_model_state_dict with full_state_dict=True (but not
broadcast_from_rank0) calls _distribute_state_dict which lets
each rank independently slice its local DTensor shard from the full
tensor — no NCCL collectives are needed.
We intentionally avoid broadcast_from_rank0=True because it
introduces an asymmetric workload: rank 0 does a synchronous CPU→GPU
copy (.to(device)) per tensor while other ranks only do
torch.empty (async allocation). The non-src ranks race ahead
enqueuing hundreds of NCCL broadcasts that rank 0 cannot keep up with,
leading to a 60 s NCCL watchdog timeout.
After loading, floating-point parameters are converted to match the
checkpoint dtype. PyTorch’s set_model_state_dict uses copy
semantics (assign=False) for non-meta parameters, which preserves
the model’s initialisation dtype instead of the checkpoint dtype.
The post-load fixup ensures the safetensors dtype (e.g. bf16) is
honoured.
Parameters:
List of model parts (for pipeline parallelism)
Full state dict with regular tensors. Must be populated on every rank (not just rank 0).
Load a HuggingFace .bin checkpoint into a state dict.
Handles single-file (pytorch_model.bin), sharded (pytorch_model.bin.index.json), and glob fallback (*.bin) layouts. Returns None if no .bin files are found.
Parameters:
Path to checkpoint file or directory.
Passed to torch.load. Default True for safety;
set to False for remote-code models whose checkpoints may
contain custom pickled objects.
Load a HuggingFace checkpoint into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Prefers safetensors but falls back to .bin files. Returns None if no loadable checkpoint is found.
Parameters:
Path to checkpoint file or directory.
Forwarded to torch.load when loading .bin files.
Load a safetensors checkpoint into a state dict.
Read a safetensors file from a local path or an msc:// cloud path.
Replace non-contiguous tensor values in state_dict with contiguous copies in place.
MoE adapters return non-contiguous strided views into the model’s grouped
expert storage for the optimized load path; safetensors.torch.save
(which the DCP HF storage writer calls) rejects non-contiguous tensors,
so we materialize one tensor at a time here with empty_cache between
iterations. Per-tensor transient is bounded to a single expert weight
instead of allocating the full grouped set up front.
Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.
Custom models use state dict adapters to convert the state dict to the Hugging Face format.
Return an MSC filesystem reader for msc:// paths, else the given reader.
Return an MSC filesystem writer for msc:// paths, else the given writer.
True if any parameter is a DTensor (model is already sharded).
Align original HF dtype metadata with the keys that will be exported.
Recompute non-persistent buffers that are not saved in checkpoints.
Non-persistent buffers are not saved in checkpoints, so after meta-device
materialization they contain uninitialized CUDA memory. When
initialize_weights() is skipped (e.g. for Gemma3 to avoid DTensor
issues), these buffers must be recomputed explicitly.
Only runs for models listed in _MODELS_REQUIRING_BUFFER_REINIT to
avoid unexpected side-effects on arbitrary HF Hub models.
Handles four patterns:
- Standard RoPE — single
inv_freqbuffer withrope_init_fn+rope_kwargs(e.g. Nemotron-NAS). - Per-layer-type RoPE —
{layer_type}_inv_freqbuffers viacompute_default_rope_parameters(e.g. Gemma3RotaryEmbedding). - Scaled embedding —
embed_scalebuffer onScaledWordEmbeddingmodules (Gemma family), recomputed fromscalar_embed_scale. - Vision position IDs —
position_idsbuffer on vision embedding modules (SigLIP), recomputed fromnum_positions.
Parameters:
Model to reinitialize non-persistent buffers for.
Device to create the new buffers on.
The config.model_type string. If not in
_MODELS_REQUIRING_BUFFER_REINIT the function is a no-op.
Write a safetensors file to a local path or an msc:// cloud path.
For cloud paths the tensors are serialized to bytes and streamed to the MSC
file handle, since save_file only accepts a local filesystem path.
Whether to output consolidated HF weights along with sharded weights.
Whether to write HF metadata/artifacts for a checkpoint.
Summarize state-dict key mismatches for checkpoint load diagnostics.
Educate users about the cost of inline HF consolidation.
Warn when inline consolidated export is large enough to waste GPU allocation time.
Check if path is a cloud storage path (MSC).
Save a config to a weights path.
Parameters:
Config to save
Path to save config
Move parameters to the specified device without copying storage, skipping buffers.
Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.
Parameters:
The module to transform
Target device
Whether to recurse into child modules
Returns: nn.Module
The same module instance