nemo_automodel._transformers.auto_model#

NeMo Auto Model classes.

Drop-in replacements for transformers.AutoModelFor* that add custom-kernel patching, distributed infrastructure, PEFT, quantization, and checkpointing.

Heavy-lifting helpers live in sibling modules:

  • kernel_patches – SDPA / Liger kernel patching

  • model_init – model class resolution and instantiation

  • infrastructure – MeshContext, sharding, PEFT/quant application

Module Contents#

Classes#

_BaseNeMoAutoModelClass

Drop-in replacement for _BaseAutoModelClass that includes custom-kernels.

NeMoAutoModelForCausalLM

Drop-in replacement for transformers.AutoModelForCausalLM that includes custom-kernels.

NeMoAutoModelForImageTextToText

Drop-in replacement for transformers.AutoModelForImageTextToText with custom-kernels.

NeMoAutoModelForMultimodalLM

Drop-in replacement for transformers.AutoModelForMultimodalLM with custom-kernels.

NeMoAutoModelForSequenceClassification

Drop-in replacement for transformers.AutoModelForSequenceClassification with custom-kernels.

NeMoAutoModelForTextToWaveform

Drop-in replacement for transformers.AutoModelForTextToWaveform with custom-kernels.

Data#

API#

nemo_automodel._transformers.auto_model.logger#

‘getLogger(…)’

class nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass#

Bases: transformers.models.auto.auto_factory._BaseAutoModelClass

Drop-in replacement for _BaseAutoModelClass that includes custom-kernels.

The class only overrides from_pretrained and from_config to add the optional use_liger_kernel flag. If the flag is True (default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with use_liger_kernel=False so that users still obtain a functional model.

TODO(@akoumpa): extend this beyond liger_kernel.

Notes:#

  • No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.

  • Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.

classmethod _from_pretrained_parent_class(*args, **kwargs)#
classmethod _from_config_parent_class(*args, **kwargs)#
classmethod _build_model(
pretrained_model_name_or_path_or_config,
*model_args,
is_hf_model,
use_liger_kernel,
use_sdpa_patching,
sdpa_method,
torch_dtype,
attn_implementation,
quantization_config,
force_hf,
model_wrapper,
autopipeline,
parallelize_fn,
qat_quantizer,
mesh,
loss_fn,
peft_config,
fp8_config,
compile_config,
load_base_model,
**kwargs,
)#

Shared model building logic for from_pretrained and from_config.

Handles pre-load overrides, meta-device initialization, model init with attention-fallback retry, kernel patching (Liger, SDPA) with retry, and full infrastructure application (sharding, PEFT, quantization, checkpointing).

All caller-specific setup (config resolution, infrastructure instantiation, is_hf_model determination) is done by from_pretrained / from_config before delegating here.

classmethod from_pretrained(
pretrained_model_name_or_path,
*model_args,
use_liger_kernel: bool = True,
use_sdpa_patching: bool = True,
sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
torch_dtype='auto',
attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
quantization_config=None,
force_hf: bool = False,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
tp_plan: Optional[dict] = None,
distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
activation_checkpointing: bool = False,
peft_config: Optional[dict] = None,
fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
**kwargs,
) transformers.PreTrainedModel#

Instantiate and (optionally) patch a causal-language model.

This is a light wrapper around transformers.AutoModelForCausalLM.from_pretrained that can automatically apply Liger and/or SDPA (scaled-dot-product attention) kernel optimizations, as well as PEFT, quantization, and distributed parallelism.

Parameters:
  • pretrained_model_name_or_path (str | os.PathLike) – Hugging Face hub repo ID or local path accepted by AutoModelForCausalLM.from_pretrained.

  • *model_args – Positional arguments forwarded verbatim to AutoModelForCausalLM.from_pretrained.

  • use_liger_kernel (bool, default=True) – If True, try to patch the model with Liger kernels for faster inference/training.

  • use_sdpa_patching (bool, default=True) – If True, patch the model with SDPA-based attention optimizations.

  • sdpa_method (list[SDPBackend] | None, optional) – Explicit list of SDPA back-ends to consider when use_sdpa_patching=True.

  • torch_dtype (str | torch.dtype | Literal["auto"], default="auto") – Data type passed to the underlying from_pretrained call.

  • attn_implementation (str, optional) – Specifies which attention implementation to use (e.g., "flash_attention_2", "eager"). Only applied when the base model supports this kwarg. Defaults to "flash_attention_2", if flash attention is not available, defaults to "sdpa".

  • quantization_config (optional) – BitsAndBytesConfig configuration object that specifies all quantization settings. If provided, quantization will be applied to the model.

  • force_hf (bool, default=False) – If True, force the use of HF model implementation. If False, the model will be loaded using the custom model implementation if available.

  • device_mesh (DeviceMesh | None, optional) – Pre-created device mesh for distributed training. Parallelism sizes (tp, pp, cp, ep) are inferred from this. Default: None.

  • moe_mesh (DeviceMesh | None, optional) – FSDP2-only. Device mesh for expert parallelism. ep_size is inferred from this. Default: None.

  • tp_plan (dict | None, optional) – Custom tensor parallel plan. If provided, overrides the tp_plan on distributed_config. Default: None.

  • distributed_config (FSDP2Config | MegatronFSDPConfig | DDPConfig | None, optional) – Strategy-specific distributed training configuration. Default: None.

  • pipeline_config (PipelineConfig | None, optional) – Pipeline parallelism configuration including loss_fn. Default: None.

  • qat_config (QATConfig | None, optional) – Quantization-Aware Training configuration. Default: None.

  • moe_config (MoEParallelizerConfig | None, optional) – MoE parallelizer configuration. Default: None.

  • activation_checkpointing (bool, default=False) – Enable activation checkpointing for transformer blocks to reduce memory usage. Default: False.

  • peft_config (dict | None, optional) – PEFT/LoRA configuration dictionary. If provided, LoRA adapters will be applied to the model. Default: None.

  • fp8_config (FP8Config | None, optional) – FP8 quantization configuration. If provided, FP8 quantization will be applied. Default: None.

  • compile_config (CompileConfig | None, optional) – Configuration for torch.compile. If provided, the model will be compiled. Default: None.

  • **kwargs

    Additional keyword arguments. Notable ones include:

    • has_packed_sequence (bool): Whether using packed sequences. Default: False.

    • cache_dir (str): Cache directory for model weights.

Returns:

The loaded (and possibly patched) model instance with all infrastructure applied.

Return type:

transformers.PreTrainedModel

classmethod from_config(
config,
*model_args,
use_liger_kernel: bool = True,
use_sdpa_patching: bool = True,
sdpa_method: Optional[List[torch.nn.attention.SDPBackend]] = None,
torch_dtype: Union[str, torch.dtype] = 'auto',
attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION,
quantization_config=None,
force_hf: bool = False,
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
tp_plan: Optional[dict] = None,
distributed_config: Optional[nemo_automodel.components.distributed.config.DistributedConfig] = None,
pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
qat_config: Optional[nemo_automodel.components.quantization.qat.QATConfig] = None,
moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
activation_checkpointing: bool = False,
peft_config: Optional[dict] = None,
fp8_config: Optional[nemo_automodel.components.quantization.fp8.FP8Config] = None,
compile_config: Optional[nemo_automodel.components.utils.compile_utils.CompileConfig] = None,
**kwargs,
) transformers.PreTrainedModel#

Instantiate a model from a transformers.PretrainedConfig (no pretrained weights). Accepts the same infrastructure arguments as from_pretrained.

See from_pretrained for full parameter documentation.

Parameters:
  • config (transformers.PretrainedConfig | str) – The configuration object used to build the model. If config is passed as a string (e.g., model-id / local checkpoint), it will create a config internally using AutoConfig.

  • torch_dtype (str | torch.dtype, default="auto") – Data type for model parameters. If “auto”, defaults to torch.bfloat16.

class nemo_automodel._transformers.auto_model.NeMoAutoModelForCausalLM#

Bases: nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass, transformers.AutoModelForCausalLM

Drop-in replacement for transformers.AutoModelForCausalLM that includes custom-kernels.

The class only overrides from_pretrained and from_config to add the optional use_liger_kernel flag. If the flag is True (default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with use_liger_kernel=False so that users still obtain a functional model.

TODO(@akoumpa): extend this beyond liger_kernel.

Notes:#

  • No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.

  • Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.

Examples:#

model = NeMoAutoModelForCausalLM.from_pretrained(“gpt2”) # try Liger model = NeMoAutoModelForCausalLM.from_pretrained( … “gpt2”, use_liger_kernel=False) # skip Liger

class nemo_automodel._transformers.auto_model.NeMoAutoModelForImageTextToText#

Bases: nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass, transformers.AutoModelForImageTextToText

Drop-in replacement for transformers.AutoModelForImageTextToText with custom-kernels.

The class only overrides from_pretrained and from_config to add the optional use_liger_kernel flag. If the flag is True (default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with use_liger_kernel=False so that users still obtain a functional model.

@akoumpa: currently only supporting liger_kernel for demonstration purposes.

Notes:#

  • No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.

  • Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.

Examples:#

model = NeMoAutoModelForImageTextToText.from_pretrained(“Qwen/Qwen2.5-VL-3B-Instruct”) # try Liger model = NeMoAutoModelForImageTextToText.from_pretrained( … “Qwen/Qwen2.5-VL-3B-Instruct”, use_liger_kernel=False) # skip Liger

class nemo_automodel._transformers.auto_model.NeMoAutoModelForMultimodalLM#

Bases: nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass, transformers.AutoModelForMultimodalLM

Drop-in replacement for transformers.AutoModelForMultimodalLM with custom-kernels.

class nemo_automodel._transformers.auto_model.NeMoAutoModelForSequenceClassification#

Bases: nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass, transformers.AutoModelForSequenceClassification

Drop-in replacement for transformers.AutoModelForSequenceClassification with custom-kernels.

The class only overrides from_pretrained and from_config to add the optional use_liger_kernel flag. If the flag is True (default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with use_liger_kernel=False so that users still obtain a functional model.

@akoumpa: currently only supporting liger_kernel for demonstration purposes.

Notes:#

  • No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.

  • Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.

Examples:#

model = NeMoAutoModelForSequenceClassification.from_pretrained(“bert-base-uncased”) # try Liger model = NeMoAutoModelForSequenceClassification.from_pretrained( … “bert-base-uncased”, use_liger_kernel=False) # skip Liger

class nemo_automodel._transformers.auto_model.NeMoAutoModelForTextToWaveform#

Bases: nemo_automodel._transformers.auto_model._BaseNeMoAutoModelClass, transformers.AutoModelForTextToWaveform

Drop-in replacement for transformers.AutoModelForTextToWaveform with custom-kernels.

The class only overrides from_pretrained and from_config to add the optional use_liger_kernel flag. If the flag is True (default) and the Liger kernel is available, the model’s attention layers are monkey-patched in place. If patching fails for any reason, the call is retried once with use_liger_kernel=False so that users still obtain a functional model.

@akoumpa: currently only supporting liger_kernel for demonstration purposes.

Notes:#

  • No changes are made to the model’s public API; forward signatures, generation utilities, and weight shapes remain identical.

  • Only decoder-style (causal) architectures are currently supported by the Liger patch. Unsupported models will silently fall back.

Examples:#

model = NeMoAutoModelForTextToWaveform.from_pretrained(“facebook/musicgen-small”) # try Liger model = NeMoAutoModelForTextToWaveform.from_pretrained( … “facebook/musicgen-small”, use_liger_kernel=False) # skip Liger