nemo_automodel.components.checkpoint.checkpointing#

Module Contents#

Classes#

_AsyncSaveContext

Internal container for async checkpointing state.

CheckpointingConfig

Configuration for checkpointing.

Checkpointer

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Functions#

_is_geq_torch_2_9

Check if the current torch version is greater than or equal to 2.9.0.

_is_safetensors_checkpoint

Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.

_is_bin_checkpoint

Return True if path looks like a PyTorch .bin checkpoint.

get_safetensors_index_path

Return the directory containing the first model.safetensors.index.json found for given model.

to_empty_parameters_only

Move parameters to the specified device without copying storage, skipping buffers.

save_config

Save a config to a weights path.

_ensure_dirs

Create directories on all ranks and synchronize across ranks.

_init_peft_adapters

Initialize the PEFT adapters with the scaled weights.

_reinit_non_persistent_buffers

Recompute non-persistent buffers that are not saved in checkpoints.

_apply

Apply a transformation function to parameters (and gradients) only.

_apply_key_mapping

Rename state-dict keys using regex-based key_mapping.

_load_full_state_dict_into_model

Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.

_convert_checkpoint_with_transformers

Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.

_maybe_adapt_state_dict_to_hf

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

_equally_divide_layers

Equally divide the state dict keys into num_shards shards.

_model_has_dtensors

True if any parameter is a DTensor (model is already sharded).

_is_custom_model

True if the model has a custom implementation in nemo_automodel/components/models/.

_is_remote_code_model

True if the model was loaded with trust_remote_code (HF dynamic modules).

_load_hf_checkpoint_preserving_dtype

Load a HuggingFace checkpoint into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Prefers safetensors but falls back to .bin files. Returns None if no loadable checkpoint is found.

_load_hf_safetensors_checkpoint

Load a safetensors checkpoint into a state dict.

_load_hf_bin_checkpoint

Load a HuggingFace .bin checkpoint into a state dict.

_maybe_adapt_state_dict_from_hf

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.

Data#

API#

nemo_automodel.components.checkpoint.checkpointing._is_geq_torch_2_9() bool#

Check if the current torch version is greater than or equal to 2.9.0.

nemo_automodel.components.checkpoint.checkpointing._is_safetensors_checkpoint(path: str) bool#

Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.

nemo_automodel.components.checkpoint.checkpointing._is_bin_checkpoint(path: str) bool#

Return True if path looks like a PyTorch .bin checkpoint.

class nemo_automodel.components.checkpoint.checkpointing._AsyncSaveContext#

Internal container for async checkpointing state.

One instance is maintained for the model save and one for the optimizer save to keep staging/upload futures and the associated process group and stager together in a single place.

stager: Any | None#

None

process_group: Any | None#

None

future: Any | None#

None

staging_active: bool#

False

class nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig#

Configuration for checkpointing.

enabled: bool#

None

checkpoint_dir: str | pathlib.Path#

None

model_save_format: str#

None

model_cache_dir: str | pathlib.Path#

None

model_repo_id: str#

None

save_consolidated: bool#

None

is_peft: bool#

None

model_state_dict_keys: list[str]#

None

is_async: bool#

False

dequantize_base_checkpoint: bool | None#

None

original_model_root_dir: str | None#

None

skip_task_head_prefixes_for_base_model: list[str] | None#

None

single_rank_consolidation: bool#

False

staging_dir: str | None#

None

v4_compatible: bool#

False

diffusers_compatible: bool#

False

best_metric_key: str#

‘default’

__post_init__()#

Convert a raw string such as “safetensors” into the right Enum.

class nemo_automodel.components.checkpoint.checkpointing.Checkpointer(
config: nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig,
dp_rank: int,
tp_rank: int,
pp_rank: int,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Supports:

  • HF sharded safetensors via custom storage reader/writer

  • Optional consolidated export (config, generation config, tokenizer)

  • PEFT adapter save/load handling

  • Async save for torch >= 2.9.0

Also provides DP-aware helpers for saving/loading auxiliary state and utilities to initialize from a base HF checkpoint.

Initialization

Initialize the checkpointer.

Parameters:
  • config – Checkpointing configuration.

  • dp_rank – Data parallel rank for the current process.

  • tp_rank – Tensor parallel rank for the current process.

  • pp_rank – Pipeline parallel rank for the current process.

  • moe_mesh – Optional device mesh used for MoE when adapting state dicts.

save_model(
model: torch.nn.Module,
weights_path: str,
peft_config: Optional[peft.PeftConfig] = None,
tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None,
) None#

Save model weights to weights_path/model.

Behavior:

  • PEFT: write adapter_model.safetensors and metadata on rank 0.

  • Safetensors + consolidation: emit HF artifacts under weights_path/model/consolidated and build a consolidated index.

  • Otherwise: use DCP with a Hugging Face or default storage writer to save shards.

Parameters:
  • model – Model to checkpoint.

  • weights_path – Base directory for checkpoints.

  • peft_config – Optional PEFT configuration when saving adapters.

  • tokenizer – Optional tokenizer to save with consolidated artifacts.

save_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Save optimizer (and optional scheduler) state to weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer whose state will be saved.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to include.

load_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Load optimizer (and optional scheduler) state from weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer to populate.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to populate.

load_model(
model: torch.nn.Module,
model_path: str,
is_init_step: bool = False,
use_checkpoint_id: bool = True,
key_mapping: Optional[dict[str, str]] = None,
) None#

Load model weights from model_path.

Behavior:

  • For PEFT (non-init): rank 0 reads adapter_model.safetensors, then broadcasts.

  • Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.

  • If the model exposes a state_dict_adapter, convert to/from HF format as needed.

  • For models requiring tensor merging (e.g., Mixtral), uses transformers’ conversion mapping.

Parameters:
  • model – Model or parallelized model parts to load into.

  • model_path – Path to the model checkpoint directory or HF snapshot.

  • is_init_step – If True, treat load as initialization from a base checkpoint.

  • use_checkpoint_id – Pass checkpoint_id to DCP if True; disable when using direct HF paths.

  • key_mapping – Optional key remapping when reading from HF checkpoints.

static initialize_model_weights(
model: torch.nn.Module,
device: torch.device,
peft_init_method: str | None = None,
) None#

Materialize meta-device parameters and initialize model weights.

Moves empty parameter shells to the target device, resets HF initialization flags, calls the model’s weight initialization method, and initializes any PEFT adapters.

Parameters:
  • model – Model whose weights should be initialized.

  • device – Target device for materialized parameters.

  • peft_init_method – Initialization method for PEFT adapters (e.g. “xavier”).

load_base_model(
model: torch.nn.Module,
device: torch.device,
root_dir: str,
model_name: str | None,
load_base_model: bool = True,
) None#

Load a model from the base Hugging Face checkpoint in parallel.

Parameters:
  • model – Model to load state into

  • device – Device to load model onto

  • root_dir – Root directory of the model cache or snapshots

  • model_name – Name of the model or an absolute path to a snapshot

  • load_base_model – If True, restore from HF base checkpoint

maybe_wait_for_staging() None#

Wait for the staging to finish if it is enabled.

async_wait() None#

Wait for the async save to finish.

save_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Save the stateful object.

This function is a helper function currently used to save the dataloader and rng state.

Parameters:
  • state – Stateful object to save

  • state_name – Name of the stateful object

  • path – Path to save stateful object

load_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Load the stateful object.

This function is a helper function currently used to load the dataloader and rng state.

Parameters:
  • state – Stateful object to load

  • state_name – Name of the stateful object

  • path – Path to load stateful object

close() None#

Close the checkpointer.

_do_load(
state_dict: dict[str, torch.Tensor],
path: str,
storage_reader: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader] = None,
is_init_step: bool = False,
) dict[str, torch.Tensor]#

Load a state dictionary from path using DCP or PEFT special-case logic.

Parameters:
  • state_dict – Mutable state dict to populate with tensors.

  • path – Checkpoint directory path.

  • storage_reader – Optional HF storage reader for safetensors.

  • is_init_step – True if loading from a base checkpoint during initialization.

Returns:

The populated state dictionary (may be replaced for PEFT).

_do_save(
state_dict: dict[str, torch.Tensor],
path: str,
storage_writer: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter] = None,
) Optional[torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse]#

Save a state dictionary to path using DCP or PEFT special-case logic.

  • For PEFT model saves: only rank 0 writes adapter_model.safetensors.

  • If async mode is enabled, schedule an asynchronous save.

Parameters:
  • state_dict – State dict to be serialized.

  • path – Checkpoint directory path.

  • storage_writer – Optional HF storage writer for safetensors sharding.

Returns:

Optional Future object if async mode is enabled.

_should_write_consolidated_safetensors() bool#

Whether to output consolidated HF weights along with sharded weights.

Returns True only for non-PEFT safetensors when consolidation is enabled.

_should_write_hf_metadata() bool#

Whether to write the HF artifacts.

_maybe_build_consolidated_index(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
state_dict: dict[str, torch.Tensor],
) Optional[dict[str, int]]#

Build FQN to shard index mapping for consolidated HF export.

Uses the base checkpoint index (if present), removes non-persistent keys, and assigns new keys to the last shard by default.

Parameters:
  • model_state – Wrapper exposing the primary model part.

  • state_dict – The state dict that will be saved.

Returns:

Mapping from FQN to shard index, or None when not consolidating.

_get_storage_writer(
consolidated_output_path: Optional[str],
fqn_to_index_mapping: Optional[dict[str, int]],
model_path: str,
consolidate_on_all_ranks: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter]#

Construct a Hugging Face storage writer for sharded safetensors.

Parameters:
  • consolidated_output_path – Optional path for consolidated artifacts.

  • fqn_to_index_mapping – Optional mapping from FQN to shard index.

  • model_path – Path where the model checkpoint is saved.

  • consolidate_on_all_ranks – If True, consolidate on all ranks on the main process.

Returns:

Configured _HuggingFaceStorageWriter or None for non-safetensors.

_get_storage_reader(
model_path: str,
key_mapping: Optional[dict[str, str]],
is_init_step: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader]#

Construct a Hugging Face storage reader when loading safetensors or during init.

Prefers the upstream torch.distributed.checkpoint.hf_storage.HuggingFaceStorageReader when no key_mapping is needed, since it uses safetensors’ native get_slice() for efficient partial reads (only the bytes for the local DTensor shard are read from disk). Falls back to the backported reader when key_mapping is required or when the upstream reader is not available.

Parameters:
  • model_path – Path to the model checkpoint directory or HF snapshot.

  • key_mapping – Optional key remapping for conversion.

  • is_init_step – If True, always produce a reader for base HF load.

Returns:

Configured storage reader or None for other formats.

_get_original_model_path(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
) str | None#

Get the path to the original model from the Hugging Face checkpoint.

nemo_automodel.components.checkpoint.checkpointing.get_safetensors_index_path(
cache_dir: str | pathlib.Path | None,
repo_id: str | None,
) str | None#

Return the directory containing the first model.safetensors.index.json found for given model.

If no model.safetensors.index.json is found then it returns None.

For example, if the file located is

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json

this function will return the directory path

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe...

This will error if the model hasn’t been downloaded or if the cache directory is incorrect.

Parameters:
  • cache_dir – Path to cache directory

  • repo_id – Hugging Face repository ID

Returns:

Path to the directory containing the index file.

Raises:

FileNotFoundError – If the index file is not found.

nemo_automodel.components.checkpoint.checkpointing.to_empty_parameters_only(
model: torch.nn.Module,
*,
device: torch.device,
recurse: bool = True,
dtype: torch.dtype | None = None,
) torch.nn.Module#

Move parameters to the specified device without copying storage, skipping buffers.

Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.

Parameters:
  • model – The module to transform

  • device – Target device

  • recurse – Whether to recurse into child modules

Returns:

The same module instance

nemo_automodel.components.checkpoint.checkpointing.save_config(config: dict[str, Any], weights_path: str) None#

Save a config to a weights path.

Parameters:
  • config – Config to save

  • weights_path – Path to save config

nemo_automodel.components.checkpoint.checkpointing._ensure_dirs(*dirs: Optional[str]) None#

Create directories on all ranks and synchronize across ranks.

Parameters:

*dirs – One or more directory paths that should exist.

nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters(
model: torch.nn.Module,
peft_init_method: str,
) None#

Initialize the PEFT adapters with the scaled weights.

Parameters:
  • model – Model to initialize PEFT adapters for

  • peft_init_method – Method to initialize PEFT adapters e.g. “xavier”. See LinearLoRA for more details.

nemo_automodel.components.checkpoint.checkpointing._MODELS_REQUIRING_BUFFER_REINIT: frozenset[str]#

‘frozenset(…)’

nemo_automodel.components.checkpoint.checkpointing._reinit_non_persistent_buffers(
model: torch.nn.Module,
device: torch.device,
model_type: str | None = None,
) None#

Recompute non-persistent buffers that are not saved in checkpoints.

Non-persistent buffers are not saved in checkpoints, so after meta-device materialization they contain uninitialized CUDA memory. When initialize_weights() is skipped (e.g. for Gemma3 to avoid DTensor issues), these buffers must be recomputed explicitly.

Only runs for models listed in _MODELS_REQUIRING_BUFFER_REINIT to avoid unexpected side-effects on arbitrary HF Hub models.

Handles four patterns:

  1. Standard RoPE — single inv_freq buffer with rope_init_fn + rope_kwargs (e.g. Nemotron-NAS).

  2. Per-layer-type RoPE{layer_type}_inv_freq buffers via compute_default_rope_parameters (e.g. Gemma3RotaryEmbedding).

  3. Scaled embeddingembed_scale buffer on ScaledWordEmbedding modules (Gemma family), recomputed from scalar_embed_scale.

  4. Vision position IDsposition_ids buffer on vision embedding modules (SigLIP), recomputed from num_positions.

Parameters:
  • model – Model to reinitialize non-persistent buffers for.

  • device – Device to create the new buffers on.

  • model_type – The config.model_type string. If not in _MODELS_REQUIRING_BUFFER_REINIT the function is a no-op.

nemo_automodel.components.checkpoint.checkpointing._apply(module, fn, recurse=True) torch.nn.Module#

Apply a transformation function to parameters (and gradients) only.

Mirrors nn.Module.to_empty for parameters while skipping buffers. Respects future flags controlling in-place vs swap behavior and safely handles wrapper subclasses.

Parameters:
  • module – Module whose parameters are to be transformed.

  • fn – Callable applied to each parameter (and its gradient).

  • recurse – Whether to recurse into child modules.

Returns:

The same module instance after transformation.

nemo_automodel.components.checkpoint.checkpointing._apply_key_mapping(
state_dict: dict[str, torch.Tensor],
key_mapping: dict[str, str],
) dict[str, torch.Tensor]#

Rename state-dict keys using regex-based key_mapping.

This mirrors the renaming logic used by the DCP / HuggingFace storage reader but operates directly on an in-memory state dict. It is needed when loading safetensors checkpoints outside of DCP so that HF checkpoint keys (e.g. language_model.model.X) are translated to the model’s parameter FQNs (e.g. model.language_model.X).

Parameters:
  • state_dict – Original state dict whose keys may need renaming.

  • key_mapping{regex_pattern: replacement} pairs applied in order.

Returns:

A new dict with renamed keys.

nemo_automodel.components.checkpoint.checkpointing._load_full_state_dict_into_model(
model_parts: list[torch.nn.Module],
state_dict: dict[str, torch.Tensor],
) None#

Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.

Every rank must supply the full state dict. PyTorch’s set_model_state_dict with full_state_dict=True (but not broadcast_from_rank0) calls _distribute_state_dict which lets each rank independently slice its local DTensor shard from the full tensor – no NCCL collectives are needed.

We intentionally avoid broadcast_from_rank0=True because it introduces an asymmetric workload: rank 0 does a synchronous CPU→GPU copy (.to(device)) per tensor while other ranks only do torch.empty (async allocation). The non-src ranks race ahead enqueuing hundreds of NCCL broadcasts that rank 0 cannot keep up with, leading to a 60 s NCCL watchdog timeout.

After loading, floating-point parameters are converted to match the checkpoint dtype. PyTorch’s set_model_state_dict uses copy semantics (assign=False) for non-meta parameters, which preserves the model’s initialisation dtype instead of the checkpoint dtype. The post-load fixup ensures the safetensors dtype (e.g. bf16) is honoured.

Parameters:
  • model_parts – List of model parts (for pipeline parallelism)

  • state_dict – Full state dict with regular tensors. Must be populated on every rank (not just rank 0).

nemo_automodel.components.checkpoint.checkpointing._convert_checkpoint_with_transformers(
model: torch.nn.Module,
model_path: str,
key_mapping: Optional[dict[str, str]] = None,
) Optional[dict[str, torch.Tensor]]#

Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.

This handles MoE models like Mixtral where the checkpoint has individual expert weights but the model uses grouped expert tensors. The transformers library’s WeightConverter operations handle the tensor merging (MergeModulelist, Concatenate).

This function converts the state dict WITHOUT loading it into the model, so it can be used with FSDP-aware loading mechanisms.

Parameters:
  • model – The model (used to get conversion mapping and target keys).

  • model_path – Path to the HuggingFace checkpoint directory.

  • key_mapping – Optional additional key mapping.

Returns:

Converted state dict ready for loading, or None if conversion failed.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_to_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
quantization: bool = False,
**kwargs,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

nemo_automodel.components.checkpoint.checkpointing._equally_divide_layers(
num_shards: int,
keys: list[str],
) dict[str, int]#

Equally divide the state dict keys into num_shards shards.

nemo_automodel.components.checkpoint.checkpointing._model_has_dtensors(module: torch.nn.Module) bool#

True if any parameter is a DTensor (model is already sharded).

nemo_automodel.components.checkpoint.checkpointing._is_custom_model(module: torch.nn.Module) bool#

True if the model has a custom implementation in nemo_automodel/components/models/.

The generic HFCheckpointingMixin (in .common.hf_checkpointing_mixin) is injected into every model by _get_mixin_wrapped_class and does NOT count as a “custom model”. Only actual model implementations (e.g. llama, deepseek_v3) that live under nemo_automodel.components.models qualify.

nemo_automodel.components.checkpoint.checkpointing._is_remote_code_model(module: torch.nn.Module) bool#

True if the model was loaded with trust_remote_code (HF dynamic modules).

nemo_automodel.components.checkpoint.checkpointing._load_hf_checkpoint_preserving_dtype(
model_path: str,
weights_only: bool = True,
) Optional[dict[str, torch.Tensor]]#

Load a HuggingFace checkpoint into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Prefers safetensors but falls back to .bin files. Returns None if no loadable checkpoint is found.

Parameters:
  • model_path – Path to checkpoint file or directory.

  • weights_only – Forwarded to torch.load when loading .bin files.

nemo_automodel.components.checkpoint.checkpointing._load_hf_safetensors_checkpoint(
model_path: str,
) Optional[dict[str, torch.Tensor]]#

Load a safetensors checkpoint into a state dict.

nemo_automodel.components.checkpoint.checkpointing._load_hf_bin_checkpoint(
model_path: str,
weights_only: bool = True,
) Optional[dict[str, torch.Tensor]]#

Load a HuggingFace .bin checkpoint into a state dict.

Handles single-file (pytorch_model.bin), sharded (pytorch_model.bin.index.json), and glob fallback (*.bin) layouts. Returns None if no .bin files are found.

Parameters:
  • model_path – Path to checkpoint file or directory.

  • weights_only – Passed to torch.load. Default True for safety; set to False for remote-code models whose checkpoints may contain custom pickled objects.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_from_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.