Source code for nemo_rl.utils.native_checkpoint

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Checkpoint management utilities for HF models."""

import os
from typing import Any, Optional

import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.state_dict import (
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful
from transformers import AutoConfig, AutoTokenizer


## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
[docs] class ModelState(Stateful): """Helper class for tracking model state in distributed checkpointing. This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs. Args: model: The PyTorch model to track. """ def __init__(self, model): self.model = model
[docs] def state_dict(self): """Get the model's state dictionary. Returns: dict: Dictionary containing the model's state dict with CPU offloading enabled. """ # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT model_state_dict = get_model_state_dict( self.model, options=torch.distributed.checkpoint.state_dict.StateDictOptions( cpu_offload=True ), ) return model_state_dict
[docs] def load_state_dict(self, state_dict): """Load the state dictionary into the model. Args: state_dict (dict): State dictionary to load. """ # sets our state dicts on the model, now that we've loaded set_model_state_dict( self.model, state_dict, )
[docs] class OptimizerState(Stateful): """Helper class for tracking optimizer state in distributed checkpointing. This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs. Args: model: The PyTorch model associated with the optimizer. optimizer: The optimizer to track. scheduler: Optional learning rate scheduler. """ def __init__(self, model, optimizer, scheduler=None): self.model = model self.optimizer = optimizer self.scheduler = scheduler
[docs] def state_dict(self): """Get the optimizer and scheduler state dictionaries. Returns: dict: Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. """ # this line automatically manages FSDP FQN's, as well as sets the default state dict type to FSDP.SHARDED_STATE_DICT optimizer_state_dict = get_optimizer_state_dict( self.model, self.optimizer, options=torch.distributed.checkpoint.state_dict.StateDictOptions( cpu_offload=True ), ) state_dict = { "optim": optimizer_state_dict, } if self.scheduler is not None: state_dict["sched"] = self.scheduler.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict): """Load the state dictionaries into the optimizer and scheduler. Args: state_dict (dict): State dictionary containing optimizer and scheduler states to load. """ # sets our state dicts on the optimizer, now that we've loaded set_optimizer_state_dict( self.model, self.optimizer, state_dict["optim"], ) ## load the scheduler state if it exists if "sched" in state_dict: self.scheduler.load_state_dict(state_dict["sched"])
[docs] def save_checkpoint( model, weights_path: str, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, optimizer_path: Optional[str] = None, tokenizer: Optional[Any] = None, tokenizer_path: Optional[str] = None, ) -> None: """Save a checkpoint of the model and optionally optimizer state. Args: model: The PyTorch model to save weights_path: Path to save model weights optimizer: Optional optimizer to save scheduler: Optional scheduler to save optimizer_path: Path to save optimizer state (required if optimizer provided) """ model_state = {"model": ModelState(model)} dcp.save(model_state, checkpoint_id=weights_path) if optimizer is not None: if optimizer_path is None: raise ValueError( "optimizer_path must be provided when saving optimizer state" ) optimizer_state = {"optim": OptimizerState(model, optimizer, scheduler)} dcp.save(optimizer_state, checkpoint_id=optimizer_path) if tokenizer is not None: if tokenizer_path is None: raise ValueError( "tokenizer_path must be provided when saving tokenizer state" ) tokenizer.save_pretrained(tokenizer_path)
[docs] def load_checkpoint( model, weights_path: str, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, optimizer_path: Optional[str] = None, ) -> None: """Load a model weights and optionally optimizer state. Args: model: The PyTorch model whose weights to update weights_path: Path to load model weights from optimizer: Optional optimizer to load state into scheduler: Optional scheduler to load state into optimizer_path: Path to load optimizer state from (required if optimizer provided) """ print(f"Loading weights from {weights_path}") model_state_dict = {"model": ModelState(model)} dcp.load(state_dict=model_state_dict, checkpoint_id=weights_path) if optimizer is not None: if optimizer_path is None: raise ValueError( "optimizer_path must be provided when loading optimizer state" ) print(f"Loading optimizer from {optimizer_path}") optimizer_state_dict = {"optim": OptimizerState(model, optimizer, scheduler)} dcp.load(state_dict=optimizer_state_dict, checkpoint_id=optimizer_path)
[docs] def convert_dcp_to_hf( dcp_ckpt_path: str, hf_ckpt_path: str, model_name_or_path: str, tokenizer_name_or_path: str, overwrite: bool = False, ): """Convert a Torch DCP checkpoint to a Hugging Face checkpoint. This is not an optimized utility. If checkpoint is too large, consider saving DCP during training and using this utility to convert to HF format. Args: dcp_ckpt_path (str): Path to DCP checkpoint hf_ckpt_path (str): Path to save HF checkpoint model_name_or_path (str): Model name or path for config tokenizer_name_or_path (str, optional): Tokenizer name or path. Defaults to model_name_or_path if None. overwrite (bool, optional): Whether to overwrite existing checkpoint. Defaults to False. Returns: str: Path to the saved HF checkpoint Raises: FileExistsError: If HF checkpoint already exists and overwrite is False """ if os.path.exists(hf_ckpt_path) and not overwrite: raise FileExistsError( f"HF checkpoint already exists at {hf_ckpt_path}. Delete it to run or set overwrite=True." ) os.makedirs(hf_ckpt_path, exist_ok=True) weights_path = os.path.join(hf_ckpt_path, "pytorch_model.bin") dcp_to_torch_save(dcp_ckpt_path, weights_path) # Need to reload and save b/c the state dict is scoped inside the model key {"model": actual_state_dict} state_dict = torch.load(weights_path) assert set(state_dict.keys()) == {"model"}, ( f"We expect that the state dict only has the top level model key, but found: {state_dict.keys()}" ) torch.save(state_dict["model"], weights_path) config = AutoConfig.from_pretrained(model_name_or_path) config.save_pretrained(hf_ckpt_path) # TODO: After the following PR gets merged: # https://github.com/NVIDIA/NeMo-RL/pull/148/files # tokenizer should be copied from policy/tokenizer/* instead of relying on the model name # We can expose a arg at the top level --tokenizer_path to plumb that through. # This is more stable than relying on the current NeMo-RL get_tokenizer() which can # change release to release. tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path) tokenizer.save_pretrained(hf_ckpt_path) return hf_ckpt_path