nemo_automodel.checkpoint.stateful_wrappers
#
Module Contents#
Classes#
Helper class for tracking model state in distributed checkpointing. |
|
Helper class for tracking optimizer state in distributed checkpointing. |
Functions#
Remove the first occurrence of |
|
Prepend |
|
Data#
API#
- nemo_automodel.checkpoint.stateful_wrappers._PREFIX#
‘model.’
- nemo_automodel.checkpoint.stateful_wrappers._drop_outer_prefix(
- sd: dict[str, Any],
- prefix: str = _PREFIX,
Remove the first occurrence of
prefix
on every key in-place.
- nemo_automodel.checkpoint.stateful_wrappers._add_outer_prefix(
- sd: dict[str, Any],
- prefix: str = _PREFIX,
Prepend
prefix
once to every key in-place (inverse of_drop_outer_prefix
).
- nemo_automodel.checkpoint.stateful_wrappers._get_lm_head_weight_and_name(
- model: torch.nn.Module,
- class nemo_automodel.checkpoint.stateful_wrappers.ModelState(
- model: torch.nn.Module,
- serialization_format: nemo_automodel.checkpoint._backports.filesystem.SerializationFormat,
- is_peft: bool = False,
Bases:
torch.distributed.checkpoint.stateful.Stateful
Helper class for tracking model state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model – The PyTorch model to track.
Initialization
Initialize a ModelState instance for distributed checkpointing.
The constructor records the model reference, detects whether the model ties its language-model head to the input embeddings, and stores the desired serialization backend so that DCP can correctly save and restore the model’s parameters and buffers.
- Parameters:
model (torch.nn.Module) – The PyTorch model whose state should be captured during checkpointing.
serialization_format (SerializationFormat) – Backend/format to use when persisting the model state (e.g., torch, safetensors).
is_peft (bool) – Whether the model is PEFT.
- class nemo_automodel.checkpoint.stateful_wrappers.OptimizerState(
- model: torch.nn.Module,
- optimizer: torch.optim.Optimizer,
- scheduler: Optional[Any] = None,
Bases:
torch.distributed.checkpoint.stateful.Stateful
Helper class for tracking optimizer state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model – The PyTorch model associated with the optimizer.
optimizer – The optimizer to track.
scheduler – Optional learning rate scheduler.
Initialization
Initialize an OptimizerState instance.
The constructor simply stores references to the model, optimizer, and (optionally) learning-rate scheduler so that their state can be captured and restored by the Distributed Checkpointing (DCP) framework.
- Parameters:
model (torch.nn.Module) – The neural-network model whose parameters the optimizer updates. Keeping the reference allows DCP to re-establish the model–optimizer relationship when loading a checkpoint.
optimizer (torch.optim.Optimizer) – Optimizer whose internal buffers (e.g., momentum, Adam moments, step counters) need to be saved and restored.
scheduler (Optional[Any], optional) – Learning-rate scheduler to track alongside the optimizer. Pass
None
if no scheduler is used.