Source code for nemo_automodel._transformers.auto_model

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import types

import torch
from torch.nn.attention import SDPBackend, sdpa_kernel
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText

from nemo_automodel import __version__
from nemo_automodel.shared.import_utils import safe_import
import types
import inspect
import functools


HAS_LIGER_KERNEL, liger_kernel_trf = safe_import("liger_kernel.transformers")
logger = logging.getLogger(__name__)

[docs] def _assert_same_signature(original, patched): """ Raise AssertionError if the two call signatures differ. """ sig_orig = inspect.signature(original) sig_patch = inspect.signature(patched) if sig_orig != sig_patch: raise AssertionError( f"Signature mismatch:\n" f" original: {sig_orig}\n" f" patched : {sig_patch}" )
[docs] def patch_attention(obj, sdpa_method=None): """ Wrap the `forward` method of `obj` in an `sdap_kernel` context manager. Args: obj: Any object with a `.forward(*args, **kwargs)` method. sdpa_method (list[SDPBackend], optional): Ordered list of SDPBackend implementations to attempt. If None, defaults to [CUDNN_ATTENTION, FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH]. Returns: The same `obj` with its `.forward` method patched. """ if sdpa_method is None: sdpa_method = [ SDPBackend.CUDNN_ATTENTION, SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] orig_forward = obj.forward def patch_method(method): func = method.__func__ @functools.wraps(func) def wrapper(self, *args, **kwargs): with sdpa_kernel(sdpa_method): return func(self, *args, **kwargs) wrapper.__doc__ = "SDPA kernel patch\n" + inspect.getdoc(method) return types.MethodType(wrapper, method.__self__) # re-bind obj.forward = patch_method(obj.forward) # runtime check _assert_same_signature(orig_forward, obj.forward) return obj
[docs] class NeMoAutoModelForCausalLM(AutoModelForCausalLM): """ Drop-in replacement for ``transformers.AutoModelForCausalLM`` that includes custom-kernels. The class only overrides ``from_pretrained`` and ``from_config`` to add the optional ``use_liger_kernel`` flag. If the flag is ``True`` (default) and the Liger kernel is available, the model's attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with ``use_liger_kernel=False`` so that users still obtain a functional model. TODO(@akoumpa): extend this beyond liger_kernel. Notes: ----- - No changes are made to the model's public API; forward signatures, generation utilities, and weight shapes remain identical. - Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back. Examples: -------- >>> model = NeMoAutoModelForCausalLM.from_pretrained("gpt2") # try Liger >>> model = NeMoAutoModelForCausalLM.from_pretrained( ... "gpt2", use_liger_kernel=False) # skip Liger """
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load a pretrained causal-language-model and (optionally) patch it with custom kernels. Parameters ---------- pretrained_model_name_or_path : str or os.PathLike Repository ID or local path accepted by ``transformers.AutoModelForCausalLM.from_pretrained``. *model_args Positional arguments forwarded verbatim to the superclass. use_liger_kernel : bool, default True Whether to attempt patching the loaded model with Liger kernels. **kwargs Keyword arguments forwarded verbatim to the superclass. Returns: ------- transformers.PreTrainedModel The instantiated model, possibly Liger-patched. Warnings: -------- Emits a ``logging.warning`` if ``use_liger_kernel=True`` but the Liger package is not available. Retries ------- If patching raises an exception, the method deletes the partially constructed model and recursively reloads it once with ``use_liger_kernel=False``. """ torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) use_liger_kernel = kwargs.pop("use_liger_kernel", True) sdpa_method = kwargs.pop("sdpa_method", None) model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, torch_dtype=torch_dtype, ) if use_liger_kernel: if not HAS_LIGER_KERNEL: logging.warning("Asked to use Liger Kernel, but could not import") return model try: liger_kernel_trf._apply_liger_kernel_to_instance(model=model) except Exception: del model # If patching failed, retry return cls.from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, torch_dtype=torch_dtype, use_liger_kernel=False, ) model = patch_attention(model, sdpa_method) model.config.update({"nemo_version": __version__}) return model
[docs] @classmethod def from_config(cls, config, **kwargs): """ Instantiate a model from a config object and (optionally) patch it with custom kernels. Parameters ---------- config : transformers.PretrainedConfig Configuration used to build the model. use_liger_kernel : bool, default True Whether to attempt patching the instantiated model with Liger kernels. **kwargs Additional keyword arguments forwarded to the superclass. Returns: ------- transformers.PreTrainedModel The instantiated model, possibly Liger-patched. See Also: -------- NeMoAutoModelForCausalLM.from_pretrained : Same logic for checkpoint loading. """ torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) use_liger_kernel = kwargs.pop("use_liger_kernel", True) sdpa_method = kwargs.pop("sdpa_method", None) model = super().from_config(config, **kwargs, torch_dtype=torch_dtype) if use_liger_kernel: if not HAS_LIGER_KERNEL: logging.warning("Asked to use Liger Kernel, but could not import") return model try: liger_kernel_trf._apply_liger_kernel_to_instance(model=model) except Exception: del model # If patching failed, retry return cls.from_config( config, **kwargs, use_liger_kernel=False, torch_dtype=torch_dtype ) model = patch_attention(model, sdpa_method) model.config.update({"nemo_version": __version__}) return model
[docs] class NeMoAutoModelForImageTextToText(AutoModelForImageTextToText): """Drop-in replacement for ``transformers.AutoModelForImageTextToText`` with custom-kernels. The class only overrides ``from_pretrained`` and ``from_config`` to add the optional ``use_liger_kernel`` flag. If the flag is ``True`` (default) and the Liger kernel is available, the model's attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with ``use_liger_kernel=False`` so that users still obtain a functional model. @akoumpa: currently only supporting liger_kernel for demonstration purposes. Notes: ----- - No changes are made to the model's public API; forward signatures, generation utilities, and weight shapes remain identical. - Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back. Examples: -------- >>> model = NeMoAutoModelForImageTextToText.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct") # try Liger >>> model = NeMoAutoModelForImageTextToText.from_pretrained( ... "Qwen/Qwen2.5-VL-3B-Instruct", use_liger_kernel=False) # skip Liger """
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load a pretrained causal-language-model and (optionally) patch it with custom kernels. Parameters ---------- pretrained_model_name_or_path : str or os.PathLike Repository ID or local path accepted by ``transformers.AutoModelForCausalLM.from_pretrained``. *model_args Positional arguments forwarded verbatim to the superclass. use_liger_kernel : bool, default True Whether to attempt patching the loaded model with Liger kernels. **kwargs Keyword arguments forwarded verbatim to the superclass. Returns: ------- transformers.PreTrainedModel The instantiated model, possibly Liger-patched. Warnings: -------- Emits a ``logging.warning`` if ``use_liger_kernel=True`` but the Liger package is not available. Retries ------- If patching raises an exception, the method deletes the partially constructed model and recursively reloads it once with ``use_liger_kernel=False``. """ torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) use_liger_kernel = kwargs.pop("use_liger_kernel", True) sdpa_method = kwargs.pop("sdpa_method", None) model = super().from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, torch_dtype=torch_dtype, ) if use_liger_kernel: if not HAS_LIGER_KERNEL: logging.warning("Asked to use Liger Kernel, but could not import") return model try: liger_kernel_trf._apply_liger_kernel_to_instance(model=model) except Exception: del model # If patching failed, retryd return cls.from_pretrained( pretrained_model_name_or_path, *model_args, **kwargs, use_liger_kernel=False, torch_dtype=torch_dtype, ) model = patch_attention(model, sdpa_method) model.config.update({"nemo_version": __version__}) return model
[docs] @classmethod def from_config(cls, config, **kwargs): """ Instantiate a model from a config object and (optionally) patch it with custom kernels. Parameters ---------- config : transformers.PretrainedConfig Configuration used to build the model. use_liger_kernel : bool, default True Whether to attempt patching the instantiated model with Liger kernels. **kwargs Additional keyword arguments forwarded to the superclass. Returns: ------- transformers.PreTrainedModel The instantiated model, possibly Liger-patched. See Also: -------- NeMoAutoModelForImageTextToText.from_pretrained : Same logic for checkpoint loading. """ torch_dtype = kwargs.pop("torch_dtype", torch.bfloat16) use_liger_kernel = kwargs.pop("use_liger_kernel", True) sdpa_method = kwargs.pop("sdpa_method", None) model = super().from_config(config, **kwargs, torch_dtype=torch_dtype) if use_liger_kernel: if not HAS_LIGER_KERNEL: logging.warning("Asked to use Liger Kernel, but could not import") return model try: liger_kernel_trf._apply_liger_kernel_to_instance(model=model) except Exception: del model # If patching failed, retry return cls.from_config( config, **kwargs, use_liger_kernel=False, torch_dtype=torch_dtype ) model = patch_attention(model, sdpa_method) model.config.update({"nemo_version": __version__}) return model