bridge.training.checkpointing
#
Input/output checkpointing.
Module Contents#
Classes#
Types of checkpoints to save. |
Functions#
Set the global checkpoint version number. |
|
Get the global checkpoint version number. |
|
Find the checkpoint directory for a given iteration, assuming distributed checkpoints. |
|
Read the metadata from the Megatron-LM tracker file. |
|
Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format. |
|
Schedule the async save request. |
|
Finalizes active async save calls. |
|
Check if async calls queue is empty. This result is consistent across ranks. |
|
Get the random number generator states for all necessary libraries. |
|
Save a model checkpoint. |
|
Clean up old non-persistent checkpoints in a directory. |
|
Save the dataloader state if the iterator supports it. |
|
Generate the model subset of the state dictionary to be saved in a checkpoint. |
|
Generate the state dictionary to be saved in a checkpoint. |
|
Load model weights from a checkpoint. |
|
Load a model checkpoint. |
|
Helper function to load state dict with fallback for missing extra states. |
|
Load a checkpoint from a given path. |
|
Initialize the async checkpoint worker if enabled. |
|
Initialize the checkpointing context, primarily for local checkpointing support. |
|
Filter state dict to contain only PEFT adapter parameters in model sections. |
|
Check if a checkpoint section contains model parameters. |
|
Helper function to transpose first dimension of tensor t. |
|
Get iteration number from non-persistent checkpoint. |
|
Load the base state_dict from a non-persistent distributed checkpoint. |
|
Load the base state_dict from the given directory containing the global distributed checkpoint. |
|
Load the base state_dict from the given directory. |
|
Builds metadata used for sharded_state_dict versioning. |
|
Create a TrainState from the state dict from a Megatron-LM checkpoint. |
Data#
API#
- bridge.training.checkpointing.TRACKER_PREFIX#
‘latest’
- bridge.training.checkpointing._CHECKPOINT_VERSION#
None
- bridge.training.checkpointing.logger#
‘getLogger(…)’
- bridge.training.checkpointing._NON_PERSISTENT_CKPT_SUBDIR#
‘non_persistent’
- bridge.training.checkpointing.set_checkpoint_version(value: float) None #
Set the global checkpoint version number.
- Parameters:
value – The checkpoint version number (e.g., 3.0).
- bridge.training.checkpointing.get_checkpoint_version() Optional[float] #
Get the global checkpoint version number.
- Returns:
The checkpoint version number, or None if not set.
- bridge.training.checkpointing.find_checkpoint_rank_0(
- checkpoints_path: str,
- iteration: int,
- release: bool = False,
Find the checkpoint directory for a given iteration, assuming distributed checkpoints.
- Parameters:
checkpoints_path – Base directory where checkpoints are stored.
iteration – The training iteration number.
release – If True, searches within the ‘release’ directory.
- Returns:
The full path to the checkpoint directory if it’s a valid distributed checkpoint, else None.
- bridge.training.checkpointing.read_metadata(tracker_filename: str) tuple[int, bool] #
Read the metadata from the Megatron-LM tracker file.
- Parameters:
tracker_filename – Path to the tracker file.
- Returns:
A tuple containing the iteration number and a boolean indicating if it’s a release checkpoint.
- bridge.training.checkpointing._extract_megatron_lm_args_from_state_dict(
- state_dict: dict[str, Any],
Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.
- Parameters:
state_dict – The loaded checkpoint state dictionary.
- Returns:
A dictionary in Megatron-Bridge config format with the essential fields.
- Raises:
RuntimeError – If args are not found in the state_dict.
- bridge.training.checkpointing.schedule_async_save(
- global_state: megatron.bridge.training.state.GlobalState,
- async_request: megatron.core.dist_checkpointing.strategies.async_utils.AsyncRequest,
Schedule the async save request.
- Parameters:
global_state – The global training state containing the async calls queue.
async_request – the async save request.
- bridge.training.checkpointing.maybe_finalize_async_save(
- global_state: megatron.bridge.training.state.GlobalState,
- ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
- blocking: bool = False,
- terminate: bool = False,
Finalizes active async save calls.
- Parameters:
global_state – The global training state containing the async calls queue.
ckpt_cfg (CheckpointConfig) – The checkpoint configuration.
blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.
terminate (bool, optional) – if True, the asynchronous queue will be closed as the last action of this function.
- bridge.training.checkpointing.is_empty_async_queue(
- global_state: megatron.bridge.training.state.GlobalState,
Check if async calls queue is empty. This result is consistent across ranks.
- Parameters:
global_state – The global training state containing the async calls queue.
- Returns:
True if there is any ongoing async call.
- Return type:
bool
- bridge.training.checkpointing.get_rng_state(
- data_parallel_random_init: bool,
Get the random number generator states for all necessary libraries.
Collects states from random, numpy, torch, cuda, and the Megatron RNG tracker. Optionally gathers states across data parallel ranks. Always wraps the result in a ShardedObject for distributed checkpointing.
- Parameters:
data_parallel_random_init – If True, gathers RNG states across data parallel ranks.
- Returns:
A ShardedObject containing the RNG states.
- class bridge.training.checkpointing.CheckpointType(*args, **kwds)#
Bases:
enum.Enum
Types of checkpoints to save.
Initialization
- LOCAL#
‘auto(…)’
- GLOBAL#
‘auto(…)’
- bridge.training.checkpointing.save_checkpoint(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- opt_param_scheduler: Optional[Any],
- num_floating_point_operations_so_far: int,
- checkpointing_context: Optional[dict[str, Any]] = None,
- pipeline_rank: Optional[int] = None,
- tensor_rank: Optional[int] = None,
- non_persistent_ckpt: bool = False,
- train_data_iterator: Optional[Any] = None,
- preprocess_common_state_dict_fn: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
Save a model checkpoint.
Handles saving the model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type (global or local). Supports synchronous and asynchronous saving.
- Parameters:
state – The GlobalState object.
model – The model module(s) to save.
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
num_floating_point_operations_so_far – Total FLOPs computed so far.
checkpointing_context – Dictionary to store context across saves (e.g., strategies).
pipeline_rank – Pipeline parallel rank (defaults to current rank).
tensor_rank – Tensor parallel rank (defaults to current rank).
non_persistent_ckpt – If True, saves as a non-persistent checkpoint.
train_data_iterator – The training data iterator (for saving state if supported).
preprocess_common_state_dict_fn – Optional function to preprocess the common state dict before consistency checks in distributed checkpointing.
- bridge.training.checkpointing.cleanup_old_non_persistent_checkpoint(
- save_dir: str,
- leave_ckpt_num: int = 1,
- do_async: bool = False,
Clean up old non-persistent checkpoints in a directory.
Keeps the specified number of latest checkpoints and removes older ones. Currently only cleans up directories matching “iter_*”.
- Parameters:
save_dir – The directory containing non-persistent checkpoints.
leave_ckpt_num – The number of latest checkpoints to keep.
do_async – If True, performs cleanup in a background thread.
- bridge.training.checkpointing.maybe_save_dataloader_state(
- train_iterator: Any,
- iteration: int,
- dataloader_save_path: Optional[str],
Save the dataloader state if the iterator supports it.
Checks if the train_iterator has a
save_state
method and calls it.- Parameters:
train_iterator – The training data iterator.
iteration – The current training iteration.
dataloader_save_path – The path where the dataloader state should be saved.
- bridge.training.checkpointing._generate_model_state_dict(
- model: list[megatron.core.transformer.MegatronModule],
- model_sd_kwargs: Optional[dict[str, Any]] = None,
Generate the model subset of the state dictionary to be saved in a checkpoint.
Can be added to the full checkpoint state dictionary with dict.update().
- Parameters:
model – The model module(s).
model_sd_kwargs – Metadata for model state dict generation.
- Returns:
A dictionary containing the model state to be saved.
- bridge.training.checkpointing.generate_state_dict(
- ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- opt_param_scheduler: Optional[Any],
- rng_state: Optional[megatron.core.dist_checkpointing.mapping.ShardedObject],
- iteration: Optional[int] = None,
- optim_sd_kwargs: Optional[dict[str, Any]] = None,
- model_sd_kwargs: Optional[dict[str, Any]] = None,
- rerun_state: Optional[dict[str, Any]] = None,
Generate the state dictionary to be saved in a checkpoint.
- Parameters:
cfg – The configuration container.
model – The model module(s).
optimizer – The optimizer instance.
opt_param_scheduler – The optimizer parameter scheduler instance.
rng_state – Collected RNG states as a ShardedObject.
iteration – The current training iteration.
optim_sd_kwargs – Additional keyword arguments for optimizer state dict generation.
model_sd_kwargs – Metadata for model state dict generation.
rerun_state – State dictionary from the rerun state machine.
- Returns:
A dictionary containing the complete state to be saved.
- bridge.training.checkpointing._load_model_weights_from_checkpoint(
- checkpoint_path: str,
- model: list[megatron.core.transformer.MegatronModule],
- fully_parallel_load: bool = False,
- return_state_dict: bool = False,
- dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all] = 'assume_ok_unexpected',
- strict: bool = True,
Load model weights from a checkpoint.
MCore distributed checkpoints from both Megatron Bridge and MegatronLM are supported. This function duplicates some logic from load_checkpoint() to simplify model loading for inference.
- Parameters:
checkpoint_path – path to a distributed checkpoint.
model – The model module(s) to load weights into.
fully_parallel_load – Apply full load parallelization across DP.
return_state_dict – Skips loading state dict into model and returns model state dict itself. Default False.
dist_ckpt_strictness – Determine handling of key mismatch during checkpoint load.
strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).
- bridge.training.checkpointing.load_checkpoint(
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- opt_param_scheduler: Optional[Any],
- strict: bool = True,
- checkpointing_context: Optional[dict[str, Any]] = None,
- skip_load_to_model_and_opt: bool = False,
Load a model checkpoint.
Handles loading model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type. Supports loading global distributed and local non-persistent checkpoints.
- Parameters:
state – The GlobalState object.
model – The model module(s) to load state into.
optimizer – The optimizer instance to load state into.
opt_param_scheduler – The scheduler instance to load state into.
strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).
checkpointing_context – Dictionary to store context across loads (e.g., strategies).
skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.
- Returns:
iteration: The training iteration number.
num_floating_point_operations_so_far: The total FLOPs computed so far.
- Return type:
A tuple containing
- bridge.training.checkpointing._load_model_state_dict(
- module: torch.nn.Module,
- state_dict: dict[str, Any],
- strict: bool,
Helper function to load state dict with fallback for missing extra states.
- bridge.training.checkpointing._load_checkpoint_from_path(
- load_dir: str,
- state: megatron.bridge.training.state.GlobalState,
- model: list[megatron.core.transformer.MegatronModule],
- optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
- opt_param_scheduler: Optional[Any],
- strict: bool = True,
- checkpointing_context: Optional[dict[str, Any]] = None,
- skip_load_to_model_and_opt: bool = False,
Load a checkpoint from a given path.
- Parameters:
load_dir – The directory containing the checkpoint.
state – The GlobalState object.
model – The model module(s) to load state into.
optimizer – The optimizer instance to load state into.
opt_param_scheduler – The scheduler instance to load state into.
strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).
checkpointing_context – Dictionary to store context across loads (e.g., strategies).
skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.
- Returns:
iteration: The training iteration number.
num_floating_point_operations_so_far: The total FLOPs computed so far.
- Return type:
A tuple containing
- bridge.training.checkpointing.init_async_checkpoint_worker(
- global_state: megatron.bridge.training.state.GlobalState,
Initialize the async checkpoint worker if enabled.
Creates a persistent background worker for handling asynchronous checkpoint saves when both async_save and use_persistent_ckpt_worker are enabled in the configuration.
- Parameters:
global_state – The GlobalState instance containing the configuration and async queue.
- bridge.training.checkpointing.init_checkpointing_context(
- checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
Initialize the checkpointing context, primarily for local checkpointing support.
If
non_persistent_ckpt_type
is set to “local”, this function sets up theLocalCheckpointManager
and replication strategy based on the providedcheckpoint_config
.- Parameters:
checkpoint_config – The checkpoint configuration object.
- Returns:
A dictionary containing the checkpointing context. This will include a
local_checkpoint_manager
if local checkpointing is enabled, otherwise it will be an empty dictionary.- Raises:
RuntimeError – If local checkpointing is configured but the
nvidia_resiliency_ext
module is not found.
- bridge.training.checkpointing.apply_peft_adapter_filter_to_state_dict(
- state_dict: dict[str, Any],
- peft_config: megatron.bridge.peft.base.PEFT,
Filter state dict to contain only PEFT adapter parameters in model sections.
This function takes a complete state dict (generated by generate_state_dict) and filters it to retain only PEFT adapter parameters for checkpoint saving. Follows the same key logic pattern as generate_state_dict for consistency.
- Parameters:
state_dict – Complete state dict from generate_state_dict()
peft_config – PEFT configuration for filtering logic
- Returns:
Filtered state dict containing only adapter parameters in model weights, while preserving all non-model metadata (checkpoint_version, iteration, etc.)
- bridge.training.checkpointing._is_model_section(section_key: str) bool #
Check if a checkpoint section contains model parameters.
Model sections are named:
“model” (single model)
“model0”, “model1”, etc. (pipeline parallel models)
Non-model sections include: “optimizer”, “iteration”, “checkpoint_version”, etc.
- bridge.training.checkpointing._transpose_first_dim(
- t: torch.Tensor,
- num_splits: int,
- num_splits_first: bool,
- model: torch.nn.Module,
Helper function to transpose first dimension of tensor t.
- bridge.training.checkpointing._get_non_persistent_iteration(
- non_persistent_global_dir: str,
- non_persistent_ckpt_type: typing.Optional[typing.Literal[global,
- local]] = None,
- checkpointing_context: typing.Optional[dict[str,
- typing.Any]] = None,
Get iteration number from non-persistent checkpoint.
- bridge.training.checkpointing._load_non_persistent_base_checkpoint(
- non_persistent_global_dir: str,
- ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
- rank0: bool,
- sharded_state_dict: Optional[dict[str, Any]],
- non_persistent_iteration: int,
- checkpointing_context: Optional[dict[str, Any]] = None,
Load the base state_dict from a non-persistent distributed checkpoint.
- bridge.training.checkpointing._load_global_dist_base_checkpoint(
- load_dir: str,
- ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
- rank0: bool,
- sharded_state_dict: Optional[dict[str, Any]],
- iteration: int,
- release: bool,
- checkpointing_context: Optional[dict[str, Any]] = None,
Load the base state_dict from the given directory containing the global distributed checkpoint.
- bridge.training.checkpointing._load_base_checkpoint(
- load_dir: Optional[str],
- ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
- rank0: bool = False,
- sharded_state_dict: Optional[dict[str, Any]] = None,
- checkpointing_context: Optional[dict[str, Any]] = None,
Load the base state_dict from the given directory.
- bridge.training.checkpointing._build_sharded_state_dict_metadata(
- use_distributed_optimizer: bool,
- ckpt_fully_parallel_save: bool,
Builds metadata used for sharded_state_dict versioning.
The whole content metadata is passed to
shared_state_dict
model and optimizer methods and therefore affects only the logic behind sharded_state_dict creation. The content metadata should be minimalistic, ideally flat (or with a single nesting level) and with semantically meaningful flag names (e.g.distrib_optim_sharding_type
). In particular, a simple integer (or SemVer) versioning flag (e.g.metadata['version'] = 3.4
) is discouraged, because the metadata serves for all models and optimizers and it’s practically impossible to enforce a linearly increasing versioning for this whole space.- Parameters:
use_distributed_optimizer – Whether to use distributed optimizer.
ckpt_fully_parallel_save – Whether to use fully parallel save.
- bridge.training.checkpointing._get_train_state_from_state_dict(
- state_dict: dict[str, Any],
Create a TrainState from the state dict from a Megatron-LM checkpoint.