Source code for nemo_automodel.checkpoint.stateful_wrappers

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional

import torch
from torch.distributed.checkpoint.state_dict import (
    StateDictOptions,
    get_model_state_dict,
    get_optimizer_state_dict,
    set_model_state_dict,
    set_optimizer_state_dict,
)
from torch.distributed.checkpoint.stateful import Stateful

from nemo_automodel.checkpoint._backports.filesystem import SerializationFormat

_PREFIX = "model."

[docs] def _drop_outer_prefix(sd: dict[str, Any], prefix: str = _PREFIX) -> None: """ Remove the *first* occurrence of `prefix` on every key in-place. """ for k in list(sd.keys()): if k.startswith(prefix): sd[k[len(prefix):]] = sd.pop(k)
[docs] def _add_outer_prefix(sd: dict[str, Any], prefix: str = _PREFIX) -> None: """ Prepend `prefix` once to every key in-place (inverse of `_drop_outer_prefix`). """ for k in list(sd.keys()): if not k.startswith(prefix): sd[prefix + k] = sd.pop(k)
[docs] def _get_lm_head_weight_and_name(model: torch.nn.Module) -> Optional[tuple[torch.Tensor, str]]: for name, param in model.named_parameters(remove_duplicate=False): if "lm_head" in name and name.endswith(".weight"): return param, name return None, None
# modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html
[docs] class ModelState(Stateful): """ Helper class for tracking model state in distributed checkpointing. This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs. Args: model: The PyTorch model to track. """ def __init__(self, model: torch.nn.Module, serialization_format: SerializationFormat, is_peft: bool = False): """ Initialize a ModelState instance for distributed checkpointing. The constructor records the model reference, detects whether the model ties its language-model head to the input embeddings, and stores the desired serialization backend so that DCP can correctly save and restore the model’s parameters and buffers. Args: model (torch.nn.Module): The PyTorch model whose state should be captured during checkpointing. serialization_format (SerializationFormat): Backend/format to use when persisting the model state (e.g., torch, safetensors). is_peft (bool): Whether the model is PEFT. """ self.model = model self.is_tied_lm_head = getattr(getattr(model, 'config', {}), 'tie_word_embeddings', False) self.serialization_format = serialization_format self.is_peft = is_peft
[docs] def state_dict(self) -> dict[str, Any]: """ Get the model's state dictionary. Returns: dict: Dictionary containing the model's state dict with CPU offloading enabled. """ options = None if self.is_peft: options = StateDictOptions( full_state_dict=True, cpu_offload=True, ignore_frozen_params=True ) model_state_dict = get_model_state_dict(self.model, options=options) if self.is_tied_lm_head: model_state_dict.pop("model.lm_head.weight", None) if self.is_peft: # HF PEFT models are saved with a "base.model." prefix. This is so they can be loaded # correctly with the HF PEFT API. _add_outer_prefix(model_state_dict, "base_model.model.") elif self.serialization_format == SerializationFormat.SAFETENSORS: # This is a hack to fix the issue with the model state dict being saved with the "model.model." prefix. # This is necessary when saving consolidated safetensors. This is because calling HF's # .from_pretrained() requires the model to be saved with a single "model." prefix. # This is not needed for torch serialization. _drop_outer_prefix(model_state_dict) return model_state_dict
[docs] def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ Load the state dictionary into the model. Args: state_dict (dict): State dictionary to load. """ options = None if self.is_peft: _drop_outer_prefix(state_dict, "base_model.model.") options = StateDictOptions(strict=False, broadcast_from_rank0=True, full_state_dict=True) elif self.serialization_format == SerializationFormat.SAFETENSORS: # Undo the prefix-stripping that happened at save-time: DCP removes the # container name ("model") when it dispatches the dict to this # ModelState, so every key now lacks the leading "model." segment that # HuggingFace modules normally carry. Re-add it so that # set_model_state_dict can match parameters correctly. This is not needed # for torch serialization. _add_outer_prefix(state_dict) # If we intentionally skipped saving "lm_head.weight" (tied embeddings) # PyTorch will complain during load even with strict=False. # To be fully compatible we inject a reference tensor so the key exists. if self.is_tied_lm_head and not self.is_peft: lm_head_weight, lm_head_param_name = _get_lm_head_weight_and_name(self.model) if lm_head_param_name not in state_dict: # weight tying guarantees this is identical to the embedding weight state_dict[lm_head_param_name] = lm_head_weight.detach() set_model_state_dict( self.model, state_dict, options=options, )
[docs] class OptimizerState(Stateful): """ Helper class for tracking optimizer state in distributed checkpointing. This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs. Args: model: The PyTorch model associated with the optimizer. optimizer: The optimizer to track. scheduler: Optional learning rate scheduler. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, scheduler: Optional[Any] = None, ): """ Initialize an OptimizerState instance. The constructor simply stores references to the model, optimizer, and (optionally) learning-rate scheduler so that their state can be captured and restored by the Distributed Checkpointing (DCP) framework. Args: model (torch.nn.Module): The neural-network model whose parameters the optimizer updates. Keeping the reference allows DCP to re-establish the model–optimizer relationship when loading a checkpoint. optimizer (torch.optim.Optimizer): Optimizer whose internal buffers (e.g., momentum, Adam moments, step counters) need to be saved and restored. scheduler (Optional[Any], optional): Learning-rate scheduler to track alongside the optimizer. Pass ``None`` if no scheduler is used. """ self.model = model self.optimizer = optimizer self.scheduler = scheduler
[docs] def state_dict(self) -> dict[str, Any]: """ Get the optimizer and scheduler state dictionaries. Returns: dict: Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled. """ # this line automatically manages FSDP FQN's, as well as sets the default state dict type # to FSDP.SHARDED_STATE_DICT optimizer_state_dict = get_optimizer_state_dict( self.model, self.optimizer, ) state_dict = { "optim": optimizer_state_dict, } if self.scheduler is not None: state_dict["sched"] = self.scheduler.state_dict() return state_dict
[docs] def load_state_dict(self, state_dict: dict[str, Any]) -> None: """ Load the state dictionaries into the optimizer and scheduler. Args: state_dict (dict): State dictionary containing optimizer and scheduler states to load. """ # sets our state dicts on the optimizer, now that we've loaded set_optimizer_state_dict( self.model, self.optimizer, state_dict["optim"], ) # load the scheduler state if it exists if "sched" in state_dict and self.scheduler is not None: self.scheduler.load_state_dict(state_dict["sched"])