nemo_automodel.components.checkpoint.checkpointing#

Module Contents#

Classes#

_AsyncSaveContext

Internal container for async checkpointing state.

CheckpointingConfig

Configuration for checkpointing.

Checkpointer

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Functions#

_is_geq_torch_2_9

Check if the current torch version is greater than or equal to 2.9.0.

_is_safetensors_checkpoint

Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.

get_safetensors_index_path

Return the directory containing the first model.safetensors.index.json found for given model.

to_empty_parameters_only

Move parameters to the specified device without copying storage, skipping buffers.

save_config

Save a config to a weights path.

_ensure_dirs

Create directories on all ranks and synchronize across ranks.

_init_peft_adapters

Initialize the PEFT adapters with the scaled weights.

_reinit_rope_buffers

Recompute non-persistent RoPE inv_freq buffers for Nemotron-NAS models.

_apply

Apply a transformation function to parameters (and gradients) only.

_apply_key_mapping

Rename state-dict keys using regex-based key_mapping.

_load_full_state_dict_into_model

Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.

_convert_checkpoint_with_transformers

Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.

_maybe_adapt_state_dict_to_hf

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

_equally_divide_layers

Equally divide the state dict keys into num_shards shards.

_model_has_dtensors

True if any parameter is a DTensor (model is already sharded).

_is_custom_model

True if the model has a custom implementation in nemo_automodel/components/models/.

_load_hf_checkpoint_preserving_dtype

Load a HuggingFace checkpoint (safetensors) into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Returns None if the path is not a valid safetensors checkpoint.

_maybe_adapt_state_dict_from_hf

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.

API#

nemo_automodel.components.checkpoint.checkpointing._is_geq_torch_2_9() bool#

Check if the current torch version is greater than or equal to 2.9.0.

nemo_automodel.components.checkpoint.checkpointing._is_safetensors_checkpoint(path: str) bool#

Return True if path looks like a safetensors checkpoint (so we can preserve dtype); else DCP or other.

class nemo_automodel.components.checkpoint.checkpointing._AsyncSaveContext#

Internal container for async checkpointing state.

One instance is maintained for the model save and one for the optimizer save to keep staging/upload futures and the associated process group and stager together in a single place.

stager: Any | None#

None

process_group: Any | None#

None

future: Any | None#

None

staging_active: bool#

False

class nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig#

Configuration for checkpointing.

enabled: bool#

None

checkpoint_dir: str | pathlib.Path#

None

model_save_format: str#

None

model_cache_dir: str | pathlib.Path#

None

model_repo_id: str#

None

save_consolidated: bool#

None

is_peft: bool#

None

model_state_dict_keys: list[str]#

None

is_async: bool#

False

dequantize_base_checkpoint: bool | None#

None

original_model_root_dir: str | None#

None

skip_task_head_prefixes_for_base_model: list[str] | None#

None

single_rank_consolidation: bool#

False

staging_dir: str | None#

None

__post_init__()#

Convert a raw string such as “safetensors” into the right Enum.

class nemo_automodel.components.checkpoint.checkpointing.Checkpointer(
config: nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig,
dp_rank: int,
tp_rank: int,
pp_rank: int,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Supports:

  • HF sharded safetensors via custom storage reader/writer

  • Optional consolidated export (config, generation config, tokenizer)

  • PEFT adapter save/load handling

  • Async save for torch >= 2.9.0

Also provides DP-aware helpers for saving/loading auxiliary state and utilities to initialize from a base HF checkpoint.

Initialization

Initialize the checkpointer.

Parameters:
  • config – Checkpointing configuration.

  • dp_rank – Data parallel rank for the current process.

  • tp_rank – Tensor parallel rank for the current process.

  • pp_rank – Pipeline parallel rank for the current process.

  • moe_mesh – Optional device mesh used for MoE when adapting state dicts.

save_model(
model: torch.nn.Module,
weights_path: str,
peft_config: Optional[peft.PeftConfig] = None,
tokenizer: Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None,
) None#

Save model weights to weights_path/model.

Behavior:

  • PEFT: write adapter_model.safetensors and metadata on rank 0.

  • Safetensors + consolidation: emit HF artifacts under weights_path/model/consolidated and build a consolidated index.

  • Otherwise: use DCP with a Hugging Face or default storage writer to save shards.

Parameters:
  • model – Model to checkpoint.

  • weights_path – Base directory for checkpoints.

  • peft_config – Optional PEFT configuration when saving adapters.

  • tokenizer – Optional tokenizer to save with consolidated artifacts.

save_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Save optimizer (and optional scheduler) state to weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer whose state will be saved.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to include.

load_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Load optimizer (and optional scheduler) state from weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer to populate.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to populate.

load_model(
model: torch.nn.Module,
model_path: str,
is_init_step: bool = False,
use_checkpoint_id: bool = True,
key_mapping: Optional[dict[str, str]] = None,
) None#

Load model weights from model_path.

Behavior:

  • For PEFT (non-init): rank 0 reads adapter_model.safetensors, then broadcasts.

  • Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.

  • If the model exposes a state_dict_adapter, convert to/from HF format as needed.

  • For models requiring tensor merging (e.g., Mixtral), uses transformers’ conversion mapping.

Parameters:
  • model – Model or parallelized model parts to load into.

  • model_path – Path to the model checkpoint directory or HF snapshot.

  • is_init_step – If True, treat load as initialization from a base checkpoint.

  • use_checkpoint_id – Pass checkpoint_id to DCP if True; disable when using direct HF paths.

  • key_mapping – Optional key remapping when reading from HF checkpoints.

static initialize_model_weights(
model: torch.nn.Module,
device: torch.device,
peft_init_method: str | None = None,
) None#

Materialize meta-device parameters and initialize model weights.

Moves empty parameter shells to the target device, resets HF initialization flags, calls the model’s weight initialization method, and initializes any PEFT adapters.

Parameters:
  • model – Model whose weights should be initialized.

  • device – Target device for materialized parameters.

  • peft_init_method – Initialization method for PEFT adapters (e.g. “xavier”).

load_base_model(
model: torch.nn.Module,
device: torch.device,
root_dir: str,
model_name: str | None,
load_base_model: bool = True,
) None#

Load a model from the base Hugging Face checkpoint in parallel.

Parameters:
  • model – Model to load state into

  • device – Device to load model onto

  • root_dir – Root directory of the model cache or snapshots

  • model_name – Name of the model or an absolute path to a snapshot

  • load_base_model – If True, restore from HF base checkpoint

maybe_wait_for_staging() None#

Wait for the staging to finish if it is enabled.

async_wait() None#

Wait for the async save to finish.

save_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Save the stateful object.

This function is a helper function currently used to save the dataloader and rng state.

Parameters:
  • state – Stateful object to save

  • state_name – Name of the stateful object

  • path – Path to save stateful object

load_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Load the stateful object.

This function is a helper function currently used to load the dataloader and rng state.

Parameters:
  • state – Stateful object to load

  • state_name – Name of the stateful object

  • path – Path to load stateful object

close() None#

Close the checkpointer.

_do_load(
state_dict: dict[str, torch.Tensor],
path: str,
storage_reader: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader] = None,
is_init_step: bool = False,
) dict[str, torch.Tensor]#

Load a state dictionary from path using DCP or PEFT special-case logic.

Parameters:
  • state_dict – Mutable state dict to populate with tensors.

  • path – Checkpoint directory path.

  • storage_reader – Optional HF storage reader for safetensors.

  • is_init_step – True if loading from a base checkpoint during initialization.

Returns:

The populated state dictionary (may be replaced for PEFT).

_do_save(
state_dict: dict[str, torch.Tensor],
path: str,
storage_writer: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter] = None,
) Optional[torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse]#

Save a state dictionary to path using DCP or PEFT special-case logic.

  • For PEFT model saves: only rank 0 writes adapter_model.safetensors.

  • If async mode is enabled, schedule an asynchronous save.

Parameters:
  • state_dict – State dict to be serialized.

  • path – Checkpoint directory path.

  • storage_writer – Optional HF storage writer for safetensors sharding.

Returns:

Optional Future object if async mode is enabled.

_should_write_consolidated_safetensors() bool#

Whether to output consolidated HF weights along with sharded weights.

Returns True only for non-PEFT safetensors when consolidation is enabled.

_should_write_hf_metadata() bool#

Whether to write the HF artifacts.

_maybe_build_consolidated_index(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
state_dict: dict[str, torch.Tensor],
) Optional[dict[str, int]]#

Build FQN to shard index mapping for consolidated HF export.

Uses the base checkpoint index (if present), removes non-persistent keys, and assigns new keys to the last shard by default.

Parameters:
  • model_state – Wrapper exposing the primary model part.

  • state_dict – The state dict that will be saved.

Returns:

Mapping from FQN to shard index, or None when not consolidating.

_get_storage_writer(
consolidated_output_path: Optional[str],
fqn_to_index_mapping: Optional[dict[str, int]],
model_path: str,
consolidate_on_all_ranks: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter]#

Construct a Hugging Face storage writer for sharded safetensors.

Parameters:
  • consolidated_output_path – Optional path for consolidated artifacts.

  • fqn_to_index_mapping – Optional mapping from FQN to shard index.

  • model_path – Path where the model checkpoint is saved.

  • consolidate_on_all_ranks – If True, consolidate on all ranks on the main process.

Returns:

Configured _HuggingFaceStorageWriter or None for non-safetensors.

_get_storage_reader(
model_path: str,
key_mapping: Optional[dict[str, str]],
is_init_step: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader]#

Construct a Hugging Face storage reader when loading safetensors or during init.

Parameters:
  • model_path – Path to the model checkpoint directory or HF snapshot.

  • key_mapping – Optional key remapping for conversion.

  • is_init_step – If True, always produce a reader for base HF load.

Returns:

Configured _HuggingFaceStorageReader or None for other formats.

_get_original_model_path(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
) str | None#

Get the path to the original model from the Hugging Face checkpoint.

nemo_automodel.components.checkpoint.checkpointing.get_safetensors_index_path(
cache_dir: str | pathlib.Path | None,
repo_id: str | None,
) str | None#

Return the directory containing the first model.safetensors.index.json found for given model.

If no model.safetensors.index.json is found then it returns None.

For example, if the file located is

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json

this function will return the directory path

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe...

This will error if the model hasn’t been downloaded or if the cache directory is incorrect.

Parameters:
  • cache_dir – Path to cache directory

  • repo_id – Hugging Face repository ID

Returns:

Path to the directory containing the index file.

Raises:

FileNotFoundError – If the index file is not found.

nemo_automodel.components.checkpoint.checkpointing.to_empty_parameters_only(
model: torch.nn.Module,
*,
device: torch.device,
recurse: bool = True,
dtype: torch.dtype | None = None,
) torch.nn.Module#

Move parameters to the specified device without copying storage, skipping buffers.

Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.

Parameters:
  • model – The module to transform

  • device – Target device

  • recurse – Whether to recurse into child modules

Returns:

The same module instance

nemo_automodel.components.checkpoint.checkpointing.save_config(config: dict[str, Any], weights_path: str) None#

Save a config to a weights path.

Parameters:
  • config – Config to save

  • weights_path – Path to save config

nemo_automodel.components.checkpoint.checkpointing._ensure_dirs(*dirs: Optional[str]) None#

Create directories on all ranks and synchronize across ranks.

Parameters:

*dirs – One or more directory paths that should exist.

nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters(
model: torch.nn.Module,
peft_init_method: str,
) None#

Initialize the PEFT adapters with the scaled weights.

Parameters:
  • model – Model to initialize PEFT adapters for

  • peft_init_method – Method to initialize PEFT adapters e.g. “xavier”. See LinearLoRA for more details.

nemo_automodel.components.checkpoint.checkpointing._reinit_rope_buffers(
model: torch.nn.Module,
device: torch.device,
) None#

Recompute non-persistent RoPE inv_freq buffers for Nemotron-NAS models.

Parameters:
  • model – Model to reinitialize RoPE buffers for.

  • device – Device to create the new buffers on.

nemo_automodel.components.checkpoint.checkpointing._apply(module, fn, recurse=True) torch.nn.Module#

Apply a transformation function to parameters (and gradients) only.

Mirrors nn.Module.to_empty for parameters while skipping buffers. Respects future flags controlling in-place vs swap behavior and safely handles wrapper subclasses.

Parameters:
  • module – Module whose parameters are to be transformed.

  • fn – Callable applied to each parameter (and its gradient).

  • recurse – Whether to recurse into child modules.

Returns:

The same module instance after transformation.

nemo_automodel.components.checkpoint.checkpointing._apply_key_mapping(
state_dict: dict[str, torch.Tensor],
key_mapping: dict[str, str],
) dict[str, torch.Tensor]#

Rename state-dict keys using regex-based key_mapping.

This mirrors the renaming logic used by the DCP / HuggingFace storage reader but operates directly on an in-memory state dict. It is needed when loading safetensors checkpoints outside of DCP so that HF checkpoint keys (e.g. language_model.model.X) are translated to the model’s parameter FQNs (e.g. model.language_model.X).

Parameters:
  • state_dict – Original state dict whose keys may need renaming.

  • key_mapping – {regex_pattern: replacement} pairs applied in order.

Returns:

A new dict with renamed keys.

nemo_automodel.components.checkpoint.checkpointing._load_full_state_dict_into_model(
model_parts: list[torch.nn.Module],
state_dict: dict[str, torch.Tensor],
) None#

Load a full (non-sharded) state dict into a potentially FSDP-wrapped model.

Every rank must supply the full state dict. PyTorch’s set_model_state_dict with full_state_dict=True (but not broadcast_from_rank0) calls _distribute_state_dict which lets each rank independently slice its local DTensor shard from the full tensor – no NCCL collectives are needed.

We intentionally avoid broadcast_from_rank0=True because it introduces an asymmetric workload: rank 0 does a synchronous CPU→GPU copy (.to(device)) per tensor while other ranks only do torch.empty (async allocation). The non-src ranks race ahead enqueuing hundreds of NCCL broadcasts that rank 0 cannot keep up with, leading to a 60 s NCCL watchdog timeout.

After loading, floating-point parameters are converted to match the checkpoint dtype. PyTorch’s set_model_state_dict uses copy semantics (assign=False) for non-meta parameters, which preserves the model’s initialisation dtype instead of the checkpoint dtype. The post-load fixup ensures the safetensors dtype (e.g. bf16) is honoured.

Parameters:
  • model_parts – List of model parts (for pipeline parallelism)

  • state_dict – Full state dict with regular tensors. Must be populated on every rank (not just rank 0).

nemo_automodel.components.checkpoint.checkpointing._convert_checkpoint_with_transformers(
model: torch.nn.Module,
model_path: str,
key_mapping: Optional[dict[str, str]] = None,
) Optional[dict[str, torch.Tensor]]#

Convert a checkpoint using transformers’ conversion mapping for models that need tensor merging.

This handles MoE models like Mixtral where the checkpoint has individual expert weights but the model uses grouped expert tensors. The transformers library’s WeightConverter operations handle the tensor merging (MergeModulelist, Concatenate).

This function converts the state dict WITHOUT loading it into the model, so it can be used with FSDP-aware loading mechanisms.

Parameters:
  • model – The model (used to get conversion mapping and target keys).

  • model_path – Path to the HuggingFace checkpoint directory.

  • key_mapping – Optional additional key mapping.

Returns:

Converted state dict ready for loading, or None if conversion failed.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_to_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
quantization: bool = False,
**kwargs,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

nemo_automodel.components.checkpoint.checkpointing._equally_divide_layers(
num_shards: int,
keys: list[str],
) dict[str, int]#

Equally divide the state dict keys into num_shards shards.

nemo_automodel.components.checkpoint.checkpointing._model_has_dtensors(module: torch.nn.Module) bool#

True if any parameter is a DTensor (model is already sharded).

nemo_automodel.components.checkpoint.checkpointing._is_custom_model(module: torch.nn.Module) bool#

True if the model has a custom implementation in nemo_automodel/components/models/.

nemo_automodel.components.checkpoint.checkpointing._load_hf_checkpoint_preserving_dtype(
model_path: str,
) Optional[dict[str, torch.Tensor]]#

Load a HuggingFace checkpoint (safetensors) into a new state dict so tensor dtypes match the checkpoint (e.g. bf16). Used when loading the base model so FSDP sees uniform dtype instead of the model’s init dtypes (e.g. float32). Returns None if the path is not a valid safetensors checkpoint.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_from_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.