nemo_rl.utils.native_checkpoint
#
Checkpoint management utilities for HF models.
Module Contents#
Classes#
Helper class for tracking model state in distributed checkpointing. |
|
Helper class for tracking optimizer state in distributed checkpointing. |
Functions#
Save a checkpoint of the model and optionally optimizer state. |
|
Load a model weights and optionally optimizer state. |
|
Convert a Torch DCP checkpoint to a Hugging Face checkpoint. |
API#
- class nemo_rl.utils.native_checkpoint.ModelState(model)[source]#
Bases:
torch.distributed.checkpoint.stateful.Stateful
Helper class for tracking model state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model β The PyTorch model to track.
Initialization
- class nemo_rl.utils.native_checkpoint.OptimizerState(model, optimizer, scheduler=None)[source]#
Bases:
torch.distributed.checkpoint.stateful.Stateful
Helper class for tracking optimizer state in distributed checkpointing.
This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.
- Parameters:
model β The PyTorch model associated with the optimizer.
optimizer β The optimizer to track.
scheduler β Optional learning rate scheduler.
Initialization
- nemo_rl.utils.native_checkpoint.save_checkpoint(
- model,
- weights_path: str,
- optimizer: Optional[torch.optim.Optimizer] = None,
- scheduler: Optional[Any] = None,
- optimizer_path: Optional[str] = None,
- tokenizer: Optional[Any] = None,
- tokenizer_path: Optional[str] = None,
Save a checkpoint of the model and optionally optimizer state.
- Parameters:
model β The PyTorch model to save
weights_path β Path to save model weights
optimizer β Optional optimizer to save
scheduler β Optional scheduler to save
optimizer_path β Path to save optimizer state (required if optimizer provided)
- nemo_rl.utils.native_checkpoint.load_checkpoint(
- model,
- weights_path: str,
- optimizer: Optional[torch.optim.Optimizer] = None,
- scheduler: Optional[Any] = None,
- optimizer_path: Optional[str] = None,
Load a model weights and optionally optimizer state.
- Parameters:
model β The PyTorch model whose weights to update
weights_path β Path to load model weights from
optimizer β Optional optimizer to load state into
scheduler β Optional scheduler to load state into
optimizer_path β Path to load optimizer state from (required if optimizer provided)
- nemo_rl.utils.native_checkpoint.convert_dcp_to_hf(
- dcp_ckpt_path: str,
- hf_ckpt_path: str,
- model_name_or_path: str,
- tokenizer_name_or_path: str,
- overwrite: bool = False,
Convert a Torch DCP checkpoint to a Hugging Face checkpoint.
This is not an optimized utility. If checkpoint is too large, consider saving DCP during training and using this utility to convert to HF format.
- Parameters:
dcp_ckpt_path (str) β Path to DCP checkpoint
hf_ckpt_path (str) β Path to save HF checkpoint
model_name_or_path (str) β Model name or path for config
tokenizer_name_or_path (str, optional) β Tokenizer name or path. Defaults to model_name_or_path if None.
overwrite (bool, optional) β Whether to overwrite existing checkpoint. Defaults to False.
- Returns:
Path to the saved HF checkpoint
- Return type:
str
- Raises:
FileExistsError β If HF checkpoint already exists and overwrite is False