> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/automodel/llms.txt.
> For AI client integration (Claude Code, Cursor, etc.), connect to the MCP server at https://docs.nvidia.com/nemo/automodel/_mcp/server.

# nemo_automodel.components.checkpoint.stateful_wrappers

## Module Contents

### Classes

| Name                                                                                       | Description                                                             |
| ------------------------------------------------------------------------------------------ | ----------------------------------------------------------------------- |
| [`ModelState`](#nemo_automodel-components-checkpoint-stateful_wrappers-ModelState)         | Helper class for tracking model state in distributed checkpointing.     |
| [`OptimizerState`](#nemo_automodel-components-checkpoint-stateful_wrappers-OptimizerState) | Helper class for tracking optimizer state in distributed checkpointing. |

### Functions

| Name                                                                                                                   | Description                                                                                 |
| ---------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- |
| [`_add_outer_prefix`](#nemo_automodel-components-checkpoint-stateful_wrappers-_add_outer_prefix)                       | Prepend `prefix` once to every key in-place (inverse of `_drop_outer_prefix`).              |
| [`_drop_outer_prefix`](#nemo_automodel-components-checkpoint-stateful_wrappers-_drop_outer_prefix)                     | Remove the *first* occurrence of `prefix` on every key in-place.                            |
| [`_get_lm_head_weight_and_name`](#nemo_automodel-components-checkpoint-stateful_wrappers-_get_lm_head_weight_and_name) | -                                                                                           |
| [`_get_peft_state_dict`](#nemo_automodel-components-checkpoint-stateful_wrappers-_get_peft_state_dict)                 | Extract only trainable PEFT adapter weights, bypassing DCP.                                 |
| [`_has_expert_parallelism`](#nemo_automodel-components-checkpoint-stateful_wrappers-_has_expert_parallelism)           | Check if any MoE expert module in the model has expert parallelism enabled.                 |
| [`_has_quantized_params`](#nemo_automodel-components-checkpoint-stateful_wrappers-_has_quantized_params)               | Check if model has any BitsAndBytes quantized modules.                                      |
| [`_is_quantized_module`](#nemo_automodel-components-checkpoint-stateful_wrappers-_is_quantized_module)                 | Check if a module is a BitsAndBytes quantized type.                                         |
| [`_rename_dora_keys_from_hf`](#nemo_automodel-components-checkpoint-stateful_wrappers-_rename_dora_keys_from_hf)       | Reverse of \_rename\_dora\_keys\_to\_hf: convert HF PEFT key format back to internal names. |
| [`_rename_dora_keys_to_hf`](#nemo_automodel-components-checkpoint-stateful_wrappers-_rename_dora_keys_to_hf)           | Rename DoRA magnitude keys to match HF PEFT's saved checkpoint format in-place.             |
| [`_safe_op_set_extra_state`](#nemo_automodel-components-checkpoint-stateful_wrappers-_safe_op_set_extra_state)         | -                                                                                           |
| [`_safe_set_extra_state`](#nemo_automodel-components-checkpoint-stateful_wrappers-_safe_set_extra_state)               | -                                                                                           |
| [`_set_peft_state_dict`](#nemo_automodel-components-checkpoint-stateful_wrappers-_set_peft_state_dict)                 | Load trainable PEFT adapter weights into the model, bypassing DCP.                          |

### Data

[`_PREFIX`](#nemo_automodel-components-checkpoint-stateful_wrappers-_PREFIX)

[`_original_op_set_extra_state`](#nemo_automodel-components-checkpoint-stateful_wrappers-_original_op_set_extra_state)

[`_original_set_extra_state`](#nemo_automodel-components-checkpoint-stateful_wrappers-_original_set_extra_state)

### API

```python
class nemo_automodel.components.checkpoint.stateful_wrappers.ModelState(
    model: torch.nn.Module | list[torch.nn.Module],
    is_peft: bool = False,
    is_init_step: bool = False,
    skip_task_head_prefixes: list[str] | None = None
)
```

Helper class for tracking model state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically
call state\_dict/load\_state\_dict as needed in the dcp.save/load APIs.

**Parameters:**

The PyTorch model to track.

```python
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._get_base_model_state_dict() -> dict[str, typing.Any]
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._refresh_local_tied_lm_head() -> None
```

Refresh tied-head metadata after DCP has normalized module state.

```python
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState._set_base_model_state_dict(
    state_dict: dict[str, typing.Any]
) -> None
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState.load_state_dict(
    state_dict: dict[str, typing.Any],
    strict: bool = True
) -> None
```

Load the state dictionary into the model.

**Parameters:**

State dictionary to load.

```python
nemo_automodel.components.checkpoint.stateful_wrappers.ModelState.state_dict() -> dict[str, typing.Any]
```

Get the model's state dictionary.

**Returns:** `dict[str, Any]`

Dictionary containing the model's state dict with CPU offloading enabled.

```python
class nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState(
    model: torch.nn.Module | list[torch.nn.Module],
    optimizer: torch.optim.Optimizer,
    scheduler: typing.Optional[typing.Any] = None,
    is_peft: bool = False
)
```

Helper class for tracking optimizer state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically
call state\_dict/load\_state\_dict as needed in the dcp.save/load APIs.

**Parameters:**

The PyTorch model associated with the optimizer.

The optimizer to track.

Optional learning rate scheduler.

```python
nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState.load_state_dict(
    state_dict: dict[str, typing.Any]
) -> None
```

Load the state dictionaries into the optimizer and scheduler.

**Parameters:**

State dictionary containing optimizer and scheduler states to load.

```python
nemo_automodel.components.checkpoint.stateful_wrappers.OptimizerState.state_dict() -> dict[str, typing.Any]
```

Get the optimizer and scheduler state dictionaries.

**Returns:** `dict[str, Any]`

Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._add_outer_prefix(
    sd: dict[str, typing.Any],
    prefix: str = _PREFIX,
    skip_keys: list[str] | None = None
) -> None
```

Prepend `prefix` once to every key in-place (inverse of `_drop_outer_prefix`).

```python
nemo_automodel.components.checkpoint.stateful_wrappers._drop_outer_prefix(
    sd: dict[str, typing.Any],
    prefix: str = _PREFIX
) -> None
```

Remove the *first* occurrence of `prefix` on every key in-place.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._get_lm_head_weight_and_name(
    model: torch.nn.Module
) -> typing.Optional[tuple[torch.Tensor, str]]
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers._get_peft_state_dict(
    model: torch.nn.Module
) -> dict[str, typing.Any]
```

Extract only trainable PEFT adapter weights, bypassing DCP.

This function directly iterates over model parameters to collect trainable weights,
avoiding PyTorch DCP's state\_dict traversal which fails on (1) BitsAndBytes quantized
modules (Params4bit, Int8Params, etc.) and (2) MoE models with expert parallelism
where expert weights are sharded across EP ranks.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._has_expert_parallelism(
    model: torch.nn.Module
) -> bool
```

Check if any MoE expert module in the model has expert parallelism enabled.

After EP initialization, expert modules (GroupedExpertsDeepEP, GroupedExpertsTE)
store `ep_size` on themselves. A value > 1 signals that expert weights are
sharded across EP ranks and DCP's state\_dict APIs cannot handle them.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._has_quantized_params(
    model: torch.nn.Module
) -> bool
```

Check if model has any BitsAndBytes quantized modules.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._is_quantized_module(
    module: torch.nn.Module
) -> bool
```

Check if a module is a BitsAndBytes quantized type.

Detects quantization by checking for `quant_state` attribute which is
common across BitsAndBytes quantized module types (Params4bit, Int8Params, etc.).

```python
nemo_automodel.components.checkpoint.stateful_wrappers._rename_dora_keys_from_hf(
    sd: dict[str, typing.Any]
) -> None
```

Reverse of \_rename\_dora\_keys\_to\_hf: convert HF PEFT key format back to internal names.

Handles both the current on-disk format (`&lt;module&gt;.lora_magnitude_vector`)
and the legacy format that included `.default.weight` for robustness.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._rename_dora_keys_to_hf(
    sd: dict[str, typing.Any]
) -> None
```

Rename DoRA magnitude keys to match HF PEFT's saved checkpoint format in-place.

HF PEFT's `get_peft_model_state_dict` strips the adapter name and the
`.weight` suffix from `lora_magnitude_vector.&lt;adapter&gt;.&lt;weight&gt;` so the
round-trip format on disk is simply `&lt;module&gt;.lora_magnitude_vector`.
When loading, `set_peft_model_state_dict` re-inserts the adapter name
and the `.weight` suffix automatically, so we must NOT include them here.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._safe_op_set_extra_state(
    self,
    state
)
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers._safe_set_extra_state(
    self,
    state
)
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers._set_peft_state_dict(
    model: torch.nn.Module,
    state_dict: dict[str, typing.Any]
) -> None
```

Load trainable PEFT adapter weights into the model, bypassing DCP.

Mirrors \_get\_peft\_state\_dict: directly assigns saved tensors to model parameters
by name, handling DTensor re-sharding for EP-parallel weights. This avoids
DCP's set\_model\_state\_dict() which raises KeyError on expert-parallel FQNs.

```python
nemo_automodel.components.checkpoint.stateful_wrappers._PREFIX = 'model.'
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers._original_op_set_extra_state = te_ops.BasicOperation.set_extra_state
```

```python
nemo_automodel.components.checkpoint.stateful_wrappers._original_set_extra_state = te_base.TransformerEngineBaseModule.set_extra_state
```