Source code for nemo_automodel.training.base_recipe

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
from pathlib import Path

import torch
import torch.nn as nn
from torch.optim import Optimizer

from nemo_automodel.checkpoint.checkpointing import (
    load_model,
    load_optimizer,
    save_model,
    save_optimizer,
)


[docs] def has_load_restore_state(object): """ Checks whether object has load_state_dict and state_dict functions. TODO: also need to check function signatures. Args: object (any): the object to check. Returns: bool: returns True if has callable load_state_dict and state_dict """ return all( callable(getattr(object, attr, None)) for attr in ('load_state_dict', 'state_dict') )
[docs] class BaseRecipe: """ BaseRecipe provides checkpoint load/save functionality for recipes. """
[docs] def __setattr__(self, key, value): """ Overriden __setattr__ to keep track of stateful classes. Args: key (str): attribute named. value (Any): Value assigned Raises: ValueError: if __state_tracked is attemped to be overwriten. """ # assuming no one will do recipe.__dict__['__state_tracked'] = None if key == '__state_tracked': raise ValueError("cannot set __state_tracked") if '__state_tracked' not in self.__dict__: self.__dict__['__state_tracked'] = set() # Track stateful objects unless they are validation/eval components. should_track = ( isinstance(value, (nn.Module, Optimizer)) or has_load_restore_state(value) ) if should_track and not any(substr in key.lower() for substr in ("val", "eval", "test")): assert key not in self.__dict__['__state_tracked'] self.__dict__['__state_tracked'].add(key) super().__setattr__(key, value)
[docs] def save_checkpoint(self, epoch: int, step: int): """ Save the current training state as a checkpoint. As long as the object has a 'load_state_dict' and 'state_dict' function, it will be saved. Args: epoch (int): The current epoch. step (int): The current step. """ if not self.checkpoint_config.enabled: return path = self.checkpoint_config.checkpoint_dir path = os.path.join(path, f"epoch_{epoch}_step_{step}") os.makedirs(path, exist_ok=True) if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: print(f"Saving checkpoint to {path}", flush=True) # TODO(@adil-a): Change this when we create a LR scheduler class model, optimizer = None, None for key in self.__dict__['__state_tracked']: if isinstance(getattr(self, key), nn.Module): model = getattr(self, key) elif isinstance(getattr(self, key), Optimizer): optimizer = getattr(self, key) else: if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: torch.save( getattr(self, key).state_dict(), os.path.join(path, f"{key}.pt"), ) if torch.distributed.is_initialized(): torch.distributed.barrier() save_model(model, path, self.checkpoint_config) save_optimizer(optimizer, model, path)
[docs] def load_checkpoint(self, restore_from: str | None = None): """ Loads the latest checkpoint. """ if not self.checkpoint_config.enabled: if ( ( not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0 ) and restore_from is not None ): print("Enable checkpointing to resume from a checkpoint, skipping...", flush=True) return if restore_from: ckpt_dir = restore_from else: # Determine the latest checkpoint directory (e.g. ".../step_42"). ckpt_dir = _find_latest_checkpoint(self.checkpoint_config.checkpoint_dir) if ckpt_dir is None: return if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: print(f"Loading checkpoint from {ckpt_dir}", flush=True) # TODO(@adil-a): Change this when we create a LR scheduler class model, optimizer = None, None for key in self.__dict__['__state_tracked']: if isinstance(getattr(self, key), nn.Module): model = getattr(self, key) elif isinstance(getattr(self, key), Optimizer): optimizer = getattr(self, key) else: getattr(self, key).load_state_dict( torch.load( os.path.join(ckpt_dir, f"{key}.pt"), weights_only=False ) ) load_model(model, ckpt_dir, self.checkpoint_config) load_optimizer(optimizer, model, ckpt_dir)
[docs] def _find_latest_checkpoint(checkpoint_dir): """ Find the latest checkpoint in the checkpoint directory and return it. """ checkpoint_dir = Path(checkpoint_dir) if not checkpoint_dir.exists(): return # Accept checkpoints saved as either `step_<num>` or `epoch_<epoch>_step_<num>` # (or any other pattern that contains the substring `step_`). # This makes the checkpoint loading logic compatible with the naming scheme # used in `save_checkpoint`, which currently saves to `epoch_{epoch}_step_{step}`. checkpoint_files = list(checkpoint_dir.glob("*step_*")) if not checkpoint_files: return def _step_num(path: Path): """Return the numeric step from a path stem of the form step_<int>.""" m = re.search(r"step_(\d+)$", path.stem) return int(m.group(1)) if m else -1 latest = max(checkpoint_files, key=_step_num) # If no directory followed the expected "step_<int>" pattern, _step_num would be -1 for all of them. # Treat that as "no valid checkpoint". if _step_num(latest) == -1: return return latest