nemo_automodel.components.utils.model_utils

View as Markdown

Module Contents

Functions

NameDescription
_freeze_module_by_attribute_and_patternsHelper function to freeze parameters by attribute name and name patterns.
_get_forward_signatureBest-effort retrieval of model.forward signature.
_get_logical_numelReturn the logical number of elements for a parameter,
_get_model_param_statsGet the number of trainable parameters and the L2 norm of the model.
_supports_logits_to_keepCheck if the model supports logits_to_keep.
_supports_seq_lensCheck if the model’s forward() accepts seq_lens.
apply_parameter_freezingApply parameter freezing based on configuration.
cast_mixed_dtype_params_to_bf16Cast fp32 parameters and buffers to bf16 for FSDP2 compatibility.
count_model_parametersCount total and trainable parameters. Safe to call on meta-device models.
enable_radio_vit_fused_attnRoute RADIO ViT attention through F.scaled_dot_product_attention.
filter_forward_kwargsDrop kwargs that model.forward does not accept.
freeze_deepseek_v4_indexer_paramsFreeze DeepSeek V4 indexer params that only feed discrete top-k masks.
freeze_minimax_m3_indexer_paramsFreeze MiniMax M3 lightning-indexer params that only feed discrete top-k masks.
freeze_unused_kv_sharing_paramsFreeze dead K/V parameters in KV-shared layers.
get_lm_head_moduleReturn the model’s LM head module, if one can be found.
get_lm_head_weightReturn the model’s LM-head weight, materializing DTensor weights when needed.
init_empty_weightsA context manager under which models are initialized with all parameters on the specified device.
print_trainable_parametersPrint the number of trainable parameters in the model.
resolve_trust_remote_codeWhitelist NVIDIA models to allow remote code execution.
skip_random_initContext manager to skip random weight initialization when loading pretrained models.
squeeze_input_for_thdSqueeze batch dimension and prepare inputs for THD (total, hidden, depth) format.

Data

VLM_INPUT_KEYS

logger

API

nemo_automodel.components.utils.model_utils._freeze_module_by_attribute_and_patterns(
model,
attribute_name,
name_patterns
)

Helper function to freeze parameters by attribute name and name patterns.

Parameters:

model

The model to apply freezing to.

attribute_name

Name of the model attribute to freeze (e.g., ‘vision_tower’).

name_patterns

List of patterns to match in module names.

nemo_automodel.components.utils.model_utils._get_forward_signature(
model: torch.nn.Module
) -> inspect.Signature | None

Best-effort retrieval of model.forward signature.

nemo_automodel.components.utils.model_utils._get_logical_numel(
param
) -> int

Return the logical number of elements for a parameter, accounting for quantized (packed) storage.

For bitsandbytes 4-bit params (Params4bit), the physical tensor packs multiple values per byte. We recover the logical count from the original shape stored in param.quant_state.

nemo_automodel.components.utils.model_utils._get_model_param_stats(
model: torch.nn.Module
) -> tuple[int, int, float]

Get the number of trainable parameters and the L2 norm of the model.

Parameters:

model
nn.Module

Model to analyze

Returns: int

int

nemo_automodel.components.utils.model_utils._supports_logits_to_keep(
model: torch.nn.Module
) -> bool

Check if the model supports logits_to_keep.

Parameters:

model
nn.Module

The model to check.

Returns: bool

True if the model supports logits_to_keep, False otherwise.

nemo_automodel.components.utils.model_utils._supports_seq_lens(
model: torch.nn.Module
) -> bool

Check if the model’s forward() accepts seq_lens.

Returns True if:

  • forward() has an explicit seq_lens parameter, OR
  • forward() has **kwargs (so it won’t crash if seq_lens is passed)

Returns False otherwise (passing seq_lens would cause “unexpected kwarg” error).

nemo_automodel.components.utils.model_utils.apply_parameter_freezing(
model,
freeze_config
)

Apply parameter freezing based on configuration.

Parameters:

model

The model to apply freezing to.

freeze_config

Configuration dict specifying what to freeze.

nemo_automodel.components.utils.model_utils.cast_mixed_dtype_params_to_bf16(
model
)

Cast fp32 parameters and buffers to bf16 for FSDP2 compatibility.

nemo_automodel.components.utils.model_utils.count_model_parameters(
model: torch.nn.Module
) -> tuple[int, int]

Count total and trainable parameters. Safe to call on meta-device models.

Parameters:

model
nn.Module

Model to analyze

Returns: int

int

nemo_automodel.components.utils.model_utils.enable_radio_vit_fused_attn(
model
)

Route RADIO ViT attention through F.scaled_dot_product_attention.

RADIO’s timm Attention blocks default to fused_attn=False, which materializes the full (B, H, seq, seq) attention tensor (~5 GiB per block at RADIO-v2-H + dynamic-resolution patch counts). Flipping fused_attn=True matches the Megatron-Bridge path which sets vision_config.use_flash_attn=True via attn_implementation="flash_attention_2".

No-op when the model has no RADIO vision tower.

Parameters:

model

The model to patch in place.

nemo_automodel.components.utils.model_utils.filter_forward_kwargs(
model: torch.nn.Module,
kwargs: dict
) -> dict

Drop kwargs that model.forward does not accept.

If the model exposes **kwargs or its signature cannot be inspected, the input kwargs are returned unchanged. The original dict is never mutated.

nemo_automodel.components.utils.model_utils.freeze_deepseek_v4_indexer_params(
model
)

Freeze DeepSeek V4 indexer params that only feed discrete top-k masks.

nemo_automodel.components.utils.model_utils.freeze_minimax_m3_indexer_params(
model
)

Freeze MiniMax M3 lightning-indexer params that only feed discrete top-k masks.

nemo_automodel.components.utils.model_utils.freeze_unused_kv_sharing_params(
model
)

Freeze dead K/V parameters in KV-shared layers.

Models like Gemma4 E2B/E4B use KV-sharing where the last N layers reuse key/value states from earlier layers. The k_proj, v_proj, k_norm, and v_norm modules still exist in those shared layers but are never used during forward. Their parameters therefore receive no gradients, yet the optimizer still tracks them. On checkpoint resume the distributed checkpoint framework expects optimizer state for every parameter the optimizer was created with, but zero-gradient params may have been excluded from the saved state — causing a RuntimeError.

Calling this function before optimizer creation sets requires_grad=False on the dead parameters so the optimizer never tracks them, keeping save and load consistent.

Parameters:

model

The model (or pipeline-parallel model part).

nemo_automodel.components.utils.model_utils.get_lm_head_module(
model: torch.nn.Module
) -> torch.nn.Module | None

Return the model’s LM head module, if one can be found.

nemo_automodel.components.utils.model_utils.get_lm_head_weight(
model: torch.nn.Module
) -> torch.Tensor

Return the model’s LM-head weight, materializing DTensor weights when needed.

nemo_automodel.components.utils.model_utils.init_empty_weights()

A context manager under which models are initialized with all parameters on the specified device.

Example:

import torch.nn as nn
from nemo_automodel.components.utils.model_utils import init_empty_weights
with init_empty_weights():
tst = nn.Linear(100, 100) # on `cuda` device

Parameters:

device
`torch.device`

Device to initialize all parameters on.

nemo_automodel.components.utils.model_utils.print_trainable_parameters(
model: torch.nn.Module,
name: str = 'Model'
) -> tuple[int, int]

Print the number of trainable parameters in the model.

Parameters:

model
nn.Module

Model to analyze

name
strDefaults to 'Model'

Label for the summary header (e.g. "Draft" to distinguish the draft model from the target in speculative-decoding training).

Returns: int

int

nemo_automodel.components.utils.model_utils.resolve_trust_remote_code(
pretrained_model_name_or_path
)

Whitelist NVIDIA models to allow remote code execution.

Parameters:

pretrained_model_name_or_path
str

The name or path of the pretrained model.

Returns:

True if the model should be loaded with trust_remote_code, False otherwise.

nemo_automodel.components.utils.model_utils.skip_random_init()

Context manager to skip random weight initialization when loading pretrained models.

nemo_automodel.components.utils.model_utils.squeeze_input_for_thd(
input_ids,
position_ids,
padding_mask,
attn_kwargs,
seqlens_padding_value = -1000
)

Squeeze batch dimension and prepare inputs for THD (total, hidden, depth) format.

This function removes the batch dimension from input tensors and processes attention kwargs for use with Transformer Engine’s THD format. It’s typically used when the batch has already been converted to THD format (with batch_size=1 as a placeholder dimension) and that dimension needs to be removed.

The function performs three key operations:

  1. Removes the batch dimension (dim 0) from input tensors
  2. Filters out padding values from cumulative sequence length tensors
  3. Converts max_seqlen from tensor to scalar if needed

Parameters:

input_ids
torch.Tensor or None

Input token IDs with shape [1, total_tokens] or [1, total_tokens, hidden_dim]. The first dimension will be squeezed. None is permitted when the caller is feeding the model via inputs_embeds instead — embeddings are squeezed inside the model forward (the squeezed_for_thd branch in NemotronHModel.forward and analogous code paths), so this helper has nothing to squeeze and simply returns None for the input_ids slot.

position_ids
torch.Tensor

Position IDs with shape [1, total_tokens]. The first dimension will be squeezed.

padding_mask
torch.Tensor

Padding mask with shape [1, total_tokens]. The first dimension will be squeezed.

attn_kwargs
dict

Dictionary of attention-related tensors. May contain:

  • cu_seqlens: Cumulative sequence lengths [1, num_seqs+1]
  • cu_seqlens_padded: Cumulative padded sequence lengths [1, num_seqs+1]
  • max_seqlen: Maximum sequence length (tensor or int)
  • Other attention parameters (will be squeezed if tensors)
seqlens_padding_value
intDefaults to -1000

Sentinel value used to indicate padding in cu_seqlens and cu_seqlens_padded tensors. These values will be filtered out. Default: -1000.

Returns:

A tuple containing:

  • input_ids (torch.Tensor): Input IDs with batch dimension removed [total_tokens] or [total_tokens, hidden_dim]
  • position_ids (torch.Tensor): Position IDs with batch dimension removed [total_tokens]
  • padding_mask (torch.Tensor): Padding mask with batch dimension removed [total_tokens]
  • attn_kwargs (dict): Updated attention kwargs with:
    • Batch dimensions removed from all tensor values
    • Padding values filtered from cu_seqlens and cu_seqlens_padded
    • max_seqlen converted to scalar if it was a tensor
nemo_automodel.components.utils.model_utils.VLM_INPUT_KEYS: tuple[str, ...] = ('input_ids', 'pixel_values', 'image_flags', 'imgs_sizes', 'image_position_ids',...
nemo_automodel.components.utils.model_utils.logger = logging.getLogger(__name__)