nemo_automodel._transformers.auto_model#
NeMo Auto Model classes.
Drop-in replacements for transformers.AutoModelFor* that add custom-kernel
patching, distributed infrastructure, PEFT, quantization, and checkpointing.
Heavy-lifting helpers live in sibling modules:
kernel_patches– SDPA / Liger kernel patchingmodel_init– model class resolution and instantiationinfrastructure– MeshContext, sharding, PEFT/quant application
Module Contents#
Classes#
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Private shared base for encoder auto-models. |
|
NeMo AutoModel for bi-encoder embedding tasks with full infrastructure support. |
|
NeMo AutoModel for cross-encoder scoring tasks with full infrastructure support. |
Functions#
Patch |
|
Set |
Data#
API#
- nemo_automodel._transformers.auto_model.logger#
‘getLogger(…)’
- nemo_automodel._transformers.auto_model._MAX_BUILD_RETRIES#
5
- nemo_automodel._transformers.auto_model._remote_code_compat_applied#
False
- nemo_automodel._transformers.auto_model._patch_remote_code_compat()#
Patch
_finalize_model_loadingfor remote-code models written against older transformers.Remote-code models (
trust_remote_code=True) may be incompatible with the installed transformers in several ways:Missing
all_tied_weights_keys– set inpost_init()which the model may never call.Overridden
tie_weights()with an old signature that doesn’t accept themissing_keyskwarg added in newer transformers.
This one-time patch wraps
_finalize_model_loadingto fix these issues on the fly. For models that are already compatible the guards are no-ops.
- nemo_automodel._transformers.auto_model._maybe_dequantize_fp8_for_peft(
- hf_native_quant_cfg,
- peft_config,
- pretrained_path,
Set
dequantize=Trueon FP8 quantization configs when PEFT is requested.Returns True if the config was mutated, False otherwise.
- class nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass#
Bases:
transformers.models.auto.auto_factory._BaseAutoModelClassDrop-in replacement for
_BaseAutoModelClassthat includes custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.TODO(@akoumpa): extend this beyond liger_kernel.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
- classmethod _from_pretrained_parent_class(*args, **kwargs)#
- classmethod _from_config_parent_class(*args, **kwargs)#
- classmethod _build_model(
- pretrained_model_name_or_path_or_config,
- *model_args,
- is_hf_model,
- use_liger_kernel,
- use_sdpa_patching,
- sdpa_method,
- torch_dtype,
- attn_implementation,
- quantization_config,
- force_hf,
- model_wrapper,
- autopipeline,
- parallelize_fn,
- qat_quantizer,
- mesh,
- loss_fn,
- peft_config,
- fp8_config,
- compile_config,
- load_base_model,
- _retry_depth=0,
- **kwargs,
Shared model building logic for
from_pretrainedandfrom_config.Handles pre-load overrides, meta-device initialization, model init with attention-fallback retry, kernel patching (Liger, SDPA) with retry, and full infrastructure application (sharding, PEFT, quantization, checkpointing).
All caller-specific setup (config resolution, infrastructure instantiation,
is_hf_modeldetermination) is done byfrom_pretrained/from_configbefore delegating here.
- classmethod from_pretrained(
- pretrained_model_name_or_path,
- *model_args,
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[Union[torch.nn.attention.SDPBackend, str]]] = None,
- torch_dtype='auto',
- attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
- quantization_config=None,
- force_hf: bool = False,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
- qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- activation_checkpointing: bool = False,
- peft_config: Optional[dict] = None,
- fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- **kwargs,
Instantiate and (optionally) patch a causal-language model.
This is a light wrapper around
transformers.AutoModelForCausalLM.from_pretrainedthat can automatically apply Liger and/or SDPA (scaled-dot-product attention) kernel optimizations, as well as PEFT, quantization, and distributed parallelism.- Parameters:
pretrained_model_name_or_path (str | os.PathLike) – Hugging Face hub repo ID or local path accepted by
AutoModelForCausalLM.from_pretrained.*model_args – Positional arguments forwarded verbatim to
AutoModelForCausalLM.from_pretrained.use_liger_kernel (bool, default=True) – If
True, try to patch the model with Liger kernels for faster inference/training.use_sdpa_patching (bool, default=True) – If
True, patch the model with SDPA-based attention optimizations.sdpa_method (list[SDPBackend | str] | None, optional) – Explicit list of SDPA back-ends to consider when
use_sdpa_patching=True. Accepts both SDPBackend enum values and string names (e.g.["flash_attention", "efficient_attention"]). WhenNone, auto-selects based on CP and activation checkpointing.torch_dtype (str | torch.dtype | Literal["auto"], default="auto") – Data type passed to the underlying
from_pretrainedcall.attn_implementation (str, optional) – Specifies which attention implementation to use (e.g.,
"flash_attention_2","eager"). Only applied when the base model supports this kwarg. Defaults to"flash_attention_2", if flash attention is not available, defaults to"sdpa".quantization_config (optional) – BitsAndBytesConfig configuration object that specifies all quantization settings. If provided, quantization will be applied to the model.
force_hf (bool, default=False) – If
True, force the use of HF model implementation. IfFalse, the model will be loaded using the custom model implementation if available.device_mesh (DeviceMesh | None, optional) – Pre-created device mesh for distributed training. Parallelism sizes (tp, pp, cp, ep) are inferred from this. Default: None.
moe_mesh (DeviceMesh | None, optional) – FSDP2-only. Device mesh for expert parallelism. ep_size is inferred from this. Default: None.
tp_plan (dict | None, optional) – Custom tensor parallel plan. If provided, overrides the tp_plan on distributed_config. Default: None.
distributed_config (FSDP2Config | MegatronFSDPConfig | DDPConfig | None, optional) – Strategy-specific distributed training configuration. Default: None.
pipeline_config (PipelineConfig | None, optional) – Pipeline parallelism configuration including loss_fn. Default: None.
qat_config (QATConfig | None, optional) – Quantization-Aware Training configuration. Default: None.
moe_config (MoEParallelizerConfig | None, optional) – MoE parallelizer configuration. Default: None.
activation_checkpointing (bool, default=False) – Enable activation checkpointing for transformer blocks to reduce memory usage. Default: False.
peft_config (dict | None, optional) – PEFT/LoRA configuration dictionary. If provided, LoRA adapters will be applied to the model. Default: None.
fp8_config (FP8Config | None, optional) – FP8 quantization configuration. If provided, FP8 quantization will be applied. Default: None.
compile_config (CompileConfig | None, optional) – Configuration for torch.compile. If provided, the model will be compiled. Default: None.
**kwargs –
Additional keyword arguments. Notable ones include:
has_packed_sequence (bool): Whether using packed sequences. Default: False.
cache_dir (str): Cache directory for model weights.
- Returns:
The loaded (and possibly patched) model instance with all infrastructure applied.
- Return type:
transformers.PreTrainedModel
- classmethod from_config(
- config,
- *model_args,
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[Union[torch.nn.attention.SDPBackend, str]]] = None,
- torch_dtype: Union[str, torch.dtype] = 'auto',
- attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
- quantization_config=None,
- force_hf: bool = False,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
- qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- activation_checkpointing: bool = False,
- peft_config: Optional[dict] = None,
- fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- **kwargs,
Instantiate a model from a
transformers.PretrainedConfig(no pretrained weights). Accepts the same infrastructure arguments asfrom_pretrained.See
from_pretrainedfor full parameter documentation.- Parameters:
config (transformers.PretrainedConfig | str) – The configuration object used to build the model. If config is passed as a string (e.g., model-id / local checkpoint), it will create a config internally using AutoConfig.
torch_dtype (str | torch.dtype, default="auto") – Data type for model parameters. If “auto”, defaults to
torch.bfloat16.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForCausalLM#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForCausalLMDrop-in replacement for
transformers.AutoModelForCausalLMthat includes custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.TODO(@akoumpa): extend this beyond liger_kernel.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForCausalLM.from_pretrained(“gpt2”) # try Liger model = NeMoAutoModelForCausalLM.from_pretrained( … “gpt2”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForImageTextToText#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForImageTextToTextDrop-in replacement for
transformers.AutoModelForImageTextToTextwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForImageTextToText.from_pretrained(“Qwen/Qwen2.5-VL-3B-Instruct”) # try Liger model = NeMoAutoModelForImageTextToText.from_pretrained( … “Qwen/Qwen2.5-VL-3B-Instruct”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForMultimodalLM#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForMultimodalLMDrop-in replacement for
transformers.AutoModelForMultimodalLMwith custom-kernels.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForSequenceClassification#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForSequenceClassificationDrop-in replacement for
transformers.AutoModelForSequenceClassificationwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForSequenceClassification.from_pretrained(“bert-base-uncased”) # try Liger model = NeMoAutoModelForSequenceClassification.from_pretrained( … “bert-base-uncased”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForTextToWaveform#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForTextToWaveformDrop-in replacement for
transformers.AutoModelForTextToWaveformwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForTextToWaveform.from_pretrained(“facebook/musicgen-small”) # try Liger model = NeMoAutoModelForTextToWaveform.from_pretrained( … “facebook/musicgen-small”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model._NeMoAutoModelForRetrievalBase#
Private shared base for encoder auto-models.
Subclasses set
_ENCODER_CLS_NAMEto select the concrete encoder class fromnemo_automodel._transformers.retrieval.- _ENCODER_CLS_NAME: Optional[str]#
None
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
- torch_dtype='auto',
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- peft_config: Optional[dict] = None,
- **kwargs,
Load an encoder model with infrastructure (FSDP, PEFT, kernel patching, etc.).
This method builds an encoder via the subclass’s
_ENCODER_CLS_NAME, applies kernel patching, and then applies all infrastructure (FSDP, checkpointing, etc.) throughapply_model_infrastructure().- Parameters:
pretrained_model_name_or_path – Path to pretrained model or model identifier.
attn_implementation – Attention implementation to use (e.g.,
"flash_attention_2","sdpa","eager"). Defaults toDEFAULT_ATTN_IMPLEMENTATION("flash_attention_2"when flash-attn is installed, otherwise"sdpa").use_liger_kernel – Whether to apply Liger kernel optimizations.
use_sdpa_patching – Whether to apply SDPA patching.
sdpa_method – SDPA backend methods to use.
torch_dtype – Data type passed to the underlying model initialization.
device_mesh – Pre-created device mesh for distributed training.
moe_mesh – Device mesh for expert parallelism (FSDP2 only).
tp_plan – Custom tensor parallel plan; overrides distributed_config.tp_plan.
distributed_config – Strategy-specific distributed training configuration.
moe_config – MoE parallelizer configuration.
compile_config – Configuration for torch.compile.
peft_config – PEFT/LoRA configuration dictionary.
**kwargs – Additional arguments passed to the encoder’s
build()method.
- Returns:
Encoder model instance with loaded weights and all infrastructure applied.
.. rubric:: Notes
If kernel patching fails, the method retries with adjusted parameters.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelBiEncoder#
Bases:
nemo_automodel._transformers.auto_model._NeMoAutoModelForRetrievalBaseNeMo AutoModel for bi-encoder embedding tasks with full infrastructure support.
Wraps
BiEncoderModel.build()with kernel patching, PEFT, FSDP, and other distributed infrastructure viaapply_model_infrastructure().Examples:#
model = NeMoAutoModelBiEncoder.from_pretrained(“meta-llama/Llama-3.2-1B”) model = NeMoAutoModelBiEncoder.from_pretrained( … “meta-llama/Llama-3.2-1B”, … pooling=”cls”, … l2_normalize=False, … distributed_config=FSDP2Config(), … )
- _ENCODER_CLS_NAME#
‘BiEncoderModel’
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- pooling: str = 'avg',
- l2_normalize: bool = True,
- **kwargs,
Load a bi-encoder model with infrastructure.
Accepts all arguments from
_NeMoAutoModelForRetrievalBase.from_pretrainedplus the bi-encoder-specific parameters below.- Parameters:
pretrained_model_name_or_path – Path to pretrained model or model identifier.
pooling – Pooling strategy (
'avg','cls','last', etc.).l2_normalize – Whether to L2-normalize embeddings.
**kwargs – Forwarded to
_NeMoAutoModelForRetrievalBase.from_pretrained.
- Returns:
BiEncoderModel instance with loaded weights and all infrastructure applied.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelCrossEncoder#
Bases:
nemo_automodel._transformers.auto_model._NeMoAutoModelForRetrievalBaseNeMo AutoModel for cross-encoder scoring tasks with full infrastructure support.
Wraps
CrossEncoderModel.build()with kernel patching, PEFT, FSDP, and other distributed infrastructure viaapply_model_infrastructure().Examples:#
model = NeMoAutoModelCrossEncoder.from_pretrained(“meta-llama/Llama-3.2-1B”) model = NeMoAutoModelCrossEncoder.from_pretrained( … “meta-llama/Llama-3.2-1B”, … distributed_config=FSDP2Config(), … )
- _ENCODER_CLS_NAME#
‘CrossEncoderModel’