nemo_automodel._transformers.auto_model
#
Module Contents#
Classes#
Drop-in replacement for |
|
Drop-in replacement for |
Functions#
Raise AssertionError if the two call signatures differ. |
|
Wrap the |
Data#
API#
- nemo_automodel._transformers.auto_model.logger#
‘getLogger(…)’
- nemo_automodel._transformers.auto_model._assert_same_signature(original, patched)[source]#
Raise AssertionError if the two call signatures differ.
- nemo_automodel._transformers.auto_model.patch_attention(obj, sdpa_method=None)[source]#
Wrap the
forward
method ofobj
in ansdap_kernel
context manager.- Parameters:
obj – Any object with a
.forward(*args, **kwargs)
method.sdpa_method (list[SDPBackend], optional) – Ordered list of SDPBackend implementations to attempt. If None, defaults to [CUDNN_ATTENTION, FLASH_ATTENTION, EFFICIENT_ATTENTION, MATH].
- Returns:
The same
obj
with its.forward
method patched.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForCausalLM(*args, **kwargs)[source]#
Bases:
transformers.AutoModelForCausalLM
Drop-in replacement for
transformers.AutoModelForCausalLM
that includes custom-kernels.The class only overrides
from_pretrained
andfrom_config
to add the optionaluse_liger_kernel
flag. If the flag isTrue
(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=False
so that users still obtain a functional model.TODO(@akoumpa): extend this beyond liger_kernel.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForCausalLM.from_pretrained(“gpt2”) # try Liger model = NeMoAutoModelForCausalLM.from_pretrained( … “gpt2”, use_liger_kernel=False) # skip Liger
Initialization
- classmethod from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)[source]#
Load a pretrained causal-language-model and (optionally) patch it with custom kernels.
Parameters#
pretrained_model_name_or_path : str or os.PathLike Repository ID or local path accepted by
transformers.AutoModelForCausalLM.from_pretrained
. *model_args Positional arguments forwarded verbatim to the superclass. use_liger_kernel : bool, default True Whether to attempt patching the loaded model with Liger kernels. **kwargs Keyword arguments forwarded verbatim to the superclass.Returns:#
transformers.PreTrainedModel The instantiated model, possibly Liger-patched.
Warnings:#
Emits a
logging.warning
ifuse_liger_kernel=True
but the Liger package is not available.Retries#
If patching raises an exception, the method deletes the partially constructed model and recursively reloads it once with
use_liger_kernel=False
.
- classmethod from_config(config, **kwargs)[source]#
Instantiate a model from a config object and (optionally) patch it with custom kernels.
Parameters#
config : transformers.PretrainedConfig Configuration used to build the model. use_liger_kernel : bool, default True Whether to attempt patching the instantiated model with Liger kernels. **kwargs Additional keyword arguments forwarded to the superclass.
Returns:#
transformers.PreTrainedModel The instantiated model, possibly Liger-patched.
See Also:#
NeMoAutoModelForCausalLM.from_pretrained : Same logic for checkpoint loading.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForImageTextToText(*args, **kwargs)[source]#
Bases:
transformers.AutoModelForImageTextToText
Drop-in replacement for
transformers.AutoModelForImageTextToText
with custom-kernels.The class only overrides
from_pretrained
andfrom_config
to add the optionaluse_liger_kernel
flag. If the flag isTrue
(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=False
so that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForImageTextToText.from_pretrained(“Qwen/Qwen2.5-VL-3B-Instruct”) # try Liger model = NeMoAutoModelForImageTextToText.from_pretrained( … “Qwen/Qwen2.5-VL-3B-Instruct”, use_liger_kernel=False) # skip Liger
Initialization
- classmethod from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)[source]#
Load a pretrained causal-language-model and (optionally) patch it with custom kernels.
Parameters#
pretrained_model_name_or_path : str or os.PathLike Repository ID or local path accepted by
transformers.AutoModelForCausalLM.from_pretrained
. *model_args Positional arguments forwarded verbatim to the superclass. use_liger_kernel : bool, default True Whether to attempt patching the loaded model with Liger kernels. **kwargs Keyword arguments forwarded verbatim to the superclass.Returns:#
transformers.PreTrainedModel The instantiated model, possibly Liger-patched.
Warnings:#
Emits a
logging.warning
ifuse_liger_kernel=True
but the Liger package is not available.Retries#
If patching raises an exception, the method deletes the partially constructed model and recursively reloads it once with
use_liger_kernel=False
.
- classmethod from_config(config, **kwargs)[source]#
Instantiate a model from a config object and (optionally) patch it with custom kernels.
Parameters#
config : transformers.PretrainedConfig Configuration used to build the model. use_liger_kernel : bool, default True Whether to attempt patching the instantiated model with Liger kernels. **kwargs Additional keyword arguments forwarded to the superclass.
Returns:#
transformers.PreTrainedModel The instantiated model, possibly Liger-patched.
See Also:#
NeMoAutoModelForImageTextToText.from_pretrained : Same logic for checkpoint loading.