nemo_automodel.components.checkpoint.checkpointing#
Module Contents#
Classes#
Internal container for async checkpointing state. |
|
Configuration for checkpointing. |
|
High-level checkpoint manager built on torch.distributed.checkpoint (DCP). |
Functions#
Check if the current torch version is greater than or equal to 2.9.0. |
|
Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other. |
|
Return the directory containing the first |
|
Move parameters to the specified device without copying storage, skipping buffers. |
|
Save a config to a weights path. |
|
Create directories on all ranks and synchronize across ranks. |
|
Initialize the PEFT adapters with the scaled weights. |
|
Recompute non-persistent RoPE |
|
Apply a transformation function to parameters (and gradients) only. |
|
Rename state-dict keys using regex-based |
|
Load a full (non-sharded) state dict into a potentially FSDP-wrapped model. |
|
Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging. |
|
Custom models use state dict adapters to convert the state dict to the Hugging Face format. |
|
Equally divide the state dict keys into num_shards shards. |
|
True if any parameter is a DTensor (model is already sharded). |
|
True if the model has a custom implementation in nemo_automodel/components/models/. |
|
Load a HuggingFace checkpoint (safetensors) into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Returns None if the path is not a valid safetensors checkpoint. |
|
Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format. |
API#
- nemo_automodel.components.checkpoint.checkpointing._is_geq_torch_2_9() bool#
Check if the current torch version is greater than or equal to 2.9.0.
- nemo_automodel.components.checkpoint.checkpointing._is_safetensors_checkpoint(path: str) bool#
Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.
- class nemo_automodel.components.checkpoint.checkpointing._AsyncSaveContext#
Internal container for async checkpointing state.
One instance is maintained for the model save and one for the optimizer save to keep staging/upload futures and the associated process group and stager together in a single place.
- stager: Any | None#
None
- process_group: Any | None#
None
- future: Any | None#
None
- staging_active: bool#
False
- class nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig#
Configuration for checkpointing.
- enabled: bool#
None
- checkpoint_dir: str | pathlib.Path#
None
- model_save_format: str#
None
- model_cache_dir: str | pathlib.Path#
None
- model_repo_id: str#
None
- save_consolidated: bool#
None
- is_peft: bool#
None
- model_state_dict_keys: list[str]#
None
- is_async: bool#
False
- dequantize_base_checkpoint: bool | None#
None
- original_model_root_dir: str | None#
None
- skip_task_head_prefixes_for_base_model: list[str] | None#
None
- single_rank_consolidation: bool#
False
- staging_dir: str | None#
None
- __post_init__()#
Convert a raw string such as “safetensors” into the right Enum.
- class nemo_automodel.components.checkpoint.checkpointing.Checkpointer(
- config: nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig,
- dp_rank: int,
- tp_rank: int,
- pp_rank: int,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
High-level checkpoint manager built on torch.distributed.checkpoint (DCP).
Supports:
HF sharded safetensors via custom storage reader/writer
Optional consolidated export (config, generation config, tokenizer)
PEFT adapter save/load handling
Async save for torch >= 2.9.0
Also provides DP-aware helpers for saving/loading auxiliary state and utilities to initialize from a base HF checkpoint.
Initialization
Initialize the checkpointer.
- Parameters:
config – Checkpointing configuration.
dp_rank – Data parallel rank for the current process.
tp_rank – Tensor parallel rank for the current process.
pp_rank – Pipeline parallel rank for the current process.
moe_mesh – Optional device mesh used for MoE when adapting state dicts.
- save_model(
- model: torch.nn.Module,
- weights_path: str,
- peft_config: Optional[peft.PeftConfig] = None,
- tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None,
Save model weights to
weights_path/model.Behavior:
PEFT: write
adapter_model.safetensorsand metadata on rank 0.Safetensors + consolidation: emit HF artifacts under
weights_path/model/consolidatedand build a consolidated index.Otherwise: use DCP with a Hugging Face or default storage writer to save shards.
- Parameters:
model – Model to checkpoint.
weights_path – Base directory for checkpoints.
peft_config – Optional PEFT configuration when saving adapters.
tokenizer – Optional tokenizer to save with consolidated artifacts.
- save_optimizer(
- optimizer: torch.optim.Optimizer,
- model: torch.nn.Module,
- weights_path: str,
- scheduler: Optional[Any] = None,
Save optimizer (and optional scheduler) state to
weights_path/optimusing DCP.- Parameters:
optimizer – Optimizer whose state will be saved.
model – Model providing partitioning context for the optimizer wrapper.
weights_path – Base directory for checkpoints.
scheduler – Optional LR scheduler to include.
- load_optimizer(
- optimizer: torch.optim.Optimizer,
- model: torch.nn.Module,
- weights_path: str,
- scheduler: Optional[Any] = None,
Load optimizer (and optional scheduler) state from
weights_path/optimusing DCP.- Parameters:
optimizer – Optimizer to populate.
model – Model providing partitioning context for the optimizer wrapper.
weights_path – Base directory for checkpoints.
scheduler – Optional LR scheduler to populate.
- load_model(
- model: torch.nn.Module,
- model_path: str,
- is_init_step: bool = False,
- use_checkpoint_id: bool = True,
- key_mapping: Optional[dict[str, str]] = None,
Load model weights from
model_path.Behavior:
For PEFT (non-init): rank 0 reads
adapter_model.safetensors, then broadcasts.Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.
If the model exposes a
state_dict_adapter, convert to/from HF format as needed.For models requiring tensor merging (e.g., Mixtral), uses transformers’ conversion mapping.
- Parameters:
model – Model or parallelized model parts to load into.
model_path – Path to the model checkpoint directory or HF snapshot.
is_init_step – If True, treat load as initialization from a base checkpoint.
use_checkpoint_id – Pass
checkpoint_idto DCP if True; disable when using direct HF paths.key_mapping – Optional key remapping when reading from HF checkpoints.
- static initialize_model_weights(
- model: torch.nn.Module,
- device: torch.device,
- peft_init_method: str | None = None,
Materialize meta-device parameters and initialize model weights.
Moves empty parameter shells to the target device, resets HF initialization flags, calls the model’s weight initialization method, and initializes any PEFT adapters.
- Parameters:
model – Model whose weights should be initialized.
device – Target device for materialized parameters.
peft_init_method – Initialization method for PEFT adapters (e.g. “xavier”).
- load_base_model(
- model: torch.nn.Module,
- device: torch.device,
- root_dir: str,
- model_name: str | None,
- load_base_model: bool = True,
Load a model from the base Hugging Face checkpoint in parallel.
- Parameters:
model – Model to load state into
device – Device to load model onto
root_dir – Root directory of the model cache or snapshots
model_name – Name of the model or an absolute path to a snapshot
load_base_model – If True, restore from HF base checkpoint
- maybe_wait_for_staging() None#
Wait for the staging to finish if it is enabled.
- async_wait() None#
Wait for the async save to finish.
- save_on_dp_ranks(
- state: Any,
- state_name: str,
- path: str,
Save the stateful object.
This function is a helper function currently used to save the dataloader and rng state.
- Parameters:
state – Stateful object to save
state_name – Name of the stateful object
path – Path to save stateful object
- load_on_dp_ranks(
- state: Any,
- state_name: str,
- path: str,
Load the stateful object.
This function is a helper function currently used to load the dataloader and rng state.
- Parameters:
state – Stateful object to load
state_name – Name of the stateful object
path – Path to load stateful object
- close() None#
Close the checkpointer.
- _do_load(
- state_dict: dict[str, torch.Tensor],
- path: str,
- storage_reader: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader] = None,
- is_init_step: bool = False,
Load a state dictionary from
pathusing DCP or PEFT special-case logic.- Parameters:
state_dict – Mutable state dict to populate with tensors.
path – Checkpoint directory path.
storage_reader – Optional HF storage reader for safetensors.
is_init_step – True if loading from a base checkpoint during initialization.
- Returns:
The populated state dictionary (may be replaced for PEFT).
- _do_save(
- state_dict: dict[str, torch.Tensor],
- path: str,
- storage_writer: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter] = None,
Save a state dictionary to
pathusing DCP or PEFT special-case logic.For PEFT model saves: only rank 0 writes
adapter_model.safetensors.If async mode is enabled, schedule an asynchronous save.
- Parameters:
state_dict – State dict to be serialized.
path – Checkpoint directory path.
storage_writer – Optional HF storage writer for safetensors sharding.
- Returns:
Optional Future object if async mode is enabled.
- _should_write_consolidated_safetensors() bool#
Whether to output consolidated HF weights along with sharded weights.
Returns True only for non-PEFT safetensors when consolidation is enabled.
- _should_write_hf_metadata() bool#
Whether to write the HF artifacts.
- _maybe_build_consolidated_index(
- model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
- state_dict: dict[str, torch.Tensor],
Build FQN to shard index mapping for consolidated HF export.
Uses the base checkpoint index (if present), removes non-persistent keys, and assigns new keys to the last shard by default.
- Parameters:
model_state – Wrapper exposing the primary model part.
state_dict – The state dict that will be saved.
- Returns:
Mapping from FQN to shard index, or None when not consolidating.
- _get_storage_writer(
- consolidated_output_path: Optional[str],
- fqn_to_index_mapping: Optional[dict[str, int]],
- model_path: str,
- consolidate_on_all_ranks: bool = False,
Construct a Hugging Face storage writer for sharded safetensors.
- Parameters:
consolidated_output_path – Optional path for consolidated artifacts.
fqn_to_index_mapping – Optional mapping from FQN to shard index.
model_path – Path where the model checkpoint is saved.
consolidate_on_all_ranks – If True, consolidate on all ranks on the main process.
- Returns:
Configured
_HuggingFaceStorageWriteror None for non-safetensors.
- _get_storage_reader(
- model_path: str,
- key_mapping: Optional[dict[str, str]],
- is_init_step: bool = False,
Construct a Hugging Face storage reader when loading safetensors or during init.
- Parameters:
model_path – Path to the model checkpoint directory or HF snapshot.
key_mapping – Optional key remapping for conversion.
is_init_step – If True, always produce a reader for base HF load.
- Returns:
Configured
_HuggingFaceStorageReaderor None for other formats.
- _get_original_model_path( ) str | None#
Get the path to the original model from the Hugging Face checkpoint.
- nemo_automodel.components.checkpoint.checkpointing.get_safetensors_index_path(
- cache_dir: str | pathlib.Path | None,
- repo_id: str | None,
Return the directory containing the first
model.safetensors.index.jsonfound for given model.If no
model.safetensors.index.jsonis found then it returns None.For example, if the file located is
/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json
this function will return the directory path
/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe...
This will error if the model hasn’t been downloaded or if the cache directory is incorrect.
- Parameters:
cache_dir – Path to cache directory
repo_id – Hugging Face repository ID
- Returns:
Path to the directory containing the index file.
- Raises:
FileNotFoundError – If the index file is not found.
- nemo_automodel.components.checkpoint.checkpointing.to_empty_parameters_only(
- model: torch.nn.Module,
- *,
- device: torch.device,
- recurse: bool = True,
- dtype: torch.dtype | None = None,
Move parameters to the specified device without copying storage, skipping buffers.
Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.
- Parameters:
model – The module to transform
device – Target device
recurse – Whether to recurse into child modules
- Returns:
The same module instance
- nemo_automodel.components.checkpoint.checkpointing.save_config(config: dict[str, Any], weights_path: str) None#
Save a config to a weights path.
- Parameters:
config – Config to save
weights_path – Path to save config
- nemo_automodel.components.checkpoint.checkpointing._ensure_dirs(*dirs: Optional[str]) None#
Create directories on all ranks and synchronize across ranks.
- Parameters:
*dirs – One or more directory paths that should exist.
- nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters(
- model: torch.nn.Module,
- peft_init_method: str,
Initialize the PEFT adapters with the scaled weights.
- Parameters:
model – Model to initialize PEFT adapters for
peft_init_method – Method to initialize PEFT adapters e.g. “xavier”. See
LinearLoRAfor more details.
- nemo_automodel.components.checkpoint.checkpointing._reinit_rope_buffers(
- model: torch.nn.Module,
- device: torch.device,
Recompute non-persistent RoPE
inv_freqbuffers for Nemotron-NAS models.- Parameters:
model – Model to reinitialize RoPE buffers for.
device – Device to create the new buffers on.
- nemo_automodel.components.checkpoint.checkpointing._apply(module, fn, recurse=True) torch.nn.Module#
Apply a transformation function to parameters (and gradients) only.
Mirrors
nn.Module.to_emptyfor parameters while skipping buffers. Respects future flags controlling in-place vs swap behavior and safely handles wrapper subclasses.- Parameters:
module – Module whose parameters are to be transformed.
fn – Callable applied to each parameter (and its gradient).
recurse – Whether to recurse into child modules.
- Returns:
The same module instance after transformation.
- nemo_automodel.components.checkpoint.checkpointing._apply_key_mapping(
- state_dict: dict[str, torch.Tensor],
- key_mapping: dict[str, str],
Rename state-dict keys using regex-based
key_mapping.This mirrors the renaming logic used by the DCP / HuggingFace storage reader but operates directly on an in-memory state dict. It is needed when loading safetensors checkpoints outside of DCP so that HF checkpoint keys (e.g.
language_model.model.X) are translated to the model’s parameter FQNs (e.g.model.language_model.X).- Parameters:
state_dict – Original state dict whose keys may need renaming.
key_mapping –
{regex_pattern: replacement}pairs applied in order.
- Returns:
A new dict with renamed keys.
- nemo_automodel.components.checkpoint.checkpointing._load_full_state_dict_into_model(
- model_parts: list[torch.nn.Module],
- state_dict: dict[str, torch.Tensor],
Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.
Every rank must supply the full state dict. PyTorch’s
set_model_state_dictwithfull_state_dict=True(but notbroadcast_from_rank0) calls_distribute_state_dictwhich lets each rank independently slice its local DTensor shard from the full tensor – no NCCL collectives are needed.We intentionally avoid
broadcast_from_rank0=Truebecause it introduces an asymmetric workload: rank 0 does a synchronous CPU→GPU copy (.to(device)) per tensor while other ranks only dotorch.empty(async allocation). The non-src ranks race ahead enqueuing hundreds of NCCL broadcasts that rank 0 cannot keep up with, leading to a 60 s NCCL watchdog timeout.After loading, floating-point parameters are converted to match the checkpoint dtype. PyTorch’s
set_model_state_dictuses copy semantics (assign=False) for non-meta parameters, which preserves the model’s initialisation dtype instead of the checkpoint dtype. The post-load fixup ensures the safetensors dtype (e.g. bf16) is honoured.- Parameters:
model_parts – List of model parts (for pipeline parallelism)
state_dict – Full state dict with regular tensors. Must be populated on every rank (not just rank 0).
- nemo_automodel.components.checkpoint.checkpointing._convert_checkpoint_with_transformers(
- model: torch.nn.Module,
- model_path: str,
- key_mapping: Optional[dict[str, str]] = None,
Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.
This handles MoE models like Mixtral where the checkpoint has individual expert weights but the model uses grouped expert tensors. The transformers library’s WeightConverter operations handle the tensor merging (MergeModulelist, Concatenate).
This function converts the state dict WITHOUT loading it into the model, so it can be used with FSDP-aware loading mechanisms.
- Parameters:
model – The model (used to get conversion mapping and target keys).
model_path – Path to the HuggingFace checkpoint directory.
key_mapping – Optional additional key mapping.
- Returns:
Converted state dict ready for loading, or None if conversion failed.
- nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_to_hf(
- model_part: torch.nn.Module,
- state_dict: dict[str, torch.Tensor],
- quantization: bool = False,
- **kwargs,
Custom models use state dict adapters to convert the state dict to the Hugging Face format.
- nemo_automodel.components.checkpoint.checkpointing._equally_divide_layers(
- num_shards: int,
- keys: list[str],
Equally divide the state dict keys into num_shards shards.
- nemo_automodel.components.checkpoint.checkpointing._model_has_dtensors(module: torch.nn.Module) bool#
True if any parameter is a DTensor (model is already sharded).
- nemo_automodel.components.checkpoint.checkpointing._is_custom_model(module: torch.nn.Module) bool#
True if the model has a custom implementation in nemo_automodel/components/models/.
- nemo_automodel.components.checkpoint.checkpointing._load_hf_checkpoint_preserving_dtype(
- model_path: str,
Load a HuggingFace checkpoint (safetensors) into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Returns None if the path is not a valid safetensors checkpoint.
- nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_from_hf(
- model_part: torch.nn.Module,
- state_dict: dict[str, torch.Tensor],
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.