Source code for nemo_automodel.utils.yaml_utils

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import enum
import functools
import inspect
from collections.abc import Generator
from contextlib import contextmanager

import yaml


[docs] @contextmanager def safe_yaml_representers() -> Generator[None, None, None]: """ Context manager for safely adding and removing custom YAML representers. Temporarily adds custom representers for functions, classes, and other objects to the YAML SafeDumper, and restores the original representers when exiting the context. Usage: with safe_yaml_representers(): yaml_str = yaml.safe_dump(my_complex_object) """ # Save original representers original_representers = yaml.SafeDumper.yaml_representers.copy() original_multi_representers = yaml.SafeDumper.yaml_multi_representers.copy() try: # Register custom representers # Partial representer yaml.SafeDumper.add_representer(functools.partial, _partial_representer) # Enum representer yaml.SafeDumper.add_multi_representer(enum.Enum, _enum_representer) # Function representer yaml.SafeDumper.add_representer(type(lambda: ...), _function_representer) yaml.SafeDumper.add_representer(type(object), _function_representer) # Try to add torch dtype representer if available try: import torch yaml.SafeDumper.add_representer(torch.dtype, _torch_dtype_representer) except ModuleNotFoundError: pass # Try to add GenerationConfig representer if available try: from transformers import GenerationConfig yaml.SafeDumper.add_representer(GenerationConfig, _generation_config_representer) except ModuleNotFoundError: pass # General object representer yaml.SafeDumper.add_multi_representer(object, _safe_object_representer) yield finally: # Restore original representers yaml.SafeDumper.yaml_representers = original_representers yaml.SafeDumper.yaml_multi_representers = original_multi_representers
[docs] def _function_representer(dumper, data): """Represent functions in YAML.""" value = { "_target_": f"{inspect.getmodule(data).__name__}.{data.__qualname__}", # type: ignore "_call_": False, } return dumper.represent_data(value)
[docs] def _torch_dtype_representer(dumper, data): """Represent torch dtypes in YAML.""" value = { "_target_": str(data), "_call_": False, } return dumper.represent_data(value)
[docs] def _safe_object_representer(dumper, data): """General object representer for YAML. This function is a fallback for objects that don't have specific representers. If the object has __qualname__ attr, the _target_ is set to f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}". If the object does not have a __qualname__ attr, the _target_ is set from its __class__ attr. The _call_ key is used to indicate whether the target should be called to create an instance. Args: dumper (yaml.Dumper): The YAML dumper to use for serialization. data (Any): The data to serialize. Returns: The YAML representation of the data. """ try: obj = data target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" call = False except AttributeError: obj = data.__class__ target = f"{inspect.getmodule(obj).__name__}.{obj.__qualname__}" call = True value = { "_target_": target, # type: ignore "_call_": call, } return dumper.represent_data(value)
[docs] def _partial_representer(dumper, data): """Represent functools.partial objects in YAML.""" # Get the underlying function func = data.func # Create a dictionary representation value = { "_target_": f"{inspect.getmodule(func).__name__}.{func.__qualname__}", "_partial_": True, "_args_": list(data.args) if data.args else [], } # Add keyword arguments if any exist if data.keywords: value |= data.keywords return dumper.represent_data(value)
[docs] def _enum_representer(dumper, data): """Represent enums in YAML.""" # Create a dictionary representation enum_class = data.__class__ value = { "_target_": f"{inspect.getmodule(enum_class).__name__}.{enum_class.__qualname__}", "_call_": True, "_args_": [data.value], } return dumper.represent_data(value)
[docs] def _generation_config_representer(dumper, data): """Represent transformers GenerationConfig objects in YAML.""" cls = data.__class__ value = { "_target_": f"{inspect.getmodule(cls).__name__}.{cls.__qualname__}.from_dict", "_call_": True, "config_dict": data.to_dict(), } return dumper.represent_data(value)