nemo_rl.utils.native_checkpoint#

Checkpoint management utilities for HF models.

Module Contents#

Classes#

ModelState

Helper class for tracking model state in distributed checkpointing.

OptimizerState

Helper class for tracking optimizer state in distributed checkpointing.

Functions#

save_checkpoint

Save a checkpoint of the model and optionally optimizer state.

load_checkpoint

Load a model weights and optionally optimizer state.

convert_dcp_to_hf

Convert a Torch DCP checkpoint to a Hugging Face checkpoint.

API#

class nemo_rl.utils.native_checkpoint.ModelState(model)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking model state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:

model – The PyTorch model to track.

Initialization

state_dict()[source]#

Get the model’s state dictionary.

Returns:

Dictionary containing the model’s state dict with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict)[source]#

Load the state dictionary into the model.

Parameters:

state_dict (dict) – State dictionary to load.

class nemo_rl.utils.native_checkpoint.OptimizerState(model, optimizer, scheduler=None)[source]#

Bases: torch.distributed.checkpoint.stateful.Stateful

Helper class for tracking optimizer state in distributed checkpointing.

This class is compliant with the Stateful protocol, allowing DCP to automatically call state_dict/load_state_dict as needed in the dcp.save/load APIs.

Parameters:
  • model – The PyTorch model associated with the optimizer.

  • optimizer – The optimizer to track.

  • scheduler – Optional learning rate scheduler.

Initialization

state_dict()[source]#

Get the optimizer and scheduler state dictionaries.

Returns:

Dictionary containing the optimizer and scheduler state dicts with CPU offloading enabled.

Return type:

dict

load_state_dict(state_dict)[source]#

Load the state dictionaries into the optimizer and scheduler.

Parameters:

state_dict (dict) – State dictionary containing optimizer and scheduler states to load.

nemo_rl.utils.native_checkpoint.save_checkpoint(
model,
weights_path: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[Any] = None,
optimizer_path: Optional[str] = None,
tokenizer: Optional[Any] = None,
tokenizer_path: Optional[str] = None,
) None[source]#

Save a checkpoint of the model and optionally optimizer state.

Parameters:
  • model – The PyTorch model to save

  • weights_path – Path to save model weights

  • optimizer – Optional optimizer to save

  • scheduler – Optional scheduler to save

  • optimizer_path – Path to save optimizer state (required if optimizer provided)

nemo_rl.utils.native_checkpoint.load_checkpoint(
model,
weights_path: str,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[Any] = None,
optimizer_path: Optional[str] = None,
) None[source]#

Load a model weights and optionally optimizer state.

Parameters:
  • model – The PyTorch model whose weights to update

  • weights_path – Path to load model weights from

  • optimizer – Optional optimizer to load state into

  • scheduler – Optional scheduler to load state into

  • optimizer_path – Path to load optimizer state from (required if optimizer provided)

nemo_rl.utils.native_checkpoint.convert_dcp_to_hf(
dcp_ckpt_path: str,
hf_ckpt_path: str,
model_name_or_path: str,
tokenizer_name_or_path: str,
overwrite: bool = False,
)[source]#

Convert a Torch DCP checkpoint to a Hugging Face checkpoint.

This is not an optimized utility. If checkpoint is too large, consider saving DCP during training and using this utility to convert to HF format.

Parameters:
  • dcp_ckpt_path (str) – Path to DCP checkpoint

  • hf_ckpt_path (str) – Path to save HF checkpoint

  • model_name_or_path (str) – Model name or path for config

  • tokenizer_name_or_path (str, optional) – Tokenizer name or path. Defaults to model_name_or_path if None.

  • overwrite (bool, optional) – Whether to overwrite existing checkpoint. Defaults to False.

Returns:

Path to the saved HF checkpoint

Return type:

str

Raises:

FileExistsError – If HF checkpoint already exists and overwrite is False