nemo_automodel._transformers.model_init#

Model resolution and initialization helpers.

Functions for resolving which model class to use (custom vs HF), downloading weights, applying config overrides, and instantiating the model.

Module Contents#

Functions#

_get_hf_meta_device_disabled

no_hf_meta_device

Disable HuggingFace’s meta device in get_init_context so model is built on real device.

_filter_meta_device_from_init_context

Remove torch.device(‘meta’) from HF init context list when we want real-device init.

_patched_get_init_context

Wrapper around PreTrainedModel.get_init_context that strips meta device when requested.

_get_mixin_wrapped_class

Get a class that combines HFCheckpointingMixin with the original model class.

local_torch_dtype

Locally change the torch default dtype to dtype, and restore the old one upon exiting the context. If model_class_name is provided, it’s used to provide a more helpful error message if dtype is not valid.

_is_config_compatible_with_custom_model

Check if a HuggingFace config is compatible with our custom model implementation.

_resolve_custom_model_cls_for_config

Resolve the custom model class for config, if the config is compatible.

get_hf_config

Get the HF config for the model.

get_is_hf_model

Determine whether the model should use the HF (not custom) implementation.

_download_model_weights

_init_model

get_architectures

Get the architectures from the HF config.

_get_init_param_names

Best-effort extraction of explicit init parameter names (excluding self).

_consume_config_overrides

Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs.

_filter_kwargs_for_init

Filter kwargs down to what model_cls.__init__ explicitly accepts.

resolve_sdpa_method

Resolve SDPA backend list from config strings or runtime constraints.

Data#

API#

nemo_automodel._transformers.model_init.logger#

‘getLogger(…)’

nemo_automodel._transformers.model_init._hf_meta_device_disabled#

‘local(…)’

nemo_automodel._transformers.model_init._get_hf_meta_device_disabled()#
nemo_automodel._transformers.model_init.no_hf_meta_device()#

Disable HuggingFace’s meta device in get_init_context so model is built on real device.

nemo_automodel._transformers.model_init._filter_meta_device_from_init_context(contexts)#

Remove torch.device(‘meta’) from HF init context list when we want real-device init.

nemo_automodel._transformers.model_init._patched_get_init_context(cls, *args, **kwargs)#

Wrapper around PreTrainedModel.get_init_context that strips meta device when requested.

nemo_automodel._transformers.model_init._original_get_init_context#

None

nemo_automodel._transformers.model_init._get_mixin_wrapped_class(model_class: type) type#

Get a class that combines HFCheckpointingMixin with the original model class.

If the class already has the mixin, returns it unchanged.

Parameters:

model_class – The original model class (e.g., LlamaForCausalLM)

Returns:

A class that inherits from both HFCheckpointingMixin and model_class

nemo_automodel._transformers.model_init.local_torch_dtype(
dtype: torch.dtype,
model_class_name: str | None = None,
default_dtype: torch.dtype = torch.bfloat16,
)#

Locally change the torch default dtype to dtype, and restore the old one upon exiting the context. If model_class_name is provided, it’s used to provide a more helpful error message if dtype is not valid.

nemo_automodel._transformers.model_init._is_config_compatible_with_custom_model(
arch_name: str,
config,
) bool#

Check if a HuggingFace config is compatible with our custom model implementation.

Some architectures (e.g., NemotronHForCausalLM) are shared between different model versions (v2 vs v3) but our custom implementation only supports specific versions. This function validates that the config has the required attributes for the custom implementation.

Parameters:
  • arch_name – The architecture name (e.g., “NemotronHForCausalLM”)

  • config – The HuggingFace config object

Returns:

True if the config is compatible with our custom implementation, False otherwise

nemo_automodel._transformers.model_init._resolve_custom_model_cls_for_config(config)#

Resolve the custom model class for config, if the config is compatible.

nemo_automodel._transformers.model_init.get_hf_config(
pretrained_model_name_or_path,
attn_implementation,
**kwargs,
)#

Get the HF config for the model.

nemo_automodel._transformers.model_init.get_is_hf_model(config, force_hf)#

Determine whether the model should use the HF (not custom) implementation.

nemo_automodel._transformers.model_init._download_model_weights(hf_config, pretrained_model_name_or_path)#
nemo_automodel._transformers.model_init._init_model(
cls,
pretrained_model_name_or_path_or_config,
attn_implementation,
torch_dtype,
quantization_config,
force_hf,
*model_args,
**kwargs,
)#
nemo_automodel._transformers.model_init.get_architectures(hf_config)#

Get the architectures from the HF config.

nemo_automodel._transformers.model_init._get_init_param_names(model_cls) set[str]#

Best-effort extraction of explicit init parameter names (excluding self).

Returns an empty set if the signature cannot be inspected.

nemo_automodel._transformers.model_init._consume_config_overrides(
config,
kwargs: dict,
*,
init_param_names: set[str] | None = None,
) None#

Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs.

For custom model implementations we instantiate via model_cls(config, **kwargs), so passing config flags like output_hidden_states would crash. This helper moves such keys onto the config and removes them from kwargs.

nemo_automodel._transformers.model_init._filter_kwargs_for_init(model_cls, kwargs: dict) dict#

Filter kwargs down to what model_cls.__init__ explicitly accepts.

If the constructor has a **kwargs parameter (VAR_KEYWORD) or signature cannot be inspected, returns kwargs unchanged.

nemo_automodel._transformers.model_init.resolve_sdpa_method(
sdpa_method: list | None = None,
device_mesh=None,
activation_checkpointing: bool = False,
) list[torch.nn.attention.SDPBackend] | None#

Resolve SDPA backend list from config strings or runtime constraints.

When sdpa_method is provided (e.g. from YAML), string values are converted to :class:torch.nn.attention.SDPBackend enum members. Already-resolved SDPBackend values are passed through unchanged. When None, automatic defaults are applied based on context parallelism and activation checkpointing settings.

Valid string values (case-insensitive): flash_attention, efficient_attention, math, cudnn_attention.

Parameters:
  • sdpa_method – List of backend name strings or SDPBackend enum values, or None to use automatic defaults.

  • device_mesh – Device mesh for distributed training.

  • activation_checkpointing – Whether activation checkpointing is enabled.

Returns:

Ordered list of :class:SDPBackend members, or None to use PyTorch’s default selection.