Source code for nemo_automodel.cli.utils

#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from functools import lru_cache
from pathlib import Path

import yaml

_RECIPES_DIR = Path(__file__).resolve().parent.parent / "recipes"
logger = logging.getLogger(__name__)


[docs] @lru_cache(maxsize=1) def _discover_recipe_classes() -> dict[str, str]: """Scan ``nemo_automodel/recipes/`` for concrete recipe classes. Returns a mapping from bare class name to fully-qualified dotted path, e.g. ``{"TrainFinetuneRecipeForNextTokenPrediction": "nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction"}``. """ registry: dict[str, str] = {} pkg_root = _RECIPES_DIR.parent.parent for py_file in _RECIPES_DIR.rglob("*.py"): if py_file.name.startswith("_"): continue module_dotted = ".".join(py_file.relative_to(pkg_root).with_suffix("").parts) source = py_file.read_text() for m in re.finditer(r"^class\s+(\w*Recipe\w*)\s*[\(:]", source, re.MULTILINE): cls_name = m.group(1) if cls_name == "BaseRecipe": continue registry[cls_name] = f"{module_dotted}.{cls_name}" return registry
[docs] def resolve_recipe_name(raw: str) -> str: """Resolve a recipe name to its fully-qualified dotted path. Accepts: - Bare class name: ``"TrainFinetuneRecipeForNextTokenPrediction"`` - Full FQN: ``"nemo_automodel.recipes.llm.train_ft.TrainFinetuneRecipeForNextTokenPrediction"`` Raises ``ValueError`` when a bare name cannot be found. """ if "." in raw: return raw registry = _discover_recipe_classes() if raw in registry: return registry[raw] available = "\n".join(f" - {name}" for name in sorted(registry)) raise ValueError(f"Unknown recipe class '{raw}'. Available short names:\n{available}")
[docs] def load_yaml(file_path): """Load and return a YAML file as a dict.""" try: with open(file_path, "r") as file: return yaml.safe_load(file) except FileNotFoundError as e: logger.error("File '%s' was not found.", file_path) raise e except yaml.YAMLError as e: logger.error("parsing YAML file '%s' failed: %s", file_path, e) raise e