nemo_automodel.components.checkpoint.stateful_wrappers

View as Markdown

Module Contents

Classes

NameDescription
ModelStateHelper class for tracking model state in distributed checkpointing.
OptimizerStateHelper class for tracking optimizer state in distributed checkpointing.

Functions

NameDescription
_add_outer_prefixPrepend prefix once to every key in-place (inverse of _drop_outer_prefix).
_drop_outer_prefixRemove the first occurrence of prefix on every key in-place.
_get_lm_head_weight_and_name-
_get_peft_state_dictExtract only trainable PEFT adapter weights, bypassing DCP.
_has_expert_parallelismCheck if any MoE expert module in the model has expert parallelism enabled.
_has_quantized_paramsCheck if model has any BitsAndBytes quantized modules.
_is_quantized_moduleCheck if a module is a BitsAndBytes quantized type.
_rename_dora_keys_from_hfReverse of _rename_dora_keys_to_hf: convert HF PEFT key format back to internal names.
_rename_dora_keys_to_hfRename DoRA magnitude keys to match HF PEFT’s saved checkpoint format in-place.
_safe_op_set_extra_state-
_safe_set_extra_state-
_set_peft_state_dictLoad trainable PEFT adapter weights into the model, bypassing DCP.

Data

_PREFIX

_original_op_set_extra_state

_original_set_extra_state

API

class nemo_automodel.components.checkpoint.stateful_wrappers.ModelState(
model: torch.nn.Module | list[torch.nn.Module],
is_peft: bool = False,
is_init_step: bool = False,
skip_task_head_prefixes: list[str] | None = None
)

Helper class for tracking model state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:

model
torch.nn.Module | list[torch.nn.Module]

The PyTorch model to track.

has_local_tied_lm_head
= has_local_tied_lm_head(self.model[0])
model
skip_task_head_prefixes
= skip_task_head_prefixes or []
uses_tied_lm_head
= is_tied_word_embeddings(self.model[0])
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._get_base_model_state_dict() -> dict[str, typing.Any]
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._refresh_local_tied_lm_head() -> None

Refresh tied-head metadata after DCP has normalized module state.

nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._set_base_model_state_dict(
state_dict: dict[str, typing.Any]
) -> None
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState.load_state_dict(
state_dict: dict[str, typing.Any],
strict: bool = True
) -> None

Load the state dictionary into the model.

Parameters:

state_dict
dict

State dictionary to load.

nemo_automodel.components.checkpoint.stateful_wrappers.ModelState.state_dict() -> dict[str, typing.Any]

Get the model’s state dictionary.

Returns: dict[str, Any]

Dictionary containing the model’s state dict with CPU offloading enabled.

class nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState(
model: torch.nn.Module | list[torch.nn.Module],
optimizer: torch.optim.Optimizer,
scheduler: typing.Optional[typing.Any] = None,
is_peft: bool = False
)

Helper class for tracking optimizer state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:

model
torch.nn.Module | list[torch.nn.Module]

The PyTorch model associated with the optimizer.

optimizer
torch.optim.Optimizer

The optimizer to track.

scheduler
Optional[Any]Defaults to None

Optional learning rate scheduler.

model
optimizer
scheduler
nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState.load_state_dict(
state_dict: dict[str, typing.Any]
) -> None

Load the state dictionaries into the optimizer and scheduler.

Parameters:

state_dict
dict

State dictionary containing optimizer and scheduler states to load.

nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState.state_dict() -> dict[str, typing.Any]

Get the optimizer and scheduler state dictionaries.

Returns: dict[str, Any]

Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.

nemo_automodel.components.checkpoint.stateful_wrappers._add_outer_prefix(
sd: dict[str, typing.Any],
prefix: str = _PREFIX,
skip_keys: list[str] | None = None
) -> None

Prepend prefix once to every key in-place (inverse of _drop_outer_prefix).

nemo_automodel.components.checkpoint.stateful_wrappers._drop_outer_prefix(
sd: dict[str, typing.Any],
prefix: str = _PREFIX
) -> None

Remove the first occurrence of prefix on every key in-place.

nemo_automodel.components.checkpoint.stateful_wrappers._get_lm_head_weight_and_name(
model: torch.nn.Module
) -> typing.Optional[tuple[torch.Tensor, str]]
nemo_automodel.components.checkpoint.stateful_wrappers._get_peft_state_dict(
model: torch.nn.Module
) -> dict[str, typing.Any]

Extract only trainable PEFT adapter weights, bypassing DCP.

This function directly iterates over model parameters to collect trainable weights, avoiding PyTorch DCP’s state_dict traversal which fails on (1) BitsAndBytes quantized modules (Params4bit, Int8Params, etc.) and (2) MoE models with expert parallelism where expert weights are sharded across EP ranks.

nemo_automodel.components.checkpoint.stateful_wrappers._has_expert_parallelism(
model: torch.nn.Module
) -> bool

Check if any MoE expert module in the model has expert parallelism enabled.

After EP initialization, expert modules (GroupedExpertsDeepEP, GroupedExpertsTE) store ep_size on themselves. A value > 1 signals that expert weights are sharded across EP ranks and DCP’s state_dict APIs cannot handle them.

nemo_automodel.components.checkpoint.stateful_wrappers._has_quantized_params(
model: torch.nn.Module
) -> bool

Check if model has any BitsAndBytes quantized modules.

nemo_automodel.components.checkpoint.stateful_wrappers._is_quantized_module(
module: torch.nn.Module
) -> bool

Check if a module is a BitsAndBytes quantized type.

Detects quantization by checking for quant_state attribute which is common across BitsAndBytes quantized module types (Params4bit, Int8Params, etc.).

nemo_automodel.components.checkpoint.stateful_wrappers._rename_dora_keys_from_hf(
sd: dict[str, typing.Any]
) -> None

Reverse of _rename_dora_keys_to_hf: convert HF PEFT key format back to internal names.

Handles both the current on-disk format (<module>.lora_magnitude_vector) and the legacy format that included .default.weight for robustness.

nemo_automodel.components.checkpoint.stateful_wrappers._rename_dora_keys_to_hf(
sd: dict[str, typing.Any]
) -> None

Rename DoRA magnitude keys to match HF PEFT’s saved checkpoint format in-place.

HF PEFT’s get_peft_model_state_dict strips the adapter name and the .weight suffix from lora_magnitude_vector.<adapter>.<weight> so the round-trip format on disk is simply <module>.lora_magnitude_vector. When loading, set_peft_model_state_dict re-inserts the adapter name and the .weight suffix automatically, so we must NOT include them here.

nemo_automodel.components.checkpoint.stateful_wrappers._safe_op_set_extra_state(
self,
state
)
nemo_automodel.components.checkpoint.stateful_wrappers._safe_set_extra_state(
self,
state
)
nemo_automodel.components.checkpoint.stateful_wrappers._set_peft_state_dict(
model: torch.nn.Module,
state_dict: dict[str, typing.Any]
) -> None

Load trainable PEFT adapter weights into the model, bypassing DCP.

Mirrors _get_peft_state_dict: directly assigns saved tensors to model parameters by name, handling DTensor re-sharding for EP-parallel weights. This avoids DCP’s set_model_state_dict() which raises KeyError on expert-parallel FQNs.

nemo_automodel.components.checkpoint.stateful_wrappers._PREFIX = 'model.'
nemo_automodel.components.checkpoint.stateful_wrappers._original_op_set_extra_state = te_ops.BasicOperation.set_extra_state
nemo_automodel.components.checkpoint.stateful_wrappers._original_set_extra_state = te_base.TransformerEngineBaseModule.set_extra_state