nemo_automodel.components.checkpoint.checkpointing#

Module Contents#

Classes#

_AsyncSaveContext

Internal container for async checkpointing state.

CheckpointingConfig

Configuration for checkpointing.

Checkpointer

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Functions#

_is_geq_torch_2_9

Check if the current torch version is greater than or equal to 2.9.0.

get_safetensors_index_path

Return the directory containing the first model.safetensors.index.json found for given model.

to_empty_parameters_only

Move parameters to the specified device without copying storage, skipping buffers.

save_config

Save a config to a weights path.

_ensure_dirs

Create directories on all ranks and synchronize across ranks.

_init_peft_adapters

Initialize the PEFT adapters with the scaled weights.

_apply

Apply a transformation function to parameters (and gradients) only.

_maybe_adapt_state_dict_to_hf

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

_maybe_adapt_state_dict_from_hf

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.

API#

nemo_automodel.components.checkpoint.checkpointing._is_geq_torch_2_9() bool#

Check if the current torch version is greater than or equal to 2.9.0.

class nemo_automodel.components.checkpoint.checkpointing._AsyncSaveContext#

Internal container for async checkpointing state.

One instance is maintained for the model save and one for the optimizer save to keep staging/upload futures and the associated process group and stager together in a single place.

stager: Any | None#

None

process_group: Any | None#

None

future: Any | None#

None

staging_active: bool#

False

class nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig#

Configuration for checkpointing.

enabled: bool#

None

checkpoint_dir: str | pathlib.Path#

None

model_save_format: str#

None

model_cache_dir: str | pathlib.Path#

None

model_repo_id: str#

None

save_consolidated: bool#

None

is_peft: bool#

None

model_state_dict_keys: list[str]#

None

is_async: bool#

False

dequantize_base_checkpoint: bool | None#

None

original_model_root_dir: str | None#

None

skip_task_head_prefixes_for_base_model: list[str] | None#

None

__post_init__()#

Convert a raw string such as β€œsafetensors” into the right Enum.

class nemo_automodel.components.checkpoint.checkpointing.Checkpointer(
config: nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig,
dp_rank: int,
tp_rank: int,
pp_rank: int,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

High-level checkpoint manager built on torch.distributed.checkpoint (DCP).

Supports:

  • HF sharded safetensors via custom storage reader/writer

  • Optional consolidated export (config, generation config, tokenizer)

  • PEFT adapter save/load handling

  • Async save for torch >= 2.9.0

Also provides DP-aware helpers for saving/loading auxiliary state and utilities to initialize from a base HF checkpoint.

Initialization

Initialize the checkpointer.

Parameters:
  • config – Checkpointing configuration.

  • dp_rank – Data parallel rank for the current process.

  • tp_rank – Tensor parallel rank for the current process.

  • pp_rank – Pipeline parallel rank for the current process.

  • moe_mesh – Optional device mesh used for MoE when adapting state dicts.

save_model(
model: torch.nn.Module,
weights_path: str,
peft_config: Optional[peft.PeftConfig] = None,
tokenizer: Optional[transformers.tokenization_utils.PreTrainedTokenizerBase] = None,
) None#

Save model weights to weights_path/model.

Behavior:

  • PEFT: write adapter_model.safetensors and metadata on rank 0.

  • Safetensors + consolidation: emit HF artifacts under weights_path/model/consolidated and build a consolidated index.

  • Otherwise: use DCP with a Hugging Face or default storage writer to save shards.

Parameters:
  • model – Model to checkpoint.

  • weights_path – Base directory for checkpoints.

  • peft_config – Optional PEFT configuration when saving adapters.

  • tokenizer – Optional tokenizer to save with consolidated artifacts.

save_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Save optimizer (and optional scheduler) state to weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer whose state will be saved.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to include.

load_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
) None#

Load optimizer (and optional scheduler) state from weights_path/optim using DCP.

Parameters:
  • optimizer – Optimizer to populate.

  • model – Model providing partitioning context for the optimizer wrapper.

  • weights_path – Base directory for checkpoints.

  • scheduler – Optional LR scheduler to populate.

load_model(
model: torch.nn.Module,
model_path: str,
is_init_step: bool = False,
use_checkpoint_id: bool = True,
key_mapping: Optional[dict[str, str]] = None,
) None#

Load model weights from model_path.

Behavior:

  • For PEFT (non-init): rank 0 reads adapter_model.safetensors, then broadcasts.

  • Otherwise: use DCP with a Hugging Face or default storage reader to populate the state dict.

  • If the model exposes a state_dict_adapter, convert to/from HF format as needed.

Parameters:
  • model – Model or parallelized model parts to load into.

  • model_path – Path to the model checkpoint directory or HF snapshot.

  • is_init_step – If True, treat load as initialization from a base checkpoint.

  • use_checkpoint_id – Pass checkpoint_id to DCP if True; disable when using direct HF paths.

  • key_mapping – Optional key remapping when reading from HF checkpoints.

load_base_model(
model: torch.nn.Module,
device: torch.device,
root_dir: str,
model_name: str | None,
peft_init_method: str,
load_base_model: bool = True,
) None#

Load a model from the base Hugging Face checkpoint in parallel.

Parameters:
  • model – Model to load state into

  • device – Device to load model onto

  • root_dir – Root directory of the model cache or snapshots

  • model_name – Name of the model or an absolute path to a snapshot

  • peft_init_method – Initialization method used for PEFT adapters

  • load_base_model – If True, restore from HF base checkpoint

maybe_wait_for_staging() None#

Wait for the staging to finish if it is enabled.

async_wait() None#

Wait for the async save to finish.

save_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Save the stateful object.

This function is a helper function currently used to save the dataloader and rng state.

Parameters:
  • state – Stateful object to save

  • state_name – Name of the stateful object

  • path – Path to save stateful object

load_on_dp_ranks(
state: Any,
state_name: str,
path: str,
) None#

Load the stateful object.

This function is a helper function currently used to load the dataloader and rng state.

Parameters:
  • state – Stateful object to load

  • state_name – Name of the stateful object

  • path – Path to load stateful object

close() None#

Close the checkpointer.

_do_load(
state_dict: dict[str, torch.Tensor],
path: str,
storage_reader: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader] = None,
is_init_step: bool = False,
) dict[str, torch.Tensor]#

Load a state dictionary from path using DCP or PEFT special-case logic.

Parameters:
  • state_dict – Mutable state dict to populate with tensors.

  • path – Checkpoint directory path.

  • storage_reader – Optional HF storage reader for safetensors.

  • is_init_step – True if loading from a base checkpoint during initialization.

Returns:

The populated state dictionary (may be replaced for PEFT).

_do_save(
state_dict: dict[str, torch.Tensor],
path: str,
storage_writer: Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter] = None,
) Optional[torch.distributed.checkpoint.state_dict_saver.AsyncSaveResponse]#

Save a state dictionary to path using DCP or PEFT special-case logic.

  • For PEFT model saves: only rank 0 writes adapter_model.safetensors.

  • If async mode is enabled, schedule an asynchronous save.

Parameters:
  • state_dict – State dict to be serialized.

  • path – Checkpoint directory path.

  • storage_writer – Optional HF storage writer for safetensors sharding.

Returns:

Optional Future object if async mode is enabled.

_should_write_consolidated_safetensors() bool#

Whether to output consolidated HF weights along with sharded weights.

Returns True only for non-PEFT safetensors when consolidation is enabled.

_should_write_hf_metadata() bool#

Whether to write the HF artifacts.

_maybe_build_consolidated_index(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
state_dict: dict[str, torch.Tensor],
) Optional[dict[str, int]]#

Build FQN to shard index mapping for consolidated HF export.

Uses the base checkpoint index (if present), removes non-persistent keys, and assigns new keys to the last shard by default.

Parameters:
  • model_state – Wrapper exposing the primary model part.

  • state_dict – The state dict that will be saved.

Returns:

Mapping from FQN to shard index, or None when not consolidating.

_get_storage_writer(
consolidated_output_path: Optional[str],
fqn_to_index_mapping: Optional[dict[str, int]],
model_path: str,
consolidate_on_all_ranks: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageWriter]#

Construct a Hugging Face storage writer for sharded safetensors.

Parameters:
  • consolidated_output_path – Optional path for consolidated artifacts.

  • fqn_to_index_mapping – Optional mapping from FQN to shard index.

  • model_path – Path where the model checkpoint is saved.

  • consolidate_on_all_ranks – If True, consolidate on all ranks on the main process.

Returns:

Configured _HuggingFaceStorageWriter or None for non-safetensors.

_get_storage_reader(
model_path: str,
key_mapping: Optional[dict[str, str]],
is_init_step: bool = False,
) Optional[nemo_automodel.components.checkpoint._backports.hf_storage._HuggingFaceStorageReader]#

Construct a Hugging Face storage reader when loading safetensors or during init.

Parameters:
  • model_path – Path to the model checkpoint directory or HF snapshot.

  • key_mapping – Optional key remapping for conversion.

  • is_init_step – If True, always produce a reader for base HF load.

Returns:

Configured _HuggingFaceStorageReader or None for other formats.

_get_original_model_path(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
) str | None#

Get the path to the original model from the Hugging Face checkpoint.

nemo_automodel.components.checkpoint.checkpointing.get_safetensors_index_path(
cache_dir: str,
repo_id: str | None,
) str | None#

Return the directory containing the first model.safetensors.index.json found for given model.

If no model.safetensors.index.json is found then it returns None.

For example, if the file located is

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json

this function will return the directory path

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe...

This will error if the model hasn’t been downloaded or if the cache directory is incorrect.

Parameters:
  • cache_dir – Path to cache directory

  • repo_id – Hugging Face repository ID

Returns:

Path to the directory containing the index file.

Raises:

FileNotFoundError – If the index file is not found.

nemo_automodel.components.checkpoint.checkpointing.to_empty_parameters_only(
model: torch.nn.Module,
*,
device: torch.device,
recurse: bool = True,
dtype: torch.dtype | None = None,
) torch.nn.Module#

Move parameters to the specified device without copying storage, skipping buffers.

Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.

Parameters:
  • model – The module to transform

  • device – Target device

  • recurse – Whether to recurse into child modules

Returns:

The same module instance

nemo_automodel.components.checkpoint.checkpointing.save_config(config: dict[str, Any], weights_path: str) None#

Save a config to a weights path.

Parameters:
  • config – Config to save

  • weights_path – Path to save config

nemo_automodel.components.checkpoint.checkpointing._ensure_dirs(*dirs: Optional[str]) None#

Create directories on all ranks and synchronize across ranks.

Parameters:

*dirs – One or more directory paths that should exist.

nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters(
model: torch.nn.Module,
peft_init_method: str,
) None#

Initialize the PEFT adapters with the scaled weights.

Parameters:
  • model – Model to initialize PEFT adapters for

  • peft_init_method – Method to initialize PEFT adapters e.g. β€œxavier”. See LinearLoRA for more details.

nemo_automodel.components.checkpoint.checkpointing._apply(module, fn, recurse=True) torch.nn.Module#

Apply a transformation function to parameters (and gradients) only.

Mirrors nn.Module.to_empty for parameters while skipping buffers. Respects future flags controlling in-place vs swap behavior and safely handles wrapper subclasses.

Parameters:
  • module – Module whose parameters are to be transformed.

  • fn – Callable applied to each parameter (and its gradient).

  • recurse – Whether to recurse into child modules.

Returns:

The same module instance after transformation.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_to_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
quantization: bool = False,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict to the Hugging Face format.

nemo_automodel.components.checkpoint.checkpointing._maybe_adapt_state_dict_from_hf(
model_part: torch.nn.Module,
state_dict: dict[str, torch.Tensor],
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) dict[str, torch.Tensor]#

Custom models use state dict adapters to convert the state dict from the Hugging Face format to the native format.