nemo_automodel.components.checkpoint.stateful_wrappers
nemo_automodel.components.checkpoint.stateful_wrappers
Module Contents
Classes
Functions
Data
API
Helper class for tracking model state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
Parameters:
The PyTorch model to track.
Refresh tied-head metadata after DCP has normalized module state.
Load the state dictionary into the model.
Parameters:
State dictionary to load.
Get the model’s state dictionary.
Returns: dict[str, Any]
Dictionary containing the model’s state dict with CPU offloading enabled.
Helper class for tracking optimizer state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
Parameters:
The PyTorch model associated with the optimizer.
The optimizer to track.
Optional learning rate scheduler.
Load the state dictionaries into the optimizer and scheduler.
Parameters:
State dictionary containing optimizer and scheduler states to load.
Get the optimizer and scheduler state dictionaries.
Returns: dict[str, Any]
Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.
Prepend prefix once to every key in-place (inverse of _drop_outer_prefix).
Remove the first occurrence of prefix on every key in-place.
Extract only trainable PEFT adapter weights, bypassing DCP.
This function directly iterates over model parameters to collect trainable weights, avoiding PyTorch DCP’s state_dict traversal which fails on (1) BitsAndBytes quantized modules (Params4bit, Int8Params, etc.) and (2) MoE models with expert parallelism where expert weights are sharded across EP ranks.
Check if any MoE expert module in the model has expert parallelism enabled.
After EP initialization, expert modules (GroupedExpertsDeepEP, GroupedExpertsTE)
store ep_size on themselves. A value > 1 signals that expert weights are
sharded across EP ranks and DCP’s state_dict APIs cannot handle them.
Check if model has any BitsAndBytes quantized modules.
Check if a module is a BitsAndBytes quantized type.
Detects quantization by checking for quant_state attribute which is
common across BitsAndBytes quantized module types (Params4bit, Int8Params, etc.).
Reverse of _rename_dora_keys_to_hf: convert HF PEFT key format back to internal names.
Handles both the current on-disk format (<module>.lora_magnitude_vector)
and the legacy format that included .default.weight for robustness.
Rename DoRA magnitude keys to match HF PEFT’s saved checkpoint format in-place.
HF PEFT’s get_peft_model_state_dict strips the adapter name and the
.weight suffix from lora_magnitude_vector.<adapter>.<weight> so the
round-trip format on disk is simply <module>.lora_magnitude_vector.
When loading, set_peft_model_state_dict re-inserts the adapter name
and the .weight suffix automatically, so we must NOT include them here.
Load trainable PEFT adapter weights into the model, bypassing DCP.
Mirrors _get_peft_state_dict: directly assigns saved tensors to model parameters by name, handling DTensor re-sharding for EP-parallel weights. This avoids DCP’s set_model_state_dict() which raises KeyError on expert-parallel FQNs.