Source code for nemo_automodel._transformers.v4_patches.layer_types

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Widen transformers' ``ALLOWED_LAYER_TYPES`` so legacy custom configs load.

Some community models (e.g. ``nvidia/Nemotron-Flash-1B``) ship a custom
``configuration_*.py`` whose ``layer_types`` entries (e.g. ``'deltanet'``,
``'f'``, ``'m2'``, ``'a'``) are not in the upstream allow-list. Loading them
via ``AutoConfig.from_pretrained`` triggers ``validate_layer_type`` and raises
``StrictDataclassClassValidationError`` before model instantiation.

The validator performs a module-global lookup of ``ALLOWED_LAYER_TYPES`` at
call time, so rebinding it in place takes effect on subsequent validations.
"""

from __future__ import annotations

import importlib.abc
import logging
import sys
from typing import Iterable

logger = logging.getLogger(__name__)

_TARGET_MODULE = "transformers.configuration_utils"

DEFAULT_EXTRA_LAYER_TYPES: tuple[str, ...] = (
    "deltanet",
    "f",
    "m2",
    "a",
)

_PATCHED: bool = False


[docs] def patch_allowed_layer_types(extra: Iterable[str] = DEFAULT_EXTRA_LAYER_TYPES) -> bool: """Extend ``transformers.configuration_utils.ALLOWED_LAYER_TYPES`` in place. Idempotent and best-effort: any failure (missing attribute, transformers not installed, unexpected container type) is logged and swallowed so the caller's import path is not broken. Args: extra: Layer-type names to append if not already present. Returns: ``True`` if the tuple was modified on this call, ``False`` otherwise. """ global _PATCHED if _PATCHED: return False try: from transformers import configuration_utils as cu except ImportError: logger.debug("[v4_patches.layer_types] transformers not importable; skipping.") return False except Exception as exc: logger.warning("[v4_patches.layer_types] transformers import failed: %s", exc) return False existing = getattr(cu, "ALLOWED_LAYER_TYPES", None) if existing is None: logger.debug("[v4_patches.layer_types] ALLOWED_LAYER_TYPES missing; nothing to patch.") _PATCHED = True return False try: existing_set = set(existing) except TypeError: logger.warning( "[v4_patches.layer_types] ALLOWED_LAYER_TYPES is not iterable (%s); skipping.", type(existing).__name__, ) return False try: additions = tuple(lt for lt in extra if lt not in existing_set) except TypeError: logger.warning("[v4_patches.layer_types] `extra` is not iterable; skipping.") return False if not additions: _PATCHED = True return False try: cu.ALLOWED_LAYER_TYPES = tuple(existing) + additions except Exception as exc: logger.warning("[v4_patches.layer_types] failed to rebind ALLOWED_LAYER_TYPES: %s", exc) return False _PATCHED = True logger.info("[v4_patches.layer_types] extended ALLOWED_LAYER_TYPES with %s", additions) return True
_HOOK_INSTALLED: bool = False
[docs] class _LayerTypesPatchFinder(importlib.abc.MetaPathFinder): """Meta-path finder that applies ``patch_allowed_layer_types`` the moment ``transformers.configuration_utils`` finishes loading. The finder only intercepts the single target module name. For that name it delegates to the remaining finders to obtain a real spec, then wraps the loader's ``exec_module`` so the patch runs once, post-exec. All deviations from the happy path (missing loader, unexpected spec shape, patch failure) are swallowed with a log — the finder must never break imports. """ def __init__(self) -> None: self._applied = False
[docs] def find_spec(self, fullname, path=None, target=None): if fullname != _TARGET_MODULE or self._applied: return None real_spec = None for finder in sys.meta_path: # Skip any instance of our own finder class to avoid wrapping our # own wrapped loader (can happen if the hook is installed twice). if isinstance(finder, _LayerTypesPatchFinder): continue find_spec = getattr(finder, "find_spec", None) if find_spec is None: continue try: real_spec = find_spec(fullname, path, target) except Exception: continue if real_spec is not None: break if real_spec is None or real_spec.loader is None: return real_spec loader = real_spec.loader original_exec_module = getattr(loader, "exec_module", None) if original_exec_module is None: return real_spec finder_self = self def exec_module(module): try: original_exec_module(module) finally: finder_self._applied = True try: patch_allowed_layer_types() except Exception as exc: logger.warning("[v4_patches.layer_types] post-import patch failed: %s", exc) try: loader.exec_module = exec_module # type: ignore[method-assign] except Exception as exc: logger.debug("[v4_patches.layer_types] could not wrap loader.exec_module: %s", exc) return real_spec
[docs] def install_layer_types_patch_hook() -> bool: """Ensure ``patch_allowed_layer_types`` runs before any call to ``PreTrainedConfig.validate_layer_type``. Two paths: * If ``transformers.configuration_utils`` is already loaded, patch immediately. * Otherwise, register a meta-path finder that patches on first import. Idempotent; safe to call multiple times. Returns: ``True`` if a hook was installed (or the patch was applied directly on this call), ``False`` if a previous call already arranged one. """ global _HOOK_INSTALLED if _HOOK_INSTALLED: return False if _TARGET_MODULE in sys.modules: _HOOK_INSTALLED = True patch_allowed_layer_types() return True try: sys.meta_path.insert(0, _LayerTypesPatchFinder()) except Exception as exc: logger.warning("[v4_patches.layer_types] failed to install import hook: %s", exc) return False _HOOK_INSTALLED = True return True
_VALIDATOR_RELAXED: bool = False
[docs] def _noop_validate_layer_type(self): # noqa: ARG001 """No-op replacement for ``PretrainedConfig.validate_layer_type``.""" return None
_noop_validate_layer_type.__name__ = "validate_layer_type"
[docs] def relax_layer_types_validator() -> bool: """Disable ``PretrainedConfig.validate_layer_type`` and its registered copies. Some upstream configs (e.g. ``stepfun-ai/Step-3.5-Flash``) ship ``layer_types`` whose length differs from ``num_hidden_layers``; newer transformers rejects such configs, and ``huggingface_hub`` wraps that ``ValueError`` in ``StrictDataclassClassValidationError``. Suppressing the validator lets downstream code load the config unmodified. The ``huggingface_hub`` ``@strict`` decorator freezes validator references on each class into ``__class_validators__`` at decoration time, so swapping ``PretrainedConfig.validate_layer_type`` alone is insufficient: every already-decorated subclass keeps its own list. This helper rewrites both the class attribute and each list entry across the live subclass tree, then newly-decorated subclasses pick up the no-op automatically via ``getattr(cls, "validate_layer_type")`` at their decoration time. Idempotent and best-effort: failures (missing attribute, transformers not installed) are logged and swallowed so the caller is never broken. Returns: ``True`` if the patch was applied on this call, ``False`` otherwise. """ global _VALIDATOR_RELAXED if _VALIDATOR_RELAXED: return False try: from transformers import PretrainedConfig except ImportError: logger.debug("[v4_patches.layer_types] transformers not importable; skipping relax.") return False except Exception as exc: logger.warning("[v4_patches.layer_types] transformers import failed: %s", exc) return False if not hasattr(PretrainedConfig, "validate_layer_type"): logger.debug("[v4_patches.layer_types] validate_layer_type missing; nothing to relax.") _VALIDATOR_RELAXED = True return False try: PretrainedConfig.validate_layer_type = _noop_validate_layer_type except Exception as exc: logger.warning("[v4_patches.layer_types] failed to rebind validate_layer_type: %s", exc) return False def _relax_subtree(cls) -> None: validators = cls.__dict__.get("__class_validators__") if isinstance(validators, list): for i, v in enumerate(validators): if getattr(v, "__name__", None) == "validate_layer_type": validators[i] = _noop_validate_layer_type try: subs = cls.__subclasses__() except Exception: return for sub in subs: _relax_subtree(sub) try: _relax_subtree(PretrainedConfig) except Exception as exc: logger.warning("[v4_patches.layer_types] subclass walk failed: %s", exc) return False _VALIDATOR_RELAXED = True logger.info("[v4_patches.layer_types] relaxed validate_layer_type to no-op") return True
__all__ = [ "DEFAULT_EXTRA_LAYER_TYPES", "install_layer_types_patch_hook", "patch_allowed_layer_types", "relax_layer_types_validator", ]