bridge.training.checkpointing#

Input/output checkpointing.

Module Contents#

Classes#

CheckpointType

Types of checkpoints to save.

Functions#

set_checkpoint_version

Set the global checkpoint version number.

get_checkpoint_version

Get the global checkpoint version number.

find_checkpoint_rank_0

Find the checkpoint directory for a given iteration, assuming distributed checkpoints.

read_metadata

Read the metadata from the Megatron-LM tracker file.

_extract_megatron_lm_args_from_state_dict

Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.

schedule_async_save

Schedule the async save request.

maybe_finalize_async_save

Finalizes active async save calls.

is_empty_async_queue

Check if async calls queue is empty. This result is consistent across ranks.

get_rng_state

Get the random number generator states for all necessary libraries.

save_checkpoint

Save a model checkpoint.

cleanup_old_non_persistent_checkpoint

Clean up old non-persistent checkpoints in a directory.

maybe_save_dataloader_state

Save the dataloader state if the iterator supports it.

_generate_model_state_dict

Generate the model subset of the state dictionary to be saved in a checkpoint.

generate_state_dict

Generate the state dictionary to be saved in a checkpoint.

_load_model_weights_from_checkpoint

Load model weights from a checkpoint.

load_checkpoint

Load a model checkpoint.

_load_model_state_dict

Helper function to load state dict with fallback for missing extra states.

_load_checkpoint_from_path

Load a checkpoint from a given path.

init_async_checkpoint_worker

Initialize the async checkpoint worker if enabled.

init_checkpointing_context

Initialize the checkpointing context, primarily for local checkpointing support.

apply_peft_adapter_filter_to_state_dict

Filter state dict to contain only PEFT adapter parameters in model sections.

_is_model_section

Check if a checkpoint section contains model parameters.

_transpose_first_dim

Helper function to transpose first dimension of tensor t.

_get_non_persistent_iteration

Get iteration number from non-persistent checkpoint.

_load_non_persistent_base_checkpoint

Load the base state_dict from a non-persistent distributed checkpoint.

_load_global_dist_base_checkpoint

Load the base state_dict from the given directory containing the global distributed checkpoint.

_load_base_checkpoint

Load the base state_dict from the given directory.

_build_sharded_state_dict_metadata

Builds metadata used for sharded_state_dict versioning.

_get_train_state_from_state_dict

Create a TrainState from the state dict from a Megatron-LM checkpoint.

Data#

API#

bridge.training.checkpointing.TRACKER_PREFIX#

‘latest’

bridge.training.checkpointing._CHECKPOINT_VERSION#

None

bridge.training.checkpointing.logger#

‘getLogger(…)’

bridge.training.checkpointing._NON_PERSISTENT_CKPT_SUBDIR#

‘non_persistent’

bridge.training.checkpointing.set_checkpoint_version(value: float) None#

Set the global checkpoint version number.

Parameters:

value – The checkpoint version number (e.g., 3.0).

bridge.training.checkpointing.get_checkpoint_version() Optional[float]#

Get the global checkpoint version number.

Returns:

The checkpoint version number, or None if not set.

bridge.training.checkpointing.find_checkpoint_rank_0(
checkpoints_path: str,
iteration: int,
release: bool = False,
) Optional[str]#

Find the checkpoint directory for a given iteration, assuming distributed checkpoints.

Parameters:
  • checkpoints_path – Base directory where checkpoints are stored.

  • iteration – The training iteration number.

  • release – If True, searches within the ‘release’ directory.

Returns:

The full path to the checkpoint directory if it’s a valid distributed checkpoint, else None.

bridge.training.checkpointing.read_metadata(tracker_filename: str) tuple[int, bool]#

Read the metadata from the Megatron-LM tracker file.

Parameters:

tracker_filename – Path to the tracker file.

Returns:

A tuple containing the iteration number and a boolean indicating if it’s a release checkpoint.

bridge.training.checkpointing._extract_megatron_lm_args_from_state_dict(
state_dict: dict[str, Any],
) dict[str, Any]#

Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.

Parameters:

state_dict – The loaded checkpoint state dictionary.

Returns:

A dictionary in Megatron-Bridge config format with the essential fields.

Raises:

RuntimeError – If args are not found in the state_dict.

bridge.training.checkpointing.schedule_async_save(
global_state: megatron.bridge.training.state.GlobalState,
async_request: megatron.core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) None#

Schedule the async save request.

Parameters:
  • global_state – The global training state containing the async calls queue.

  • async_request – the async save request.

bridge.training.checkpointing.maybe_finalize_async_save(
global_state: megatron.bridge.training.state.GlobalState,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
blocking: bool = False,
terminate: bool = False,
) None#

Finalizes active async save calls.

Parameters:
  • global_state – The global training state containing the async calls queue.

  • ckpt_cfg (CheckpointConfig) – The checkpoint configuration.

  • blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.

  • terminate (bool, optional) – if True, the asynchronous queue will be closed as the last action of this function.

bridge.training.checkpointing.is_empty_async_queue(
global_state: megatron.bridge.training.state.GlobalState,
) bool#

Check if async calls queue is empty. This result is consistent across ranks.

Parameters:

global_state – The global training state containing the async calls queue.

Returns:

True if there is any ongoing async call.

Return type:

bool

bridge.training.checkpointing.get_rng_state(
data_parallel_random_init: bool,
) megatron.core.dist_checkpointing.mapping.ShardedObject#

Get the random number generator states for all necessary libraries.

Collects states from random, numpy, torch, cuda, and the Megatron RNG tracker. Optionally gathers states across data parallel ranks. Always wraps the result in a ShardedObject for distributed checkpointing.

Parameters:

data_parallel_random_init – If True, gathers RNG states across data parallel ranks.

Returns:

A ShardedObject containing the RNG states.

class bridge.training.checkpointing.CheckpointType(*args, **kwds)#

Bases: enum.Enum

Types of checkpoints to save.

Initialization

LOCAL#

‘auto(…)’

GLOBAL#

‘auto(…)’

bridge.training.checkpointing.save_checkpoint(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
num_floating_point_operations_so_far: int,
checkpointing_context: Optional[dict[str, Any]] = None,
pipeline_rank: Optional[int] = None,
tensor_rank: Optional[int] = None,
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Any] = None,
preprocess_common_state_dict_fn: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
) None#

Save a model checkpoint.

Handles saving the model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type (global or local). Supports synchronous and asynchronous saving.

Parameters:
  • state – The GlobalState object.

  • model – The model module(s) to save.

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Total FLOPs computed so far.

  • checkpointing_context – Dictionary to store context across saves (e.g., strategies).

  • pipeline_rank – Pipeline parallel rank (defaults to current rank).

  • tensor_rank – Tensor parallel rank (defaults to current rank).

  • non_persistent_ckpt – If True, saves as a non-persistent checkpoint.

  • train_data_iterator – The training data iterator (for saving state if supported).

  • preprocess_common_state_dict_fn – Optional function to preprocess the common state dict before consistency checks in distributed checkpointing.

bridge.training.checkpointing.cleanup_old_non_persistent_checkpoint(
save_dir: str,
leave_ckpt_num: int = 1,
do_async: bool = False,
) None#

Clean up old non-persistent checkpoints in a directory.

Keeps the specified number of latest checkpoints and removes older ones. Currently only cleans up directories matching “iter_*”.

Parameters:
  • save_dir – The directory containing non-persistent checkpoints.

  • leave_ckpt_num – The number of latest checkpoints to keep.

  • do_async – If True, performs cleanup in a background thread.

bridge.training.checkpointing.maybe_save_dataloader_state(
train_iterator: Any,
iteration: int,
dataloader_save_path: Optional[str],
) None#

Save the dataloader state if the iterator supports it.

Checks if the train_iterator has a save_state method and calls it.

Parameters:
  • train_iterator – The training data iterator.

  • iteration – The current training iteration.

  • dataloader_save_path – The path where the dataloader state should be saved.

bridge.training.checkpointing._generate_model_state_dict(
model: list[megatron.core.transformer.MegatronModule],
model_sd_kwargs: Optional[dict[str, Any]] = None,
) dict[str, megatron.core.dist_checkpointing.mapping.ShardedStateDict]#

Generate the model subset of the state dictionary to be saved in a checkpoint.

Can be added to the full checkpoint state dictionary with dict.update().

Parameters:
  • model – The model module(s).

  • model_sd_kwargs – Metadata for model state dict generation.

Returns:

A dictionary containing the model state to be saved.

bridge.training.checkpointing.generate_state_dict(
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
rng_state: Optional[megatron.core.dist_checkpointing.mapping.ShardedObject],
iteration: Optional[int] = None,
optim_sd_kwargs: Optional[dict[str, Any]] = None,
model_sd_kwargs: Optional[dict[str, Any]] = None,
rerun_state: Optional[dict[str, Any]] = None,
) dict[str, Any]#

Generate the state dictionary to be saved in a checkpoint.

Parameters:
  • cfg – The configuration container.

  • model – The model module(s).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • rng_state – Collected RNG states as a ShardedObject.

  • iteration – The current training iteration.

  • optim_sd_kwargs – Additional keyword arguments for optimizer state dict generation.

  • model_sd_kwargs – Metadata for model state dict generation.

  • rerun_state – State dictionary from the rerun state machine.

Returns:

A dictionary containing the complete state to be saved.

bridge.training.checkpointing._load_model_weights_from_checkpoint(
checkpoint_path: str,
model: list[megatron.core.transformer.MegatronModule],
fully_parallel_load: bool = False,
return_state_dict: bool = False,
dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all] = 'assume_ok_unexpected',
strict: bool = True,
) Optional[Union[megatron.core.dist_checkpointing.serialization.StateDict, tuple[megatron.core.dist_checkpointing.serialization.StateDict, set[str], set[str]]]]#

Load model weights from a checkpoint.

MCore distributed checkpoints from both Megatron Bridge and MegatronLM are supported. This function duplicates some logic from load_checkpoint() to simplify model loading for inference.

Parameters:
  • checkpoint_path – path to a distributed checkpoint.

  • model – The model module(s) to load weights into.

  • fully_parallel_load – Apply full load parallelization across DP.

  • return_state_dict – Skips loading state dict into model and returns model state dict itself. Default False.

  • dist_ckpt_strictness – Determine handling of key mismatch during checkpoint load.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

bridge.training.checkpointing.load_checkpoint(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
strict: bool = True,
checkpointing_context: Optional[dict[str, Any]] = None,
skip_load_to_model_and_opt: bool = False,
) tuple[int, int]#

Load a model checkpoint.

Handles loading model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type. Supports loading global distributed and local non-persistent checkpoints.

Parameters:
  • state – The GlobalState object.

  • model – The model module(s) to load state into.

  • optimizer – The optimizer instance to load state into.

  • opt_param_scheduler – The scheduler instance to load state into.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

  • checkpointing_context – Dictionary to store context across loads (e.g., strategies).

  • skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.

Returns:

  • iteration: The training iteration number.

  • num_floating_point_operations_so_far: The total FLOPs computed so far.

Return type:

A tuple containing

bridge.training.checkpointing._load_model_state_dict(
module: torch.nn.Module,
state_dict: dict[str, Any],
strict: bool,
)#

Helper function to load state dict with fallback for missing extra states.

bridge.training.checkpointing._load_checkpoint_from_path(
load_dir: str,
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
strict: bool = True,
checkpointing_context: Optional[dict[str, Any]] = None,
skip_load_to_model_and_opt: bool = False,
) tuple[int, int]#

Load a checkpoint from a given path.

Parameters:
  • load_dir – The directory containing the checkpoint.

  • state – The GlobalState object.

  • model – The model module(s) to load state into.

  • optimizer – The optimizer instance to load state into.

  • opt_param_scheduler – The scheduler instance to load state into.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

  • checkpointing_context – Dictionary to store context across loads (e.g., strategies).

  • skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.

Returns:

  • iteration: The training iteration number.

  • num_floating_point_operations_so_far: The total FLOPs computed so far.

Return type:

A tuple containing

bridge.training.checkpointing.init_async_checkpoint_worker(
global_state: megatron.bridge.training.state.GlobalState,
) None#

Initialize the async checkpoint worker if enabled.

Creates a persistent background worker for handling asynchronous checkpoint saves when both async_save and use_persistent_ckpt_worker are enabled in the configuration.

Parameters:

global_state – The GlobalState instance containing the configuration and async queue.

bridge.training.checkpointing.init_checkpointing_context(
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
) dict[str, Any]#

Initialize the checkpointing context, primarily for local checkpointing support.

If non_persistent_ckpt_type is set to “local”, this function sets up the LocalCheckpointManager and replication strategy based on the provided checkpoint_config.

Parameters:

checkpoint_config – The checkpoint configuration object.

Returns:

A dictionary containing the checkpointing context. This will include a local_checkpoint_manager if local checkpointing is enabled, otherwise it will be an empty dictionary.

Raises:

RuntimeError – If local checkpointing is configured but the nvidia_resiliency_ext module is not found.

bridge.training.checkpointing.apply_peft_adapter_filter_to_state_dict(
state_dict: dict[str, Any],
peft_config: megatron.bridge.peft.base.PEFT,
) dict[str, Any]#

Filter state dict to contain only PEFT adapter parameters in model sections.

This function takes a complete state dict (generated by generate_state_dict) and filters it to retain only PEFT adapter parameters for checkpoint saving. Follows the same key logic pattern as generate_state_dict for consistency.

Parameters:
  • state_dict – Complete state dict from generate_state_dict()

  • peft_config – PEFT configuration for filtering logic

Returns:

Filtered state dict containing only adapter parameters in model weights, while preserving all non-model metadata (checkpoint_version, iteration, etc.)

bridge.training.checkpointing._is_model_section(section_key: str) bool#

Check if a checkpoint section contains model parameters.

Model sections are named:

  • “model” (single model)

  • “model0”, “model1”, etc. (pipeline parallel models)

Non-model sections include: “optimizer”, “iteration”, “checkpoint_version”, etc.

bridge.training.checkpointing._transpose_first_dim(
t: torch.Tensor,
num_splits: int,
num_splits_first: bool,
model: torch.nn.Module,
) torch.Tensor#

Helper function to transpose first dimension of tensor t.

bridge.training.checkpointing._get_non_persistent_iteration(
non_persistent_global_dir: str,
non_persistent_ckpt_type: typing.Optional[typing.Literal[global,
local]] = None,
checkpointing_context: typing.Optional[dict[str,
typing.Any]] = None,
) int#

Get iteration number from non-persistent checkpoint.

bridge.training.checkpointing._load_non_persistent_base_checkpoint(
non_persistent_global_dir: str,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool,
sharded_state_dict: Optional[dict[str, Any]],
non_persistent_iteration: int,
checkpointing_context: Optional[dict[str, Any]] = None,
) tuple[dict[str, Any], str, bool, bridge.training.checkpointing.CheckpointType]#

Load the base state_dict from a non-persistent distributed checkpoint.

bridge.training.checkpointing._load_global_dist_base_checkpoint(
load_dir: str,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool,
sharded_state_dict: Optional[dict[str, Any]],
iteration: int,
release: bool,
checkpointing_context: Optional[dict[str, Any]] = None,
) tuple[dict[str, Any], str, bool, bridge.training.checkpointing.CheckpointType]#

Load the base state_dict from the given directory containing the global distributed checkpoint.

bridge.training.checkpointing._load_base_checkpoint(
load_dir: Optional[str],
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool = False,
sharded_state_dict: Optional[dict[str, Any]] = None,
checkpointing_context: Optional[dict[str, Any]] = None,
) tuple[Optional[dict[str, Any]], str, bool, Optional[bridge.training.checkpointing.CheckpointType]]#

Load the base state_dict from the given directory.

bridge.training.checkpointing._build_sharded_state_dict_metadata(
use_distributed_optimizer: bool,
ckpt_fully_parallel_save: bool,
) dict#

Builds metadata used for sharded_state_dict versioning.

The whole content metadata is passed to shared_state_dict model and optimizer methods and therefore affects only the logic behind sharded_state_dict creation. The content metadata should be minimalistic, ideally flat (or with a single nesting level) and with semantically meaningful flag names (e.g. distrib_optim_sharding_type). In particular, a simple integer (or SemVer) versioning flag (e.g. metadata['version'] = 3.4) is discouraged, because the metadata serves for all models and optimizers and it’s practically impossible to enforce a linearly increasing versioning for this whole space.

Parameters:
  • use_distributed_optimizer – Whether to use distributed optimizer.

  • ckpt_fully_parallel_save – Whether to use fully parallel save.

bridge.training.checkpointing._get_train_state_from_state_dict(
state_dict: dict[str, Any],
) megatron.bridge.training.state.TrainState#

Create a TrainState from the state dict from a Megatron-LM checkpoint.