nemo_automodel.components.checkpoint.checkpointing#

Checkpoint management utilities for HF models.

Module Contents#

Classes#

CheckpointingConfig

Configuration for checkpointing.

Functions#

save_model

Save a model state dictionary to a weights path.

load_model_from_base_checkpoint

Load a model from the base Hugging Face checkpoint in parallel.

load_model

Load a model state dictionary from a weights path.

save_optimizer

Save an optimizer state dictionary to a weights path.

load_optimizer

Load an optimizer state dictionary from a weights path.

save_dp_aware_helper

Save the stateful object.

load_dp_aware_helper

Load the stateful object.

save_config

Save a config to a weights path.

get_safetensors_index_path

Return the directory containing the first model.safetensors.index.json found for given model.

to_empty_parameters_only

Move parameters to the specified device without copying storage, skipping buffers.

_save_peft_adapters

Save PEFT adapters to a weights path.

_get_hf_peft_config

Get the PEFT config in the format expected by Hugging Face.

_get_automodel_peft_metadata

Get the PEFT metadata in the format expected by Automodel.

_extract_target_modules

Extract the target modules from the model.

_init_peft_adapters

Initialize the PEFT adapters with the scaled weights.

_apply

API#

class nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig#

Configuration for checkpointing.

enabled: bool#

None

checkpoint_dir: str | pathlib.Path#

None

model_save_format: nemo_automodel.components.checkpoint._backports.filesystem.SerializationFormat | str#

None

model_cache_dir: str | pathlib.Path#

None

model_repo_id: str#

None

save_consolidated: bool#

None

is_peft: bool#

None

model_state_dict_keys: list[str]#

None

dequantize_base_checkpoint: bool#

False

__post_init__()#

Convert a raw string such as β€œsafetensors” into the right Enum.

nemo_automodel.components.checkpoint.checkpointing.save_model(
model: torch.nn.Module,
weights_path: str,
checkpoint_config: nemo_automodel.components.checkpoint.checkpointing.CheckpointingConfig,
peft_config: Optional[peft.PeftConfig] = None,
tokenizer: Optional[transformers.tokenization_utils.PreTrainedTokenizerBase] = None,
)#

Save a model state dictionary to a weights path.

This function can save a model in the following formats:

  • safetensors (in HF format)

  • torch_save (in DCP format)

Parameters:
  • model – Model to save

  • weights_path – Path to save model weights

  • checkpoint_config – Checkpointing configuration

  • peft_config – PEFT config

  • tokenizer – Tokenizer. Only saved if checkpoint_config.save_consolidated is True.

nemo_automodel.components.checkpoint.checkpointing.load_model_from_base_checkpoint(
model: torch.nn.Module,
device: torch.device,
is_peft: bool,
root_dir: str,
model_name: str | None,
peft_init_method: str,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
load_base_model: bool = True,
quantization: bool = False,
)#

Load a model from the base Hugging Face checkpoint in parallel.

Parameters:
  • model – Model to load state into

  • device – Device to load model onto

  • is_peft – Whether the model is PEFT

  • root_dir – Root directory of the model

  • model_name – Name of the model

nemo_automodel.components.checkpoint.checkpointing.load_model(
model: torch.nn.Module,
model_path: str,
model_save_format: nemo_automodel.components.checkpoint._backports.filesystem.SerializationFormat,
*,
is_peft: bool = False,
is_init_step: bool = False,
use_checkpoint_id: bool = True,
key_mapping: Optional[dict[str, str]] = None,
load_peft_adapters: bool = True,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
quantization: bool = False,
)#

Load a model state dictionary from a weights path.

Parameters:
  • model – Model to load state into

  • model_path – Path to load model weights from

  • model_save_format – Model save format

  • is_peft – Whether the model is PEFT

  • is_init_step – Whether the model is being initialized

  • use_checkpoint_id – Whether to use the checkpoint ID

  • key_mapping – Key mapping for the model

  • load_peft_adapters – Whether to load PEFT adapters

  • moe_mesh – MoE mesh for distributed loading

nemo_automodel.components.checkpoint.checkpointing.save_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
)#

Save an optimizer state dictionary to a weights path.

Parameters:
  • optimizer – Optimizer to save

  • model – Model to save optimizer state for

  • weights_path – Path to save optimizer weights

  • scheduler – Optional scheduler to save

nemo_automodel.components.checkpoint.checkpointing.load_optimizer(
optimizer: torch.optim.Optimizer,
model: torch.nn.Module,
weights_path: str,
scheduler: Optional[Any] = None,
)#

Load an optimizer state dictionary from a weights path.

Parameters:
  • optimizer – Optimizer to load state into

  • model – Model to load optimizer state for

  • weights_path – Path to load optimizer weights from

  • scheduler – Optional scheduler to load state into

nemo_automodel.components.checkpoint.checkpointing.save_dp_aware_helper(
state: Any,
state_name: str,
path: str,
dp_rank: int,
tp_rank: int,
pp_rank: int,
)#

Save the stateful object.

This function is a helper function currently used to save the dataloader and rng state.

Parameters:
  • state – Stateful object to save

  • state_name – Name of the stateful object

  • path – Path to save stateful object

  • dp_rank – Data parallel rank

  • tp_rank – Tensor parallel rank

  • pp_rank – Pipeline parallel rank

nemo_automodel.components.checkpoint.checkpointing.load_dp_aware_helper(
state: Any,
state_name: str,
path: str,
dp_rank: int,
)#

Load the stateful object.

This function is a helper function currently used to load the dataloader and rng state.

Parameters:
  • state – Stateful object to load

  • state_name – Name of the stateful object

  • path – Path to load stateful object

  • dp_rank – Data parallel rank

nemo_automodel.components.checkpoint.checkpointing.save_config(config: dict[str, Any], weights_path: str)#

Save a config to a weights path.

Parameters:
  • config – Config to save

  • weights_path – Path to save config

nemo_automodel.components.checkpoint.checkpointing.get_safetensors_index_path(cache_dir: str, repo_id: str) str#

Return the directory containing the first model.safetensors.index.json found for given model.

If no model.safetensors.index.json is found then it returns None.

For example, if the file located is

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json

this function will return the directory path

/opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe...

This will error if the model hasn’t been downloaded or if the cache directory is incorrect.

Parameters:
  • cache_dir – Path to cache directory

  • repo_id – Hugging Face repository ID

Returns:

Path to the directory containing the index file.

Raises:

FileNotFoundError – If the index file is not found.

nemo_automodel.components.checkpoint.checkpointing.to_empty_parameters_only(
model: torch.nn.Module,
*,
device: torch.device,
recurse: bool = True,
dtype: torch.dtype | None = None,
) torch.nn.Module#

Move parameters to the specified device without copying storage, skipping buffers.

Mirrors torch.nn.Module.to_empty but applies only to parameters, not buffers.

Parameters:
  • model – The module to transform

  • device – Target device

  • recurse – Whether to recurse into child modules

Returns:

The same module instance

nemo_automodel.components.checkpoint.checkpointing._save_peft_adapters(
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
peft_config: peft.PeftConfig,
model_path: str,
)#

Save PEFT adapters to a weights path.

nemo_automodel.components.checkpoint.checkpointing._get_hf_peft_config(
peft_config: peft.PeftConfig,
model_state: nemo_automodel.components.checkpoint.stateful_wrappers.ModelState,
) dict#

Get the PEFT config in the format expected by Hugging Face.

nemo_automodel.components.checkpoint.checkpointing._get_automodel_peft_metadata(peft_config: peft.PeftConfig) dict#

Get the PEFT metadata in the format expected by Automodel.

nemo_automodel.components.checkpoint.checkpointing._extract_target_modules(model: torch.nn.Module) list[str]#

Extract the target modules from the model.

Note: When torch.compile is used, module names get prefixed with β€˜_orig_mod.’. This function strips those prefixes to get the original module names.

nemo_automodel.components.checkpoint.checkpointing._init_peft_adapters(model: torch.nn.Module, peft_init_method: str)#

Initialize the PEFT adapters with the scaled weights.

Parameters:
  • model – Model to initialize PEFT adapters for

  • peft_init_method – Method to initialize PEFT adapters e.g. β€œxavier”. See LinearLoRA for more details.

nemo_automodel.components.checkpoint.checkpointing._apply(module, fn, recurse=True)#