nemo_automodel.checkpoint.stateful_wrappers#

Module Contents#

Classes#

ModelState

Helper class for tracking model state in distributed checkpointing.

OptimizerState

Helper class for tracking optimizer state in distributed checkpointing.

Functions#

_drop_outer_prefix

Remove the first occurrence of prefix on every key in-place.

_add_outer_prefix

Prepend prefix once to every key in-place (inverse of _drop_outer_prefix).

_get_lm_head_weight_and_name

Data#

API#

nemo_automodel.checkpoint.stateful_wrappers._PREFIX#

‘model.’

nemo_automodel.checkpoint.stateful_wrappers._drop_outer_prefix(
sd: dict[str, Any],
prefix: str = _PREFIX,
) None[source]#

Remove the first occurrence of prefix on every key in-place.

nemo_automodel.checkpoint.stateful_wrappers._add_outer_prefix(
sd: dict[str, Any],
prefix: str = _PREFIX,
) None[source]#

Prepend prefix once to every key in-place (inverse of _drop_outer_prefix).

nemo_automodel.checkpoint.stateful_wrappers._get_lm_head_weight_and_name(
model: torch.nn.Module,
) Optional[tuple[torch.Tensor, str]][source]#
class nemo_automodel.checkpoint.stateful_wrappers.ModelState(
model: torch.nn.Module,
serialization_format: nemo_automodel.checkpoint._backports.filesystem.SerializationFormat,
is_peft: bool = False,
)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking model state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:

model – The PyTorch model to track.

Initialization

Initialize a ModelState instance for distributed checkpointing.

The constructor records the model reference, detects whether the model ties its language-model head to the input embeddings, and stores the desired serialization backend so that DCP can correctly save and restore the model’s parameters and buffers.

Parameters:
  • model (torch.nn.Module) – The PyTorch model whose state should be captured during checkpointing.

  • serialization_format (SerializationFormat) – Backend/format to use when persisting the model state (e.g., torch, safetensors).

  • is_peft (bool) – Whether the model is PEFT.

state_dict() dict[str, Any][source]#

Get the model’s state dictionary.

Returns:

Dictionary containing the model’s state dict with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict: dict[str, Any]) None[source]#

Load the state dictionary into the model.

Parameters:

state_dict (dict) – State dictionary to load.

class nemo_automodel.checkpoint.stateful_wrappers.OptimizerState(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: Optional[Any] = None,
)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking optimizer state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:
  • model – The PyTorch model associated with the optimizer.

  • optimizer – The optimizer to track.

  • scheduler – Optional learning rate scheduler.

Initialization

Initialize an OptimizerState instance.

The constructor simply stores references to the model, optimizer, and (optionally) learning-rate scheduler so that their state can be captured and restored by the Distributed Checkpointing (DCP) framework.

Parameters:
  • model (torch.nn.Module) – The neural-network model whose parameters the optimizer updates. Keeping the reference allows DCP to re-establish the model–optimizer relationship when loading a checkpoint.

  • optimizer (torch.optim.Optimizer) – Optimizer whose internal buffers (e.g., momentum, Adam moments, step counters) need to be saved and restored.

  • scheduler (Optional[Any], optional) – Learning-rate scheduler to track alongside the optimizer. Pass None if no scheduler is used.

state_dict() dict[str, Any][source]#

Get the optimizer and scheduler state dictionaries.

Returns:

Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict: dict[str, Any]) None[source]#

Load the state dictionaries into the optimizer and scheduler.

Parameters:

state_dict (dict) – State dictionary containing optimizer and scheduler states to load.