nemo_automodel._transformers.auto_model#
NeMo Auto Model classes.
Drop-in replacements for transformers.AutoModelFor* that add custom-kernel
patching, distributed infrastructure, PEFT, quantization, and checkpointing.
Heavy-lifting helpers live in sibling modules:
kernel_patches– SDPA / Liger kernel patchingmodel_init– model class resolution and instantiationinfrastructure– MeshContext, sharding, PEFT/quant application
Module Contents#
Classes#
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
Drop-in replacement for |
|
NeMo AutoModel for biencoder/embedding tasks with full infrastructure support. |
Data#
API#
- nemo_automodel._transformers.auto_model.logger#
‘getLogger(…)’
- class nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass#
Bases:
transformers.models.auto.auto_factory._BaseAutoModelClassDrop-in replacement for
_BaseAutoModelClassthat includes custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.TODO(@akoumpa): extend this beyond liger_kernel.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
- classmethod _from_pretrained_parent_class(*args, **kwargs)#
- classmethod _from_config_parent_class(*args, **kwargs)#
- classmethod _build_model(
- pretrained_model_name_or_path_or_config,
- *model_args,
- is_hf_model,
- use_liger_kernel,
- use_sdpa_patching,
- sdpa_method,
- torch_dtype,
- attn_implementation,
- quantization_config,
- force_hf,
- model_wrapper,
- autopipeline,
- parallelize_fn,
- qat_quantizer,
- mesh,
- loss_fn,
- peft_config,
- fp8_config,
- compile_config,
- load_base_model,
- **kwargs,
Shared model building logic for
from_pretrainedandfrom_config.Handles pre-load overrides, meta-device initialization, model init with attention-fallback retry, kernel patching (Liger, SDPA) with retry, and full infrastructure application (sharding, PEFT, quantization, checkpointing).
All caller-specific setup (config resolution, infrastructure instantiation,
is_hf_modeldetermination) is done byfrom_pretrained/from_configbefore delegating here.
- classmethod from_pretrained(
- pretrained_model_name_or_path,
- *model_args,
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
- torch_dtype='auto',
- attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
- quantization_config=None,
- force_hf: bool = False,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
- qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- activation_checkpointing: bool = False,
- peft_config: Optional[dict] = None,
- fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- **kwargs,
Instantiate and (optionally) patch a causal-language model.
This is a light wrapper around
transformers.AutoModelForCausalLM.from_pretrainedthat can automatically apply Liger and/or SDPA (scaled-dot-product attention) kernel optimizations, as well as PEFT, quantization, and distributed parallelism.- Parameters:
pretrained_model_name_or_path (str | os.PathLike) – Hugging Face hub repo ID or local path accepted by
AutoModelForCausalLM.from_pretrained.*model_args – Positional arguments forwarded verbatim to
AutoModelForCausalLM.from_pretrained.use_liger_kernel (bool, default=True) – If
True, try to patch the model with Liger kernels for faster inference/training.use_sdpa_patching (bool, default=True) – If
True, patch the model with SDPA-based attention optimizations.sdpa_method (list[SDPBackend] | None, optional) – Explicit list of SDPA back-ends to consider when
use_sdpa_patching=True.torch_dtype (str | torch.dtype | Literal["auto"], default="auto") – Data type passed to the underlying
from_pretrainedcall.attn_implementation (str, optional) – Specifies which attention implementation to use (e.g.,
"flash_attention_2","eager"). Only applied when the base model supports this kwarg. Defaults to"flash_attention_2", if flash attention is not available, defaults to"sdpa".quantization_config (optional) – BitsAndBytesConfig configuration object that specifies all quantization settings. If provided, quantization will be applied to the model.
force_hf (bool, default=False) – If
True, force the use of HF model implementation. IfFalse, the model will be loaded using the custom model implementation if available.device_mesh (DeviceMesh | None, optional) – Pre-created device mesh for distributed training. Parallelism sizes (tp, pp, cp, ep) are inferred from this. Default: None.
moe_mesh (DeviceMesh | None, optional) – FSDP2-only. Device mesh for expert parallelism. ep_size is inferred from this. Default: None.
tp_plan (dict | None, optional) – Custom tensor parallel plan. If provided, overrides the tp_plan on distributed_config. Default: None.
distributed_config (FSDP2Config | MegatronFSDPConfig | DDPConfig | None, optional) – Strategy-specific distributed training configuration. Default: None.
pipeline_config (PipelineConfig | None, optional) – Pipeline parallelism configuration including loss_fn. Default: None.
qat_config (QATConfig | None, optional) – Quantization-Aware Training configuration. Default: None.
moe_config (MoEParallelizerConfig | None, optional) – MoE parallelizer configuration. Default: None.
activation_checkpointing (bool, default=False) – Enable activation checkpointing for transformer blocks to reduce memory usage. Default: False.
peft_config (dict | None, optional) – PEFT/LoRA configuration dictionary. If provided, LoRA adapters will be applied to the model. Default: None.
fp8_config (FP8Config | None, optional) – FP8 quantization configuration. If provided, FP8 quantization will be applied. Default: None.
compile_config (CompileConfig | None, optional) – Configuration for torch.compile. If provided, the model will be compiled. Default: None.
**kwargs –
Additional keyword arguments. Notable ones include:
has_packed_sequence (bool): Whether using packed sequences. Default: False.
cache_dir (str): Cache directory for model weights.
- Returns:
The loaded (and possibly patched) model instance with all infrastructure applied.
- Return type:
transformers.PreTrainedModel
- classmethod from_config(
- config,
- *model_args,
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
- torch_dtype: Union[str, torch.dtype] = 'auto',
- attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
- quantization_config=None,
- force_hf: bool = False,
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
- qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- activation_checkpointing: bool = False,
- peft_config: Optional[dict] = None,
- fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- **kwargs,
Instantiate a model from a
transformers.PretrainedConfig(no pretrained weights). Accepts the same infrastructure arguments asfrom_pretrained.See
from_pretrainedfor full parameter documentation.- Parameters:
config (transformers.PretrainedConfig | str) – The configuration object used to build the model. If config is passed as a string (e.g., model-id / local checkpoint), it will create a config internally using AutoConfig.
torch_dtype (str | torch.dtype, default="auto") – Data type for model parameters. If “auto”, defaults to
torch.bfloat16.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForCausalLM#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForCausalLMDrop-in replacement for
transformers.AutoModelForCausalLMthat includes custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.TODO(@akoumpa): extend this beyond liger_kernel.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForCausalLM.from_pretrained(“gpt2”) # try Liger model = NeMoAutoModelForCausalLM.from_pretrained( … “gpt2”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForImageTextToText#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForImageTextToTextDrop-in replacement for
transformers.AutoModelForImageTextToTextwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForImageTextToText.from_pretrained(“Qwen/Qwen2.5-VL-3B-Instruct”) # try Liger model = NeMoAutoModelForImageTextToText.from_pretrained( … “Qwen/Qwen2.5-VL-3B-Instruct”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForMultimodalLM#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForMultimodalLMDrop-in replacement for
transformers.AutoModelForMultimodalLMwith custom-kernels.
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForSequenceClassification#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForSequenceClassificationDrop-in replacement for
transformers.AutoModelForSequenceClassificationwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForSequenceClassification.from_pretrained(“bert-base-uncased”) # try Liger model = NeMoAutoModelForSequenceClassification.from_pretrained( … “bert-base-uncased”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelForTextToWaveform#
Bases:
nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass,transformers.AutoModelForTextToWaveformDrop-in replacement for
transformers.AutoModelForTextToWaveformwith custom-kernels.The class only overrides
from_pretrainedandfrom_configto add the optionaluse_liger_kernelflag. If the flag isTrue(default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once withuse_liger_kernel=Falseso that users still obtain a functional model.@akoumpa: currently only supporting liger_kernel for demonstration purposes.
Notes:#
No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.
Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.
Examples:#
model = NeMoAutoModelForTextToWaveform.from_pretrained(“facebook/musicgen-small”) # try Liger model = NeMoAutoModelForTextToWaveform.from_pretrained( … “facebook/musicgen-small”, use_liger_kernel=False) # skip Liger
- class nemo_automodel._transformers.auto_model.NeMoAutoModelBiencoder#
NeMo AutoModel for biencoder/embedding tasks with full infrastructure support.
This class provides a unified interface for loading biencoder models with support for PEFT, FSDP, TP, CP, FP8, QAT, and other infrastructure features. It uses the BiencoderModel.build() method to create the model and then applies all infrastructure through apply_model_infrastructure().
This class properly integrates with the model registry and applies all kernel patching and infrastructure support.
Examples:#
model = NeMoAutoModelBiencoder.from_pretrained(“meta-llama/Llama-3.2-1B”) model = NeMoAutoModelBiencoder.from_pretrained( … “meta-llama/Llama-3.2-1B”, … distributed_config=FSDP2Config(), … )
- classmethod from_pretrained(
- pretrained_model_name_or_path: str,
- share_encoder: bool = True,
- pooling: str = 'avg',
- l2_normalize: bool = True,
- attn_implementation: str = 'flash_attention_2',
- use_liger_kernel: bool = True,
- use_sdpa_patching: bool = True,
- sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
- torch_dtype='auto',
- device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
- tp_plan: Optional[dict] = None,
- distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
- moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
- compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
- peft_config: Optional[dict] = None,
- **kwargs,
Load a biencoder model from pretrained weights with full infrastructure support.
This method builds a biencoder using BiencoderModel.build(), applies kernel patching, and then applies all infrastructure (FSDP, checkpointing, etc.) through apply_model_infrastructure().
- Parameters:
pretrained_model_name_or_path – Path to pretrained model or model identifier.
share_encoder – Whether to share encoder weights between query and passage.
pooling – Pooling strategy (‘avg’, ‘cls’, ‘last’, etc.).
l2_normalize – Whether to L2 normalize embeddings.
attn_implementation – Attention implementation to use (e.g.,
"flash_attention_2","sdpa","eager"). Defaults to"flash_attention_2".use_liger_kernel – Whether to apply Liger kernel optimizations.
use_sdpa_patching – Whether to apply SDPA patching.
sdpa_method – SDPA backend methods to use.
torch_dtype – Data type passed to the underlying model initialization.
device_mesh – Pre-created device mesh for distributed training.
moe_mesh – Device mesh for expert parallelism (FSDP2 only).
tp_plan – Custom tensor parallel plan; overrides distributed_config.tp_plan.
distributed_config – Strategy-specific distributed training configuration.
moe_config – MoE parallelizer configuration.
compile_config – Configuration for torch.compile.
**kwargs – Additional arguments passed to BiencoderModel.build.
- Returns:
BiencoderModel instance with loaded weights and all infrastructure applied.
.. rubric:: Notes
If kernel patching fails, the method retries with adjusted parameters.