nemo_automodel._transformers.model_init#

Model resolution and initialization helpers.

Functions for resolving which model class to use (custom vs HF), downloading weights, applying config overrides, and instantiating the model.

Module Contents#

Functions#

_get_mixin_wrapped_class

Get a class that combines HFCheckpointingMixin with the original model class.

local_torch_dtype

Locally change the torch default dtype to dtype, and restore the old one upon exiting the context. If model_class_name is provided, it’s used to provide a more helpful error message if dtype is not valid.

_is_config_compatible_with_custom_model

Check if a HuggingFace config is compatible with our custom model implementation.

get_hf_config

Get the HF config for the model.

get_is_hf_model

Resolve trust_remote_code default and determine if model is HF-based.

_download_model_weights

_init_model

get_architectures

Get the architectures from the HF config.

_get_init_param_names

Best-effort extraction of explicit init parameter names (excluding self).

_consume_config_overrides

Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs.

_filter_kwargs_for_init

Filter kwargs down to what model_cls.__init__ explicitly accepts.

Data#

API#

nemo_automodel._transformers.model_init.logger#

‘getLogger(…)’

nemo_automodel._transformers.model_init._get_mixin_wrapped_class(model_class: type) type#

Get a class that combines HFCheckpointingMixin with the original model class.

If the class already has the mixin, returns it unchanged.

Parameters:

model_class – The original model class (e.g., LlamaForCausalLM)

Returns:

A class that inherits from both HFCheckpointingMixin and model_class

nemo_automodel._transformers.model_init.local_torch_dtype(
dtype: torch.dtype,
model_class_name: str | None = None,
default_dtype: torch.dtype = torch.bfloat16,
)#

Locally change the torch default dtype to dtype, and restore the old one upon exiting the context. If model_class_name is provided, it’s used to provide a more helpful error message if dtype is not valid.

nemo_automodel._transformers.model_init._is_config_compatible_with_custom_model(
arch_name: str,
config,
) bool#

Check if a HuggingFace config is compatible with our custom model implementation.

Some architectures (e.g., NemotronHForCausalLM) are shared between different model versions (v2 vs v3) but our custom implementation only supports specific versions. This function validates that the config has the required attributes for the custom implementation.

Parameters:
  • arch_name – The architecture name (e.g., “NemotronHForCausalLM”)

  • config – The HuggingFace config object

Returns:

True if the config is compatible with our custom implementation, False otherwise

nemo_automodel._transformers.model_init.get_hf_config(
pretrained_model_name_or_path,
attn_implementation,
**kwargs,
)#

Get the HF config for the model.

nemo_automodel._transformers.model_init.get_is_hf_model(config, force_hf)#

Resolve trust_remote_code default and determine if model is HF-based.

nemo_automodel._transformers.model_init._download_model_weights(hf_config, pretrained_model_name_or_path)#
nemo_automodel._transformers.model_init._init_model(
cls,
pretrained_model_name_or_path_or_config,
attn_implementation,
torch_dtype,
quantization_config,
force_hf,
*model_args,
**kwargs,
)#
nemo_automodel._transformers.model_init.get_architectures(hf_config)#

Get the architectures from the HF config.

nemo_automodel._transformers.model_init._get_init_param_names(model_cls) set[str]#

Best-effort extraction of explicit init parameter names (excluding self).

Returns an empty set if the signature cannot be inspected.

nemo_automodel._transformers.model_init._consume_config_overrides(
config,
kwargs: dict,
*,
init_param_names: set[str] | None = None,
) None#

Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs.

For custom model implementations we instantiate via model_cls(config, **kwargs), so passing config flags like output_hidden_states would crash. This helper moves such keys onto the config and removes them from kwargs.

nemo_automodel._transformers.model_init._filter_kwargs_for_init(model_cls, kwargs: dict) dict#

Filter kwargs down to what model_cls.__init__ explicitly accepts.

If the constructor has a **kwargs parameter (VAR_KEYWORD) or signature cannot be inspected, returns kwargs unchanged.