nemo_rl.utils.checkpoint#

Checkpoint management utilities for the rl algorithm loop.

It handles logic at the algorithm level. Each RL Actor is expected to have its own checkpoint saving function (called by the algorithm loop).

Module Contents#

Classes#

CheckpointingConfig

Configuration for checkpoint management.

CheckpointManager

Manages model checkpoints during training.

Functions#

_load_checkpoint_history

Load the history of checkpoints and their metrics.

API#

class nemo_rl.utils.checkpoint.CheckpointingConfig[source]#

Bases: typing.TypedDict

Configuration for checkpoint management.

Attributes: enabled (bool): Whether checkpointing is enabled. checkpoint_dir (os.PathLike): Directory where checkpoints will be saved. metric_name (str): Name of the metric to use for determining best checkpoints. higher_is_better (bool): Whether higher values of the metric indicate better performance. keep_top_k (Optional[int]): Number of best checkpoints to keep. If None, all checkpoints are kept.

Initialization

Initialize self. See help(type(self)) for accurate signature.

enabled: bool#

None

checkpoint_dir: os.PathLike#

None

metric_name: str#

None

higher_is_better: bool#

None

keep_top_k: Optional[int]#

None

class nemo_rl.utils.checkpoint.CheckpointManager(
config: nemo_rl.utils.checkpoint.CheckpointingConfig,
)[source]#

Manages model checkpoints during training.

This class handles creating checkpoint dirs, saving training info, and configurations. It also provides utilities for keeping just the top-k checkpoints. The checkpointing structure looks like this:

checkpoint_dir/
    step_0/
        training_info.json
        config.yaml
        policy.py (up to the algorithm loop to save here)
        policy_optimizer.py (up to the algorithm loop to save here)
        ...
    step_1/
        ...

Attributes: Derived from the CheckpointingConfig.

Initialization

Initialize the checkpoint manager.

Parameters:

config (CheckpointingConfig)

init_tmp_checkpoint(
step: int,
training_info: Dict[str, Any],
run_config: Optional[Dict[str, Any]] = None,
) os.PathLike[source]#

Initialize a temporary checkpoint directory.

Creates a temporary directory for a new checkpoint and saves training info and configuration. The directory is named β€˜tmp_step_{step}’ and will be renamed to β€˜step_{step}’ when the checkpoint is completed. We do it this way to allow the algorithm loop to save any files it wants to save in a safe, temporary directory.

Parameters:
  • step (int) – The training step number.

  • training_info (Dict[str, Any]) – Dictionary containing training metrics and info.

  • run_config (Optional[Dict[str, Any]]) – Optional configuration for the training run.

Returns:

Path to the temporary checkpoint directory.

Return type:

os.PathLike

finalize_checkpoint(checkpoint_path: os.PathLike) None[source]#

Complete a checkpoint by moving it from temporary to permanent location.

If a checkpoint at the target location already exists (i.e when resuming training), we override the old one. Also triggers cleanup of old checkpoints based on the keep_top_k setting.

Parameters:

checkpoint_path (os.PathLike) – Path to the temporary checkpoint directory.

remove_old_checkpoints(exclude_latest: bool = True) None[source]#

Remove checkpoints that are not in the top-k or latest based on the metric.

If keep_top_k is set, this method removes all checkpoints except the top-k best ones based on the specified metric. The best checkpoints are determined by the metric value and the higher_is_better setting. When multiple checkpoints have the same metric value, more recent checkpoints (higher step numbers) are preferred.

Parameters:

exclude_latest (bool) – Whether to exclude the latest checkpoint from deletion. (may result in K+1 checkpoints)

get_best_checkpoint_path() Optional[str][source]#

Get the path to the best checkpoint based on the metric.

Returns the path to the checkpoint with the best metric value. If no checkpoints exist, returns None. If the metric isn’t found, we warn and return the latest checkpoint.

Returns:

Path to the best checkpoint, or None if no valid checkpoints exist.

Return type:

Optional[str]

get_latest_checkpoint_path() str[source]#

Get the path to the latest checkpoint.

Returns the path to the checkpoint with the highest step number.

Returns:

Path to the latest checkpoint, or None if no checkpoints exist.

Return type:

str

load_training_info(
checkpoint_path: Optional[os.PathLike] = None,
) Dict[str, Any][source]#

Load the training info from a checkpoint.

Parameters:

checkpoint_path (Optional[os.PathLike]) – Path to the checkpoint. If None, returns None.

Returns:

Dictionary containing the training info, or None if checkpoint_path is None.

Return type:

Dict[str, Any]

nemo_rl.utils.checkpoint._load_checkpoint_history(
checkpoint_dir: pathlib.Path,
) List[Tuple[int, os.PathLike, Dict[str, Any]]][source]#

Load the history of checkpoints and their metrics.

Parameters:

checkpoint_dir (Path) – Directory containing the checkpoints.

Returns:

List of tuples containing (step_number, checkpoint_path, info) for each checkpoint.

Return type:

List[Tuple[int, os.PathLike, Dict[str, Any]]]