Source code for nemo_automodel.checkpoint.checkpointing

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpoint management utilities for HF models."""

import glob
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import json

import torch
import torch.distributed
import torch.distributed.checkpoint as dcp
import torch.nn as nn

from transformers import PreTrainedModel
from safetensors.torch import save_file
from safetensors import safe_open

from nemo_automodel.checkpoint._backports.filesystem import SerializationFormat
from nemo_automodel.checkpoint._backports.hf_storage import (
    _HuggingFaceStorageReader,
    _HuggingFaceStorageWriter,
    get_fqn_to_file_index_mapping,
)
from nemo_automodel.checkpoint.stateful_wrappers import ModelState, OptimizerState


[docs] @dataclass class CheckpointingConfig: """ Configuration for checkpointing. """ enabled: bool checkpoint_dir: str | Path model_save_format: SerializationFormat | str model_cache_dir: str | Path model_repo_id: str save_consolidated: bool is_peft: bool
[docs] def __post_init__(self): """ Convert a raw string such as "safetensors" into the right Enum. """ if isinstance(self.model_save_format, str): self.model_save_format = SerializationFormat[ self.model_save_format.upper() ]
[docs] def save_model( model: nn.Module | PreTrainedModel, weights_path: str, checkpoint_config: CheckpointingConfig, ): """ Save a model state dictionary to a weights path. This function can save a model in the following formats: - safetensors (in HF format) - torch_save (in DCP format) Args: model: Model to save weights_path: Path to save model weights checkpoint_config: Checkpointing configuration """ # TODO(@adil-a): Need to add support for PEFT. # We also need to eventually add suport for HSDP, so we only save on non-duplicate ranks. # Add functionality to chunk different layers for different ranks to save. # The above functionality will also make it trivial to get a FQN -> rank mapping # which doesn't leave out any user modified layers. # This is because we need to create the mapping on the fly from the model state dict. model_path = os.path.join(weights_path, "model") consolidated_model_path = None if checkpoint_config.save_consolidated: consolidated_model_path = os.path.join(model_path, "consolidated") if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: os.makedirs(model_path, exist_ok=True) if ( checkpoint_config.save_consolidated and checkpoint_config.model_save_format == SerializationFormat.SAFETENSORS and not checkpoint_config.is_peft ): os.makedirs(consolidated_model_path, exist_ok=True) # save the config.json file with open(os.path.join(consolidated_model_path, "config.json"), "w") as f: f.write(model.config.to_json_string()) # Ensure all ranks wait for rank 0 to handle directories if torch.distributed.is_initialized(): torch.distributed.barrier() model_state = ModelState(model, checkpoint_config.model_save_format, checkpoint_config.is_peft) if checkpoint_config.is_peft: if not isinstance(model, PreTrainedModel): raise ValueError("PEFT checkpointing is only supported for PreTrainedModel") if not hasattr(model, "_automodel_peft_config"): raise ValueError("PEFT checkpointing is only supported for models that have been trained with PEFT. " "Please use the `apply_lora_to_linear_modules` function to apply LoRA to the model.") peft_config = model._automodel_peft_config state_dict = model_state.state_dict() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: with open(os.path.join(model_path, "adapter_config.json"), "w") as f: json.dump(peft_config, f, indent=2, sort_keys=True) save_file(state_dict, os.path.join(model_path, "adapter_model.safetensors")) elif checkpoint_config.model_save_format == SerializationFormat.SAFETENSORS: fqn_to_file_index_mapping = None if checkpoint_config.save_consolidated: # we first need to find the FQN -> .safetensors mapping index_path = _get_safetensors_index_path( checkpoint_config.model_cache_dir, checkpoint_config.model_repo_id, ) if index_path: fqn_to_file_index_mapping = get_fqn_to_file_index_mapping(index_path) # Add any missing keys from the model_state_dict # These will go to the same file as the last file (or file 1 for single-file models) default_index = max(fqn_to_file_index_mapping.values()) # TODO:(@adil-a): This will need to change when we add PP. Maybe we can cache the keys in ModelState. for fqn in list(model.state_dict().keys()): if fqn not in fqn_to_file_index_mapping: if model_state.is_tied_lm_head and fqn == "lm_head.weight": continue if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: print(f"Adding missing key to mapping: {fqn}") fqn_to_file_index_mapping[fqn] = default_index storage_writer = _HuggingFaceStorageWriter( path=model_path, save_sharded=True, consolidated_output_path=consolidated_model_path, fqn_to_index_mapping=fqn_to_file_index_mapping, ) dcp.save( {"model": model_state}, checkpoint_id=model_path, storage_writer=storage_writer, ) elif checkpoint_config.model_save_format == SerializationFormat.TORCH_SAVE: dcp.save({"model": model_state}, checkpoint_id=model_path) else: raise ValueError(f"Unsupported model save format: {checkpoint_config.model_save_format}")
[docs] def load_model( model: torch.nn.Module | PreTrainedModel, weights_path: str, checkpoint_config: CheckpointingConfig, ): """ Load a model state dictionary from a weights path. Args: model: Model to load state into weights_path: Path to load model weights from checkpoint_config: Checkpointing configuration """ model_path = os.path.join(weights_path, "model") # Validate checkpoint directory if not os.path.exists(model_path): raise FileNotFoundError(f"Model path {model_path} does not exist") model_state = ModelState(model, checkpoint_config.model_save_format, checkpoint_config.is_peft) if checkpoint_config.is_peft: if not isinstance(model, PreTrainedModel): raise ValueError("PEFT checkpointing is only supported for PreTrainedModel") state_dict = model.state_dict() if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: with safe_open(os.path.join(model_path, "adapter_model.safetensors"), framework="pt") as f: state_dict = {k: f.get_tensor(k) for k in f.keys()} # since we're loading the PEFT adapters on rank0, we don't need to call dcp.load # the call below will broadcast from rank0 to all other ranks model_state.load_state_dict(state_dict) elif checkpoint_config.model_save_format == SerializationFormat.SAFETENSORS: storage_reader = _HuggingFaceStorageReader(path=model_path) dcp.load( state_dict={"model": model_state}, checkpoint_id=model_path, storage_reader=storage_reader, planner=dcp.DefaultLoadPlanner(), ) elif checkpoint_config.model_save_format == SerializationFormat.TORCH_SAVE: dcp.load(state_dict={"model": model_state}, checkpoint_id=model_path) else: raise ValueError(f"Unsupported model save format: {checkpoint_config.model_save_format}")
[docs] def save_optimizer( optimizer: torch.optim.Optimizer, model: torch.nn.Module, weights_path: str, scheduler: Optional[Any] = None, ): """ Save an optimizer state dictionary to a weights path. Args: optimizer: Optimizer to save model: Model to save optimizer state for weights_path: Path to save optimizer weights scheduler: Optional scheduler to save """ optimizer_path = os.path.join(weights_path, "optim") if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: os.makedirs(optimizer_path, exist_ok=True) optimizer_state = OptimizerState(model, optimizer, scheduler) dcp.save({"optim": optimizer_state}, checkpoint_id=optimizer_path)
[docs] def load_optimizer( optimizer: torch.optim.Optimizer, model: torch.nn.Module, weights_path: str, scheduler: Optional[Any] = None, ): """ Load an optimizer state dictionary from a weights path. Args: optimizer: Optimizer to load state into model: Model to load optimizer state for weights_path: Path to load optimizer weights from scheduler: Optional scheduler to load state into """ optimizer_path = os.path.join(weights_path, "optim") if not os.path.exists(optimizer_path): raise FileNotFoundError(f"Optimizer path {optimizer_path} does not exist") optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} dcp.load(state_dict=optimizer_state, checkpoint_id=optimizer_path)
[docs] def _get_safetensors_index_path(cache_dir: str, repo_id: str) -> str: """ Return the directory containing the first `model.safetensors.index.json` found for given model. If no `model.safetensors.index.json` is found then it returns None. For example, if the file located is /opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe.../model.safetensors.index.json this function will return the directory path /opt/models/models--meta-llama--Llama-3.2-3B/snapshots/13afe... This will error if the model hasn't been downloaded or if the cache directory is incorrect. Args: cache_dir: Path to cache directory repo_id: Hugging Face repository ID Returns: Path to the directory containing the index file. Raises: FileNotFoundError: If the index file is not found. """ repo_dir = f"models--{repo_id.replace('/', '--')}" snapshots_root = Path(cache_dir) / repo_dir / "snapshots" # Look for an index file inside any snapshot directory. pattern = snapshots_root / "*" / "model.safetensors.index.json" matches = glob.glob(str(pattern)) if matches: # Return the directory path that contains the index file. return str(Path(matches[0]).parent) # Fall back: if no index file, return the first available snapshot directory (if any). # This is the case for single-file models. snapshot_dirs = [p for p in glob.glob(str(snapshots_root / "*")) if Path(p).is_dir()] if snapshot_dirs: try: return snapshot_dirs[0] except IndexError: raise FileNotFoundError(f"No snapshot directories found in {snapshots_root}")