nemo_automodel.components.checkpoint.utils#
Module Contents#
Functions#
Whitelist NVIDIA models to allow remote code execution. |
|
Check if the model’s word embeddings are tied. |
|
Strip wrapper-specific prefixes from a parameter name. |
|
Return the first |
|
Return the input embedding weight and normalized name if present. |
|
Return candidate checkpoint keys that can source a tied LM head. |
|
Return whether the current model partition can locally satisfy a tied LM head. |
|
Populate a missing tied |
|
Inspect checkpoint tensors and return their exact dtypes by key. |
API#
- nemo_automodel.components.checkpoint.utils.resolve_trust_remote_code(pretrained_model_name_or_path)#
Whitelist NVIDIA models to allow remote code execution.
- Parameters:
pretrained_model_name_or_path (str) – The name or path of the pretrained model.
- Returns:
True if the model should be loaded with trust_remote_code, False otherwise.
- Return type:
bool
- nemo_automodel.components.checkpoint.utils.is_tied_word_embeddings(model: torch.nn.Module) bool#
Check if the model’s word embeddings are tied.
- Parameters:
model (nn.Module) – The model to check.
- Returns:
True if the model’s word embeddings are tied, False otherwise.
- Return type:
bool
- nemo_automodel.components.checkpoint.utils._normalize_param_name(name: str) str#
Strip wrapper-specific prefixes from a parameter name.
- nemo_automodel.components.checkpoint.utils.get_lm_head_weight_and_name(
- model: torch.nn.Module,
Return the first
lm_head.weightparameter found on a model.- Parameters:
model – Model to inspect.
- Returns:
Tuple of the parameter tensor and its normalized FQN, or
(None, None)when the model has no LM head weight.
- nemo_automodel.components.checkpoint.utils.get_input_embeddings_weight_and_name(
- model: torch.nn.Module,
Return the input embedding weight and normalized name if present.
- Parameters:
model – Model to inspect.
- Returns:
Tuple of the embedding weight tensor and its normalized FQN, or
(None, None)when the current model partition does not own the input embedding.
- nemo_automodel.components.checkpoint.utils.get_tied_lm_head_source_names(
- model: torch.nn.Module,
- lm_head_param_name: str | None = None,
Return candidate checkpoint keys that can source a tied LM head.
- Parameters:
model – Model or pipeline stage to inspect.
lm_head_param_name – Optional normalized LM head FQN.
- Returns:
Ordered list of possible source FQNs.
- nemo_automodel.components.checkpoint.utils.has_local_tied_lm_head(model: torch.nn.Module) bool#
Return whether the current model partition can locally satisfy a tied LM head.
This is intentionally stricter than
is_tied_word_embeddings(): pipeline stages often keep the config flag set toTrueeven thoughlm_headandembed_tokenslive on different partitions and therefore cannot be reconstructed from each other locally.Note: we purposefully do NOT check
lm_head.weight is embed_tokens.weight. After FSDP/TP sharding both are wrapped into separateDTensors and theis-identity is broken, but HF’stie_weights()can still relink them locally on load. The only case we actually need to distinguish is “is the embedding source present on this partition at all?”, which answers “can we safely omitlm_head.weightduring save and rematerialize on load?”.- Parameters:
model – Model or pipeline stage to inspect.
- Returns:
Truewhen the model is configured with tied word embeddings AND both the locallm_headand the input embedding live on this partition.Falsewhen the config isn’t tied, or when the local partition is missing one of the two (typical for PP non-last / non-first stages).
- nemo_automodel.components.checkpoint.utils.materialize_missing_tied_lm_head(
- state_dict: dict[str, Any],
- model: torch.nn.Module,
- *,
- allow_current_lm_head_fallback: bool = False,
Populate a missing tied
lm_head.weightfrom its embedding source.Hugging Face checkpoints for tied-embedding models often omit
lm_head.weightentirely. That is fine for unsplit models wheretie_weights()can restore the alias, but it breaks pipeline-parallel last stages which ownlm_headbut notembed_tokens.- Parameters:
state_dict – Checkpoint state dict to mutate in place.
model – Target model or pipeline stage.
allow_current_lm_head_fallback – If
True, fall back to the currentlm_headtensor when the tied source cannot be found instate_dict. This preserves legacy resume behavior for older checkpoints that were saved without a locallm_head.weight.
- Returns:
Trueif a missinglm_head.weightwas materialized, elseFalse.
- nemo_automodel.components.checkpoint.utils._get_checkpoint_tensor_dtypes(
- pretrained_model_name_or_path: str,
- hf_config: Any,
- load_kwargs: collections.abc.Mapping[str, object] | None = None,
Inspect checkpoint tensors and return their exact dtypes by key.
This reads checkpoint metadata only by loading tensors on the
metadevice, so it preserves the per-tensor dtype information without materializing full checkpoint weights in memory.