Source code for nemo_automodel.config.loader

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import importlib
import importlib.util
import os
import sys

import yaml


[docs] def translate_value(v): """ Convert a string token into the corresponding Python object. This function first checks for a handful of special symbols (None/true/false), then falls back to `ast.literal_eval`, and finally to returning the original string if parsing fails. Args: v (str): The raw string value to translate. Returns: The translated Python value, which may be: - None, True, or False for the special symbols - an int, float, tuple, list, dict, etc. if `ast.literal_eval` succeeds - the original string `v` if all parsing attempts fail """ special_symbols = { 'none': None, 'None': None, 'true': True, 'True': True, 'false': False, 'False': False, } if v in special_symbols: return special_symbols[v] else: try: # smart-cast literals: numbers, dicts, lists, True/False, None return ast.literal_eval(v) except Exception: # fallback to raw string return v
[docs] def _resolve_target(dotted_path: str): """ Resolve a dotted path to a Python object. 1) Find the longest importable module prefix. 2) getattr() the rest. 3) If that fails, fall back to scanning sys.path for .py or package dirs. """ parts = dotted_path.split(".") # 1) Try longest‐prefix module import + getattr the rest for i in range(len(parts), 0, -1): module_name = ".".join(parts[:i]) remainder = parts[i:] try: module = importlib.import_module(module_name) except ModuleNotFoundError: continue # we got a module; now walk its attributes try: obj = module for name in remainder: obj = getattr(obj, name) return obj except AttributeError: # we imported module_name but one of the remainder attrs failed raise ImportError( f"Module '{module_name}' loaded, " f"but cannot resolve attribute '{'.'.join(remainder)}' in '{dotted_path}'" ) # 2) Fallback: scan sys.path for a .py file or package dir matching parts[:-1] for base in sys.path: pkg_dir = os.path.join(base, *parts[:-1]) candidates = [ pkg_dir + ".py", os.path.join(pkg_dir, "__init__.py"), ] for cand in candidates: if not os.path.isfile(cand): continue module_name = "_dynamic_" + "_".join(parts[:-1]) spec = importlib.util.spec_from_file_location(module_name, cand) mod = importlib.util.module_from_spec(spec) sys.modules[module_name] = mod spec.loader.exec_module(mod) try: return getattr(mod, parts[-1]) except AttributeError: raise ImportError( f"Loaded '{cand}' as module but no attribute '{parts[-1]}'" ) # 3) Give up raise ImportError(f"Cannot resolve target: {dotted_path}")
[docs] class ConfigNode: """ A configuration node that wraps a dictionary (or parts of it) from a YAML file. This class allows nested dictionaries and lists to be accessed as attributes and provides functionality to instantiate objects from configuration. """ def __init__(self, d): """Initialize the ConfigNode. Args: d (dict): A dictionary representing configuration options. """ self.__dict__ = { k: self._wrap(k, v) for k, v in d.items() }
[docs] def _wrap(self, k, v): """Wrap a configuration value based on its type. Args: k (str): The key corresponding to the value. v: The value to be wrapped. Returns: The wrapped value. """ if isinstance(v, dict): return ConfigNode(v) elif isinstance(v, list): return [self._wrap('', i) for i in v] elif k.endswith('_fn'): return _resolve_target(v) else: return translate_value(v)
[docs] def instantiate(self, *args, **kwargs): """Instantiate the target object specified in the configuration. This method looks for the "_target_" attribute in the configuration and resolves it to a callable function or class which is then instantiated. Args: *args: Positional arguments for the target instantiation. **kwargs: Keyword arguments to override or add to the configuration values. Returns: The instantiated object. Raises: AttributeError: If no "_target_" attribute is found in the configuration. """ if not hasattr(self, "_target_"): raise AttributeError("No _target_ found to instantiate") func = _resolve_target(self._target_) # Prepare kwargs from config config_kwargs = {} for k, v in self.__dict__.items(): if k == '_target_': continue if k.endswith('_fn'): config_kwargs[k] = v else: config_kwargs[k] = self._instantiate_value(v) # Override/add with passed kwargs config_kwargs.update(kwargs) return func(*args, **config_kwargs)
[docs] def _instantiate_value(self, v): """ Recursively instantiate configuration values. Args: v: The configuration value. Returns: The instantiated value. """ if isinstance(v, ConfigNode) and hasattr(v, "_target_"): return v.instantiate() elif isinstance(v, ConfigNode): return v.to_dict() elif isinstance(v, list): return [self._instantiate_value(i) for i in v] else: return translate_value(v)
[docs] def to_dict(self): """ Convert the configuration node back to a dictionary. Returns: dict: A dictionary representation of the configuration node. """ return { k: self._unwrap(v) for k, v in self.__dict__.items() }
[docs] def _unwrap(self, v): """ Recursively convert wrapped configuration values to basic Python types. Args: v: The configuration value. Returns: The unwrapped value. """ if isinstance(v, ConfigNode): return v.to_dict() elif isinstance(v, list): return [self._unwrap(i) for i in v] else: return v
[docs] def get(self, key, default=None): """ Retrieve a configuration value using a dotted key. If any component of the path is missing, returns the specified default value. Args: key (str): The dotted path key. default: A default value to return if the key is not found. Returns: The configuration value or the default value. """ parts = key.split(".") current = self # TODO(@akoumparouli): reduce? for p in parts: # Traverse dictionaries (ConfigNode) if isinstance(current, ConfigNode): if p in current.__dict__: current = current.__dict__[p] else: return default # Traverse lists by numeric index elif isinstance(current, list): try: idx = int(p) current = current[idx] except (ValueError, IndexError): return default else: # Reached a leaf but path still has components return default return current
[docs] def set_by_dotted(self, dotted_key: str, value): """ Set (or append) a value in the config using a dotted key. e.g. set_by_dotted("foo.bar.abc", 1) will ensure self.foo.bar.abc == 1 """ parts = dotted_key.split(".") node = self # walk / create intermediate ConfigNodes for p in parts[:-1]: if p not in node.__dict__ or not isinstance(node.__dict__[p], ConfigNode): node.__dict__[p] = ConfigNode({}) node = node.__dict__[p] # wrap the final leaf value node.__dict__[parts[-1]] = node._wrap(parts[-1], value)
[docs] def __repr__(self, level=0): """ Return a string representation of the configuration node with indentation. Args: level (int): The current indentation level. Returns: str: An indented string representation of the configuration. """ indent = " " * level lines = [f"{indent}{key}: {self._repr_value(value, level)}" for key, value in self.__dict__.items()] return "\n#path: " + "\n".join(lines) + f"\n{indent}"
[docs] def _repr_value(self, value, level): """ Format a configuration value for the string representation. Args: value: The configuration value. level (int): The indentation level. Returns: str: A formatted string representation of the value. """ if isinstance(value, ConfigNode): return value.__repr__(level + 1) elif isinstance(value, list): return "[\n" + \ "\n".join([f"{' ' * (level + 1)}{self._repr_value(i, level + 1)}" for i in value]) \ + f"\n{' ' * level}]" else: return repr(value)
[docs] def __str__(self): """ Return a string representation of the configuration node. Returns: str: The string representation. """ return self.__repr__(level=0)
[docs] def __contains__(self, key): """ Check if a dotted key exists in the configuration. Args: key (str): The dotted key to check. Returns: bool: True if the key exists, False otherwise. """ parts = key.split('.') current = self for p in parts: if isinstance(current, ConfigNode): if p in current.__dict__: current = current.__dict__[p] else: return False return current != self
[docs] def load_yaml_config(path): """ Load a YAML configuration file and convert it to a ConfigNode. Args: path (str): The path to the YAML configuration file. Returns: ConfigNode: A configuration node representing the YAML file. """ with open(path, "r") as f: raw = yaml.safe_load(f) return ConfigNode(raw)