nemo_automodel._transformers.model_init#
Model resolution and initialization helpers.
Functions for resolving which model class to use (custom vs HF), downloading weights, applying config overrides, and instantiating the model.
Module Contents#
Functions#
Disable HuggingFace’s meta device in get_init_context so model is built on real device. |
|
Remove torch.device(‘meta’) from HF init context list when we want real-device init. |
|
Wrapper around PreTrainedModel.get_init_context that strips meta device when requested. |
|
Get a class that combines HFCheckpointingMixin with the original model class. |
|
Locally change the torch default dtype to |
|
Check if a HuggingFace config is compatible with our custom model implementation. |
|
Resolve the custom model class for config, if the config is compatible. |
|
Get the HF config for the model. |
|
Load an HF config after truncating |
|
Determine whether the model should use the HF (not custom) implementation. |
|
Configure kwargs for HF from_pretrained to work with BitsAndBytes quantization. |
|
Resolve a HF repo id or local path to a local directory with model files. |
|
Check whether a model directory contains safetensors checkpoint files. |
|
Load safetensor shards one-at-a-time, quantizing BnB Params4bit on the fly. |
|
Whether streaming BnB can safely load HF safetensors directly into the target class. |
|
Create model on meta device, replace Linear→Linear4bit, stream-load+quantize. |
|
Return a parameter or buffer by its fully-qualified state-dict key. |
|
Restore each loaded tensor to the exact dtype stored in the checkpoint. |
|
Get the architectures from the HF config. |
|
Best-effort extraction of explicit init parameter names (excluding |
|
Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs. |
|
Filter kwargs down to what |
|
Resolve SDPA backend list from config strings or runtime constraints. |
Data#
API#
- nemo_automodel._transformers.model_init.logger#
‘getLogger(…)’
- nemo_automodel._transformers.model_init._hf_meta_device_disabled#
‘local(…)’
- nemo_automodel._transformers.model_init._get_hf_meta_device_disabled()#
- nemo_automodel._transformers.model_init.no_hf_meta_device()#
Disable HuggingFace’s meta device in get_init_context so model is built on real device.
- nemo_automodel._transformers.model_init._filter_meta_device_from_init_context(contexts)#
Remove torch.device(‘meta’) from HF init context list when we want real-device init.
- nemo_automodel._transformers.model_init._patched_get_init_context(cls, *args, **kwargs)#
Wrapper around PreTrainedModel.get_init_context that strips meta device when requested.
- nemo_automodel._transformers.model_init._original_get_init_context#
None
- nemo_automodel._transformers.model_init._get_mixin_wrapped_class(model_class: type) type#
Get a class that combines HFCheckpointingMixin with the original model class.
If the class already has the mixin, returns it unchanged.
- Parameters:
model_class – The original model class (e.g., LlamaForCausalLM)
- Returns:
A class that inherits from both HFCheckpointingMixin and model_class
- nemo_automodel._transformers.model_init.local_torch_dtype(
- dtype: torch.dtype,
- model_class_name: str | None = None,
- default_dtype: torch.dtype = torch.bfloat16,
Locally change the torch default dtype to
dtype, and restore the old one upon exiting the context. Ifmodel_class_nameis provided, it’s used to provide a more helpful error message ifdtypeis not valid.
- nemo_automodel._transformers.model_init._is_config_compatible_with_custom_model(
- arch_name: str,
- config,
Check if a HuggingFace config is compatible with our custom model implementation.
Some architectures (e.g., NemotronHForCausalLM) are shared between different model versions (v2 vs v3) but our custom implementation only supports specific versions. This function validates that the config has the required attributes for the custom implementation.
- Parameters:
arch_name – The architecture name (e.g., “NemotronHForCausalLM”)
config – The HuggingFace config object
- Returns:
True if the config is compatible with our custom implementation, False otherwise
- nemo_automodel._transformers.model_init._resolve_custom_model_cls_for_config(config)#
Resolve the custom model class for config, if the config is compatible.
- nemo_automodel._transformers.model_init.get_hf_config(
- pretrained_model_name_or_path,
- attn_implementation,
- **kwargs,
Get the HF config for the model.
- nemo_automodel._transformers.model_init._load_config_with_layer_types_fix(
- pretrained_model_name_or_path,
- attn_implementation,
- trust_remote_code,
- **kwargs,
Load an HF config after truncating
layer_typestonum_hidden_layers.Works around buggy upstream configs whose
layer_typeslist is longer thannum_hidden_layers(e.g. stepfun-ai/Step-3.5-Flash).
- nemo_automodel._transformers.model_init.get_is_hf_model(config, force_hf)#
Determine whether the model should use the HF (not custom) implementation.
- nemo_automodel._transformers.model_init._download_model_weights(hf_config, pretrained_model_name_or_path)#
- nemo_automodel._transformers.model_init._setup_bnb_loading_kwargs(kwargs: dict) None#
Configure kwargs for HF from_pretrained to work with BitsAndBytes quantization.
Sets
device_mapso HF loads+quantizes per-shard on the current GPU, and disables the async weight loader introduced in transformers v5 which can materialize many full-precision tensors concurrently before the quantizer runs, causing OOM on memory-constrained systems.
- nemo_automodel._transformers.model_init._resolve_model_dir(pretrained_model_name_or_path: str) str#
Resolve a HF repo id or local path to a local directory with model files.
- nemo_automodel._transformers.model_init._has_safetensors(model_dir: str) bool#
Check whether a model directory contains safetensors checkpoint files.
- nemo_automodel._transformers.model_init._stream_load_bnb_weights(model, model_dir, device, torch_dtype)#
Load safetensor shards one-at-a-time, quantizing BnB Params4bit on the fly.
Peak memory ≈ (accumulated quantized weights) + (one bf16 weight tensor) instead of (full bf16 model) with standard HF loading.
- nemo_automodel._transformers.model_init._streaming_bnb_supported(cls, hf_config) bool#
Whether streaming BnB can safely load HF safetensors directly into the target class.
The streaming loader maps safetensors keys 1:1 onto
model.named_parameters(). Two cases break that 1:1 assumption and must fall back to the standard HF loader:Automodel’s custom implementations fuse projections (e.g. MoE
mlp.experts.gate_up_proj) and rely on astate_dict_adapterto translate HF-style keys on load. Detected via theHFCheckpointingMixinmarker.Vanilla HF classes whose safetensors use a legacy layout that HF’s loader reshapes/renames at load time (e.g. Mixtral
block_sparse_moe.experts.*.w1→ fusedmlp.experts.gate_up_proj). Detected via HF’s per-model-typeget_checkpoint_conversion_mapping— any non-empty mapping means the streaming path would leave fused tensors on meta device.
- nemo_automodel._transformers.model_init._init_model_bnb_streaming(
- cls,
- pretrained_model_name_or_path,
- hf_config,
- attn_implementation,
- torch_dtype,
- quantization_config,
- **kwargs,
Create model on meta device, replace Linear→Linear4bit, stream-load+quantize.
This avoids materializing the full bf16 model in memory, which is critical for unified-memory systems (e.g. DGX Spark) where CPU and GPU share the same physical memory pool.
Returns
(is_custom_model=False, model)so the caller treats it like an HF-loaded model with weights already present.
- nemo_automodel._transformers.model_init._get_model_tensor(model, name: str)#
Return a parameter or buffer by its fully-qualified state-dict key.
- nemo_automodel._transformers.model_init._restore_loaded_model_dtype(
- model,
- pretrained_model_name_or_path,
- hf_config,
- quantization_config,
- load_kwargs,
Restore each loaded tensor to the exact dtype stored in the checkpoint.
Some modules allocate parameters in a wider dtype than the checkpoint. HuggingFace then copies the checkpoint tensor into that existing tensor, which upcasts the loaded value. We fix that by re-inspecting checkpoint tensor dtypes per key and restoring each loaded parameter/buffer to the dtype that was actually stored in the file.
- nemo_automodel._transformers.model_init.__init_model(
- cls,
- pretrained_model_name_or_path_or_config,
- attn_implementation,
- torch_dtype,
- quantization_config,
- force_hf,
- *model_args,
- **kwargs,
- nemo_automodel._transformers.model_init._tie_weights_nemo(model)#
- nemo_automodel._transformers.model_init._init_model(
- cls,
- pretrained_model_name_or_path_or_config,
- attn_implementation,
- torch_dtype,
- quantization_config,
- force_hf,
- *model_args,
- **kwargs,
- nemo_automodel._transformers.model_init.get_architectures(hf_config)#
Get the architectures from the HF config.
- nemo_automodel._transformers.model_init._get_init_param_names(model_cls) set[str]#
Best-effort extraction of explicit init parameter names (excluding
self).Returns an empty set if the signature cannot be inspected.
- nemo_automodel._transformers.model_init._consume_config_overrides(
- config,
- kwargs: dict,
- *,
- init_param_names: set[str] | None = None,
Mimic HF from_pretrained behavior: treat config-related kwargs as config overrides, not model init kwargs.
For custom model implementations we instantiate via
model_cls(config, **kwargs), so passing config flags likeoutput_hidden_stateswould crash. This helper moves such keys onto the config and removes them fromkwargs.
- nemo_automodel._transformers.model_init._filter_kwargs_for_init(model_cls, kwargs: dict) dict#
Filter kwargs down to what
model_cls.__init__explicitly accepts.If the constructor has a
**kwargsparameter (VAR_KEYWORD) or signature cannot be inspected, returns kwargs unchanged.
- nemo_automodel._transformers.model_init.resolve_sdpa_method(
- sdpa_method: list | None = None,
- device_mesh=None,
- activation_checkpointing: bool = False,
Resolve SDPA backend list from config strings or runtime constraints.
When sdpa_method is provided (e.g. from YAML), string values are converted to :class:
torch.nn.attention.SDPBackendenum members. Already-resolvedSDPBackendvalues are passed through unchanged. WhenNone, automatic defaults are applied based on context parallelism and activation checkpointing settings.Valid string values (case-insensitive):
flash_attention,efficient_attention,math,cudnn_attention.- Parameters:
sdpa_method – List of backend name strings or SDPBackend enum values, or
Noneto use automatic defaults.device_mesh – Device mesh for distributed training.
activation_checkpointing – Whether activation checkpointing is enabled.
- Returns:
Ordered list of :class:
SDPBackendmembers, orNoneto use PyTorch’s default selection.