bridge.training.checkpointing#

Input/output checkpointing.

Module Contents#

Classes#

CheckpointType

Types of checkpoints to save.

CheckpointSaveContext

Context containing all state needed for a checkpoint save operation.

CheckpointLoadContext

Context containing all state needed for a checkpoint load operation.

CheckpointManager

Protocol defining the checkpoint manager interface.

DefaultCheckpointManager

Default checkpoint manager that delegates to existing functional code.

Functions#

set_checkpoint_version

Set the global checkpoint version number.

get_checkpoint_version

Get the global checkpoint version number.

delete_extra_state

Delete all extra state keys from the model state dictionary.

_get_checkpoint_format

Determine the checkpoint format by examining the checkpoint directory.

find_checkpoint_rank_0

Find the checkpoint directory for a given iteration, assuming distributed checkpoints.

read_metadata

Read the metadata from the Megatron-LM tracker file.

_extract_megatron_lm_args_from_state_dict

Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.

schedule_async_save

Schedule the async save request.

maybe_finalize_async_save

Finalizes active async save calls.

is_empty_async_queue

Check if async calls queue is empty. This result is consistent across ranks.

get_save_and_finalize_callbacks

Creates an async save request for fsdp_dtensor & torch_dcp with a finalize function.

get_rng_state

Get the random number generator states for all necessary libraries.

create_checkpoint_manager

Factory function to create a checkpoint manager.

save_checkpoint

Save a model checkpoint.

cleanup_old_non_persistent_checkpoint

Clean up old non-persistent checkpoints in a directory.

maybe_save_dataloader_state

Save the dataloader state if the iterator supports it.

save_tokenizer_assets

Save tokenizer files to the checkpoint directory.

_generate_model_state_dict

Generate the model subset of the state dictionary to be saved in a checkpoint.

generate_state_dict

Generate the state dictionary to be saved in a checkpoint.

preprocess_fsdp_dtensor_state_dict

Preprocess FSDP DTensor state dict before saving.

save_fsdp_dtensor_checkpoint

Preprocess and save an FSDP DTensor checkpoint with PyTorch DCP.

_create_fsdp_dtensor_storage_writer

_load_model_weights_from_checkpoint

Load model weights from a checkpoint.

load_checkpoint

Load a model checkpoint.

_deinterleave_glu_tensor

De-interleave SwiGLU fc1 tensor along dim 0: block-interleaved -> contiguous [W_all, V_all].

_interleave_glu_tensor

Interleave SwiGLU fc1 tensor along dim 0: contiguous [W_all, V_all] -> block-interleaved.

_is_swiglu_fc1_checkpoint_key

True for selected SwiGLU linear_fc1 weights/biases.

_apply_glu_interleave_to_tensor_data

Run interleave or de-interleave on fc1 weight or bias (identical dim-0 layout).

_process_state_dict_for_glu_interleaving

Process GLU weights and biases in state dict for interleaving or de-interleaving.

_get_model_glu_interleave_sizes

Return routed/dense and shared-expert GLU interleave sizes for this model.

_process_state_dict_for_model_glu_interleaving

Apply GLU interleaving transforms for each model component that needs them.

_load_model_state_dict

Helper function to load state dict with fallback for missing extra states.

_load_checkpoint_from_path

Load a checkpoint from a given path.

init_checkpointing_context

Initialize the checkpointing context, primarily for local checkpointing support.

apply_peft_adapter_filter_to_state_dict

Filter state dict to contain only PEFT adapter parameters in model sections.

_is_model_section

Check if a checkpoint section contains model parameters.

_resolve_checkpoint_iteration

Resolve which checkpoint iteration to load.

_get_non_persistent_iteration

Get iteration number from non-persistent checkpoint.

_load_non_persistent_base_checkpoint

Load the base state_dict from a non-persistent distributed checkpoint.

_load_global_dist_base_checkpoint

Load the base state_dict from the given directory containing the global distributed checkpoint.

_load_base_checkpoint

Load the base state_dict from the given directory.

load_fsdp_dtensor_checkpoint

Load the base state dict from an FSDP DTensor checkpoint.

_build_sharded_state_dict_metadata

Builds metadata used for sharded_state_dict versioning.

_get_train_state_from_state_dict

Create a TrainState from the state dict from a Megatron-LM checkpoint.

Data#

API#

bridge.training.checkpointing.TRACKER_PREFIX#

‘latest’

bridge.training.checkpointing._CHECKPOINT_VERSION#

None

bridge.training.checkpointing.logger#

‘getLogger(…)’

bridge.training.checkpointing._NON_PERSISTENT_CKPT_SUBDIR#

‘non_persistent’

bridge.training.checkpointing._DIRECT_ITERATION_DIR_SENTINEL#

None

bridge.training.checkpointing.set_checkpoint_version(value: float) None#

Set the global checkpoint version number.

Parameters:

value – The checkpoint version number (e.g., 3.0).

bridge.training.checkpointing.get_checkpoint_version() Optional[float]#

Get the global checkpoint version number.

Returns:

The checkpoint version number, or None if not set.

bridge.training.checkpointing.delete_extra_state(state_dict)#

Delete all extra state keys from the model state dictionary.

This function removes all keys containing ‘_extra_state’ from the model portion of the state dictionary. This is useful for cleaning up corrupted or problematic extra state that can cause issues during model loading.

Parameters:

state_dict

The state dictionary. Can be either:

  • A full checkpoint dict with a “model” key, or

  • A model state dict directly

Returns:

The modified state dictionary with extra state keys removed.

bridge.training.checkpointing._get_checkpoint_format(checkpoint_path: str) str#

Determine the checkpoint format by examining the checkpoint directory.

Parameters:

checkpoint_path – Path to the checkpoint directory.

Returns:

The checkpoint format string.

bridge.training.checkpointing.find_checkpoint_rank_0(
checkpoints_path: str,
iteration: int,
release: bool = False,
) Optional[str]#

Find the checkpoint directory for a given iteration, assuming distributed checkpoints.

Parameters:
  • checkpoints_path – Base directory where checkpoints are stored.

  • iteration – The training iteration number.

  • release – If True, searches within the ‘release’ directory.

Returns:

The full path to the checkpoint directory if it’s a valid distributed checkpoint, else None.

bridge.training.checkpointing.read_metadata(tracker_filename: str) tuple[int, bool]#

Read the metadata from the Megatron-LM tracker file.

Parameters:

tracker_filename – Path to the tracker file.

Returns:

A tuple containing the iteration number and a boolean indicating if it’s a release checkpoint.

bridge.training.checkpointing._extract_megatron_lm_args_from_state_dict(
state_dict: dict[str, Any],
) dict[str, Any]#

Extract and convert legacy Megatron-LM args from checkpoint state_dict to Megatron-Bridge config format.

Parameters:

state_dict – The loaded checkpoint state dictionary.

Returns:

A dictionary in Megatron-Bridge config format with the essential fields.

Raises:

RuntimeError – If args are not found in the state_dict.

bridge.training.checkpointing.schedule_async_save(
global_state: megatron.bridge.training.state.GlobalState,
async_request: megatron.core.dist_checkpointing.strategies.async_utils.AsyncRequest,
) None#

Schedule the async save request.

Parameters:
  • global_state – The global training state containing the async calls queue.

  • async_request – the async save request.

bridge.training.checkpointing.maybe_finalize_async_save(
global_state: megatron.bridge.training.state.GlobalState,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
blocking: bool = False,
terminate: bool = False,
) None#

Finalizes active async save calls.

Parameters:
  • global_state – The global training state containing the async calls queue.

  • ckpt_cfg (CheckpointConfig) – The checkpoint configuration.

  • blocking (bool, optional) – if True, will wait until all active requests are done. Otherwise, finalizes only the async request that already finished. Defaults to False.

  • terminate (bool, optional) – if True, the asynchronous queue will be closed as the last action of this function.

bridge.training.checkpointing.is_empty_async_queue(
global_state: megatron.bridge.training.state.GlobalState,
) bool#

Check if async calls queue is empty. This result is consistent across ranks.

Parameters:

global_state – The global training state containing the async calls queue.

Returns:

True if there is any ongoing async call.

Return type:

bool

bridge.training.checkpointing.get_save_and_finalize_callbacks(
writer,
save_state_dict_ret,
) nvidia_resiliency_ext.checkpointing.async_ckpt.core.AsyncRequest#

Creates an async save request for fsdp_dtensor & torch_dcp with a finalize function.

bridge.training.checkpointing.get_rng_state(
data_parallel_random_init: bool,
ckpt_format: str = 'torch_dist',
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
module_name: str | None = None,
) megatron.core.dist_checkpointing.mapping.ShardedObject | dict#

Get the random number generator states for all necessary libraries.

Collects states from random, numpy, torch, cuda, and the Megatron RNG tracker. Optionally gathers states across data parallel ranks. Returns format depends on checkpoint format.

For torch_dist format with Expert Parallelism (EP > 1), RNG states are sharded by (PP, TP, DP) dimensions since different EP ranks may have different RNG states. Without EP, states are sharded by (PP, TP) with DP rank as replica_id.

Parameters:
  • data_parallel_random_init – If True, gathers RNG states across data parallel ranks.

  • ckpt_format – The checkpoint format being used.

  • pg_collection – Process group collection for accessing parallel ranks/sizes.

  • module_name – Optional module name for MegatronMIMO per-module RNG namespacing. When set, the ShardedObject key becomes "rng_state.{module_name}" to avoid duplicate shard keys across modules that share the same (pp_rank, tp_rank) coordinates from module-local process groups.

Returns:

A ShardedObject containing the RNG states, sharded by (PP, TP, DP) when EP > 1, or (PP, TP) with DP as replica_id otherwise. For fsdp_dtensor: A dict mapping (pp_rank, tp_rank) to RNG state lists.

Return type:

For torch_dist

class bridge.training.checkpointing.CheckpointType(*args, **kwds)#

Bases: enum.Enum

Types of checkpoints to save.

Initialization

LOCAL#

‘auto(…)’

GLOBAL#

‘auto(…)’

FSDP_DTENSOR#

‘auto(…)’

class bridge.training.checkpointing.CheckpointSaveContext#

Context containing all state needed for a checkpoint save operation.

.. attribute:: state

The GlobalState object containing config, train_state, etc.

.. attribute:: model

List of model modules (MegatronModule instances).

.. attribute:: optimizer

The optimizer instance (may be None for inference checkpoints).

.. attribute:: opt_param_scheduler

The learning rate scheduler instance.

.. attribute:: num_floating_point_operations_so_far

Cumulative FLOPs computed up to this point.

.. attribute:: train_data_iterator

Optional training data iterator to save its state.

.. attribute:: non_persistent_ckpt

If True, saves as a non-persistent (temporary) checkpoint.

state: megatron.bridge.training.state.GlobalState#

None

model: list[megatron.core.transformer.MegatronModule]#

None

optimizer: megatron.core.optimizer.MegatronOptimizer | None#

None

opt_param_scheduler: Any | None#

None

num_floating_point_operations_so_far: int#

None

train_data_iterator: Any | None#

None

non_persistent_ckpt: bool#

False

pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None#

None

module_name: str | None#

None

class bridge.training.checkpointing.CheckpointLoadContext#

Context containing all state needed for a checkpoint load operation.

.. attribute:: state

The GlobalState object containing config, train_state, etc.

.. attribute:: model

List of model modules to load state into.

.. attribute:: optimizer

The optimizer instance to load state into.

.. attribute:: opt_param_scheduler

The learning rate scheduler instance.

.. attribute:: strict

Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

.. attribute:: skip_load_to_model_and_opt

If True, only loads metadata but skips loading state into model and optimizer modules.

state: megatron.bridge.training.state.GlobalState#

None

model: list[megatron.core.transformer.MegatronModule]#

None

optimizer: megatron.core.optimizer.MegatronOptimizer | None#

None

opt_param_scheduler: Any | None#

None

strict: bool#

True

skip_load_to_model_and_opt: bool#

False

pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None#

None

module_name: str | None#

None

class bridge.training.checkpointing.CheckpointManager(
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
)#

Bases: typing.Protocol

Protocol defining the checkpoint manager interface.

Implement this protocol to create custom checkpoint save/load behavior. The default implementation (DefaultCheckpointManager) delegates to the existing functional checkpoint code.

Initialization

Initialize the checkpoint manager.

Parameters:

checkpoint_config – The checkpoint configuration.

save(
ctx: bridge.training.checkpointing.CheckpointSaveContext,
callback_manager: Optional[megatron.bridge.training.callbacks.CallbackManager],
) None#

Save a checkpoint.

Parameters:

ctx – CheckpointSaveContext containing all state needed for save.

load(
ctx: bridge.training.checkpointing.CheckpointLoadContext,
) tuple[int, int]#

Load a checkpoint.

Parameters:

ctx – CheckpointLoadContext containing all state needed for load.

Returns:

A tuple of (iteration, num_floating_point_operations_so_far). Returns (0, 0) if no checkpoint was loaded.

finalize_async_saves(
state: megatron.bridge.training.state.GlobalState,
blocking: bool = False,
terminate: bool = False,
) None#

Finalize any pending asynchronous checkpoint saves.

Parameters:
  • state – The GlobalState object (needed for async_calls_queue access).

  • blocking – If True, wait for all pending saves to complete.

  • terminate – If True, close the async queue after finalization.

class bridge.training.checkpointing.DefaultCheckpointManager(
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
)#

Default checkpoint manager that delegates to existing functional code.

This implementation wraps the default save_checkpoint and load_checkpoint functions.

The manager owns the checkpointing_context dictionary, which is used to cache strategies and local checkpoint managers across save/load operations.

.. attribute:: checkpoint_config

The CheckpointConfig instance.

.. attribute:: _context

Internal context dictionary for caching checkpoint strategies.

Initialization

Initialize the checkpoint manager.

Parameters:

checkpoint_config – The checkpoint configuration.

property checkpointing_context: dict[str, Any]#

The internal checkpointing context dictionary.

This context is passed to save/load functions and caches:

  • Save/load strategies for distributed checkpointing

  • Local checkpoint manager (for non-persistent local checkpoints)

  • Cached metadata for constant-structure optimization

save(
ctx: bridge.training.checkpointing.CheckpointSaveContext,
callback_manager: Optional[megatron.bridge.training.callbacks.CallbackManager],
) None#

Save a checkpoint using the default implementation.

Delegates to save_checkpoint function.

Parameters:

ctx – CheckpointSaveContext containing all state needed for save.

load(
ctx: bridge.training.checkpointing.CheckpointLoadContext,
) tuple[int, int]#

Load a checkpoint using the default implementation.

Delegates to load_checkpoint function.

Parameters:

ctx – CheckpointLoadContext containing all state needed for load.

Returns:

A tuple of (iteration, num_floating_point_operations_so_far).

finalize_async_saves(
state: megatron.bridge.training.state.GlobalState,
blocking: bool = False,
terminate: bool = False,
) None#

Finalize any pending asynchronous checkpoint saves.

Parameters:
  • state – The GlobalState object (needed for async_calls_queue access).

  • blocking – If True, wait for all pending saves to complete.

  • terminate – If True, close the async queue after finalization.

bridge.training.checkpointing.create_checkpoint_manager(
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
) bridge.training.checkpointing.CheckpointManager#

Factory function to create a checkpoint manager.

Creates either the default checkpoint manager or a custom manager based on the checkpoint_config.custom_manager_class setting.

Parameters:

checkpoint_config – The checkpoint configuration. If custom_manager_class is set, it should be a fully qualified class name (e.g., “mypackage.module.MyManager”).

Returns:

A CheckpointManager instance.

Raises:
  • ImportError – If the custom manager module cannot be imported.

  • AttributeError – If the custom manager class is not found in the module.

  • ValueError – If custom_manager_class format is invalid.

  • TypeError – If the custom manager does not implement the CheckpointManager protocol.

.. rubric:: Example

Default manager#

config = CheckpointConfig(save=”/path/to/checkpoints”) manager = create_checkpoint_manager(config)

Custom manager#

config = CheckpointConfig( save=”/path/to/checkpoints”, custom_manager_class=”mypackage.checkpoint.MyCheckpointManager”, ) manager = create_checkpoint_manager(config)

bridge.training.checkpointing.save_checkpoint(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
num_floating_point_operations_so_far: int,
checkpointing_context: Optional[dict[str, Any]] = None,
pipeline_rank: Optional[int] = None,
tensor_rank: Optional[int] = None,
non_persistent_ckpt: bool = False,
train_data_iterator: Optional[Any] = None,
preprocess_common_state_dict_fn: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
prebuilt_state_dict: Optional[dict[str, Any]] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
callback_manager: Optional[megatron.bridge.training.callbacks.CallbackManager] = None,
module_name: str | None = None,
) None#

Save a model checkpoint.

Handles saving the model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type (global or local). Supports synchronous and asynchronous saving.

Parameters:
  • state – The GlobalState object.

  • model – The model module(s) to save.

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • num_floating_point_operations_so_far – Total FLOPs computed so far.

  • checkpointing_context – Dictionary to store context across saves (e.g., strategies).

  • pipeline_rank – Pipeline parallel rank (defaults to current rank).

  • tensor_rank – Tensor parallel rank (defaults to current rank).

  • non_persistent_ckpt – If True, saves as a non-persistent checkpoint.

  • train_data_iterator – The training data iterator (for saving state if supported).

  • preprocess_common_state_dict_fn – Optional function to preprocess the common state dict before consistency checks in distributed checkpointing.

  • prebuilt_state_dict – Optional pre-built state dict. When provided, skips state dict generation and uses this directly. Used for low-memory save mode where factories are expanded and model deleted before save.

  • pg_collection – Optional ProcessGroupCollection. When provided, uses this instead of extracting from model. Required when model is empty (e.g., low-memory save).

  • module_name – Optional MegatronMIMO module name for per-module RNG state namespacing. When set, RNG ShardedObject keys are namespaced to avoid collisions across modules with identical (pp_rank, tp_rank) coordinates.

bridge.training.checkpointing.cleanup_old_non_persistent_checkpoint(
save_dir: str,
leave_ckpt_num: int = 1,
do_async: bool = False,
) None#

Clean up old non-persistent checkpoints in a directory.

Keeps the specified number of latest checkpoints and removes older ones. Currently only cleans up directories matching “iter_*”.

Parameters:
  • save_dir – The directory containing non-persistent checkpoints.

  • leave_ckpt_num – The number of latest checkpoints to keep.

  • do_async – If True, performs cleanup in a background thread.

bridge.training.checkpointing.maybe_save_dataloader_state(
model: list[megatron.core.transformer.MegatronModule] | megatron.core.transformer.MegatronModule,
train_iterator: Any,
iteration: int,
dataloader_save_path: str | None = None,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
) None#

Save the dataloader state if the iterator supports it.

Checks if the train_iterator has a save_state method and calls it.

Parameters:
  • train_iterator – The training data iterator.

  • iteration – The current training iteration.

  • dataloader_save_path – The path where the dataloader state should be saved.

bridge.training.checkpointing.save_tokenizer_assets(
tokenizer: megatron.bridge.training.tokenizers.tokenizer.MegatronTokenizer,
tokenizer_config: megatron.bridge.training.tokenizers.config.TokenizerConfig,
checkpoint_path: str,
) None#

Save tokenizer files to the checkpoint directory.

Always saves tokenizer files to ensure checkpoints are self-contained and portable. Handles both HuggingFace tokenizers and file-based tokenizers. Compatible with MultiStorageClient for cloud storage support.

Parameters:
  • tokenizer – The tokenizer instance to save.

  • tokenizer_config – The tokenizer configuration (used for file-based tokenizers).

  • checkpoint_path – The checkpoint directory path.

bridge.training.checkpointing._generate_model_state_dict(
model: list[megatron.core.transformer.MegatronModule],
model_sd_kwargs: Optional[dict[str, Any]] = None,
ckpt_format: str = 'torch_dist',
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
) dict[str, megatron.core.dist_checkpointing.mapping.ShardedStateDict]#

Generate the model subset of the state dictionary to be saved in a checkpoint.

Can be added to the full checkpoint state dictionary with dict.update().

Parameters:
  • model – The model module(s).

  • model_sd_kwargs – Metadata for model state dict generation.

  • ckpt_format – The checkpoint format being used.

Returns:

A dictionary containing the model state to be saved.

bridge.training.checkpointing.generate_state_dict(
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
rng_state: Optional[megatron.core.dist_checkpointing.mapping.ShardedObject],
iteration: Optional[int] = None,
optim_sd_kwargs: Optional[dict[str, Any]] = None,
model_sd_kwargs: Optional[dict[str, Any]] = None,
rerun_state: Optional[dict[str, Any]] = None,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection | None = None,
) dict[str, Any]#

Generate the state dictionary to be saved in a checkpoint.

Parameters:
  • cfg – The configuration container.

  • model – The model module(s).

  • optimizer – The optimizer instance.

  • opt_param_scheduler – The optimizer parameter scheduler instance.

  • rng_state – Collected RNG states as a ShardedObject.

  • iteration – The current training iteration.

  • optim_sd_kwargs – Additional keyword arguments for optimizer state dict generation.

  • model_sd_kwargs – Metadata for model state dict generation.

  • rerun_state – State dictionary from the rerun state machine.

Returns:

A dictionary containing the complete state to be saved.

bridge.training.checkpointing.preprocess_fsdp_dtensor_state_dict(
cfg,
raw_state_dict: dict[str, Any],
model: megatron.core.transformer.MegatronModule,
) dict[str, Any]#

Preprocess FSDP DTensor state dict before saving.

Handles:

  • FP8 extra state

  • SWiGLU weight splitting

  • GDN (Gated DeltaNet) fused projection splitting (in_proj / conv1d)

  • Expert parameter reindexing for Expert Parallel

  • Uneven DTensor preprocessing

Parameters:
  • cfg – Configuration object

  • raw_state_dict – The state dict to preprocess

  • model – The model instance

Returns:

Preprocessed state dict ready for FSDP DTensor checkpoint save

bridge.training.checkpointing.save_fsdp_dtensor_checkpoint(
checkpoint_path: str | pathlib.Path,
state_dict: dict[str, Any],
*,
cfg: Any,
model: megatron.core.transformer.MegatronModule,
storage_writer: Any | None = None,
barrier: bool = True,
) Any#

Preprocess and save an FSDP DTensor checkpoint with PyTorch DCP.

This exposes the Bridge preprocessing used by the trainer save path for external trainers that already own the training loop.

Parameters:
  • checkpoint_path – Directory to write when storage_writer is not provided.

  • state_dict – Raw FSDP DTensor state dict to save.

  • cfg – Configuration object passed through to FSDP DTensor preprocessing.

  • model – Model chunk used to infer FSDP DTensor preprocessing rules.

  • storage_writer – Optional PyTorch Distributed Checkpoint storage writer.

  • barrier – Whether to run a distributed barrier after the save.

Returns:

The value returned by torch.distributed.checkpoint.save.

bridge.training.checkpointing._create_fsdp_dtensor_storage_writer(
checkpoint_path: str,
) Any#
bridge.training.checkpointing._load_model_weights_from_checkpoint(
checkpoint_path: str,
model: list[megatron.core.transformer.MegatronModule],
fully_parallel_load: bool = False,
return_state_dict: bool = False,
dist_ckpt_strictness: Literal[assume_ok_unexpected, log_unexpected, log_all, raise_unexpected, raise_all, return_unexpected, return_all, ignore_all] = 'assume_ok_unexpected',
strict: bool = True,
) Optional[Union[megatron.core.dist_checkpointing.serialization.StateDict, tuple[megatron.core.dist_checkpointing.serialization.StateDict, set[str], set[str]]]]#

Load model weights from a checkpoint.

MCore distributed checkpoints from both Megatron Bridge and MegatronLM are supported. This function duplicates some logic from load_checkpoint() to simplify model loading for inference.

Parameters:
  • checkpoint_path – path to a distributed checkpoint.

  • model – The model module(s) to load weights into.

  • fully_parallel_load – Apply full load parallelization across DP.

  • return_state_dict – Skips loading state dict into model and returns model state dict itself. Default False.

  • dist_ckpt_strictness – Determine handling of key mismatch during checkpoint load.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

bridge.training.checkpointing.load_checkpoint(
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
strict: bool = True,
checkpointing_context: Optional[dict[str, Any]] = None,
skip_load_to_model_and_opt: bool = False,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
module_name: str | None = None,
) tuple[int, int]#

Load a model checkpoint.

Handles loading model state, optimizer state, scheduler state, RNG state, and other metadata based on the configuration and checkpoint type. Supports loading global distributed and local non-persistent checkpoints.

Parameters:
  • state – The GlobalState object.

  • model – The model module(s) to load state into.

  • optimizer – The optimizer instance to load state into.

  • opt_param_scheduler – The scheduler instance to load state into.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

  • checkpointing_context – Dictionary to store context across loads (e.g., strategies).

  • skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.

  • pg_collection – Optional ProcessGroupCollection. When provided, uses this instead of extracting from model via get_pg_collection(). Required for MegatronMIMO where model-level PG extraction may not reflect rank-local topology.

  • module_name – Optional MegatronMIMO module name for per-module RNG state namespacing.

Returns:

  • iteration: The training iteration number.

  • num_floating_point_operations_so_far: The total FLOPs computed so far.

Return type:

A tuple containing

bridge.training.checkpointing._deinterleave_glu_tensor(
tensor: torch.Tensor,
interleave_size: int,
) torch.Tensor#

De-interleave SwiGLU fc1 tensor along dim 0: block-interleaved -> contiguous [W_all, V_all].

Same layout for linear_fc1.weight (dim 0 + remaining dims) and linear_fc1.bias (dim 0 only). Interleaved format (dim 0): [W0:k, V0:k, Wk:2k, Vk:2k, …] with k = interleave_size.

bridge.training.checkpointing._interleave_glu_tensor(
tensor: torch.Tensor,
interleave_size: int,
) torch.Tensor#

Interleave SwiGLU fc1 tensor along dim 0: contiguous [W_all, V_all] -> block-interleaved.

bridge.training.checkpointing._is_swiglu_fc1_checkpoint_key(
key: str,
*,
include_routed_experts: bool = True,
include_shared_experts: bool = False,
include_dense: bool = True,
) bool#

True for selected SwiGLU linear_fc1 weights/biases.

Dense mlp.linear_fc1 is included only when USE_ACT_FUSION_FOR_DENSE=1. Shared experts are enabled separately by moe_shared_expert_glu_interleave_size before this helper is called.

bridge.training.checkpointing._apply_glu_interleave_to_tensor_data(
tensor: torch.Tensor,
interleave_size: int,
interleave: bool,
) torch.Tensor#

Run interleave or de-interleave on fc1 weight or bias (identical dim-0 layout).

bridge.training.checkpointing._process_state_dict_for_glu_interleaving(
model_state_dict: dict[str, Any],
interleave_size: int,
interleave: bool = True,
use_megatron_fsdp: bool = False,
include_routed_experts: bool = True,
include_shared_experts: bool = False,
include_dense: bool = True,
) dict[str, Any]#

Process GLU weights and biases in state dict for interleaving or de-interleaving.

Parameters:
  • model_state_dict – The state dict to process

  • interleave_size – The interleave size to use

  • interleave – If True, interleave from contiguous to interleaved (for loading). If False, de-interleave from interleaved to contiguous (for saving).

bridge.training.checkpointing._get_model_glu_interleave_sizes(
model: list[megatron.core.transformer.MegatronModule],
cfg: megatron.bridge.training.config.ConfigContainer,
) tuple[Optional[int], Optional[int]]#

Return routed/dense and shared-expert GLU interleave sizes for this model.

bridge.training.checkpointing._process_state_dict_for_model_glu_interleaving(
model_state_dict: dict[str, Any],
routed_interleave_size: Optional[int],
shared_interleave_size: Optional[int],
interleave: bool = True,
use_megatron_fsdp: bool = False,
) dict[str, Any]#

Apply GLU interleaving transforms for each model component that needs them.

bridge.training.checkpointing._load_model_state_dict(
module: torch.nn.Module,
state_dict: dict[str, Any],
strict: bool,
)#

Helper function to load state dict with fallback for missing extra states.

bridge.training.checkpointing._load_checkpoint_from_path(
load_dir: str,
state: megatron.bridge.training.state.GlobalState,
model: list[megatron.core.transformer.MegatronModule],
optimizer: Optional[megatron.core.optimizer.MegatronOptimizer],
opt_param_scheduler: Optional[Any],
strict: bool = True,
checkpointing_context: Optional[dict[str, Any]] = None,
skip_load_to_model_and_opt: bool = False,
ignore_ckpt_step: bool = False,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
module_name: str | None = None,
) tuple[int, int]#

Load a checkpoint from a given path.

Parameters:
  • load_dir – The directory containing the checkpoint.

  • state – The GlobalState object.

  • model – The model module(s) to load state into.

  • optimizer – The optimizer instance to load state into.

  • opt_param_scheduler – The scheduler instance to load state into.

  • strict – Whether to enforce strict loading (see torch.nn.Module.load_state_dict).

  • checkpointing_context – Dictionary to store context across loads (e.g., strategies).

  • skip_load_to_model_and_opt – If True, only loads metadata (iteration, rng) but skips loading state into model and optimizer modules.

  • ignore_ckpt_step – If True, ignores the ckpt_step config and loads latest checkpoint. Used when loading pretrained checkpoints in PEFT scenarios.

  • pg_collection – Optional ProcessGroupCollection. When provided, uses this instead of extracting from model via get_pg_collection(). Required for MegatronMIMO where model-level PG extraction may not reflect rank-local topology.

Returns:

  • iteration: The training iteration number.

  • num_floating_point_operations_so_far: The total FLOPs computed so far.

Return type:

A tuple containing

bridge.training.checkpointing.init_checkpointing_context(
checkpoint_config: megatron.bridge.training.config.CheckpointConfig,
) dict[str, Any]#

Initialize the checkpointing context, primarily for local checkpointing support.

If non_persistent_ckpt_type is set to “local”, this function sets up the LocalCheckpointManager and replication strategy based on the provided checkpoint_config.

Parameters:

checkpoint_config – The checkpoint configuration object.

Returns:

A dictionary containing the checkpointing context. This will include a local_checkpoint_manager if local checkpointing is enabled, otherwise it will be an empty dictionary.

Raises:

RuntimeError – If local checkpointing is configured but the nvidia_resiliency_ext module is not found.

bridge.training.checkpointing.apply_peft_adapter_filter_to_state_dict(
state_dict: dict[str, Any],
peft_config: megatron.bridge.peft.base.PEFT,
) dict[str, Any]#

Filter state dict to contain only PEFT adapter parameters in model sections.

This function takes a complete state dict (generated by generate_state_dict) and filters it to retain only PEFT adapter parameters for checkpoint saving. Transformer Engine ._extra_state entries are excluded even when they live under adapter modules because adapter checkpoint loading already tolerates missing extra-state keys and the objects can collide under expert-parallel shared-adapter layouts. Follows the same key logic pattern as generate_state_dict for consistency.

Parameters:
  • state_dict – Complete state dict from generate_state_dict()

  • peft_config – PEFT configuration for filtering logic

Returns:

Filtered state dict containing only adapter parameters in model weights, while preserving all non-model metadata (checkpoint_version, iteration, etc.)

bridge.training.checkpointing._is_model_section(section_key: str) bool#

Check if a checkpoint section contains model parameters.

Model sections are named:

  • “model” (single model)

  • “model0”, “model1”, etc. (pipeline parallel models)

Non-model sections include: “optimizer”, “iteration”, “checkpoint_version”, etc.

bridge.training.checkpointing._resolve_checkpoint_iteration(
load_dir: str | None,
ckpt_step_override: int | None,
) tuple[int, bool]#

Resolve which checkpoint iteration to load.

This function determines the checkpoint iteration by:

  1. If load_dir is already a specific iteration directory (detected via is_checkpoint_iteration_directory), return _DIRECT_ITERATION_DIR_SENTINEL so the caller uses load_dir directly without sub-directory resolution.

  2. If ckpt_step_override is specified, validate the corresponding iter_* sub-directory exists and return that integer directly.

  3. Otherwise, read from the tracker file (latest_train_state.pt or legacy latest_checkpointed_iteration.txt).

Parameters:
  • load_dir – Base checkpoint directory, or a specific iteration directory.

  • ckpt_step_override – User-specified iteration override (from ckpt_step config).

Returns:

  • iteration = _DIRECT_ITERATION_DIR_SENTINEL means load_dir is an iteration directory and should be used as-is.

  • iteration = -1 means no checkpoint was found.

  • Any other non-negative value is the resolved iteration number.

Return type:

Tuple of (iteration, release) where

bridge.training.checkpointing._get_non_persistent_iteration(
non_persistent_global_dir: str,
non_persistent_ckpt_type: Optional[typing.Literal[global,
local]]=None,
checkpointing_context: Optional[dict[str,
typing.Any]]=None,
) int#

Get iteration number from non-persistent checkpoint.

bridge.training.checkpointing._load_non_persistent_base_checkpoint(
non_persistent_global_dir: str,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool,
sharded_state_dict: Optional[dict[str, Any]],
non_persistent_iteration: int,
checkpointing_context: Optional[dict[str, Any]] = None,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) tuple[dict[str, Any], str, bool, bridge.training.checkpointing.CheckpointType]#

Load the base state_dict from a non-persistent distributed checkpoint.

bridge.training.checkpointing._load_global_dist_base_checkpoint(
load_dir: str,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool,
sharded_state_dict: Optional[dict[str, Any]],
iteration: Optional[int],
release: bool,
checkpoint_path_override: Optional[str] = None,
checkpointing_context: Optional[dict[str, Any]] = None,
is_megatron_mimo: bool = False,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) tuple[dict[str, Any], str, bool, bridge.training.checkpointing.CheckpointType]#

Load the base state_dict from the given directory containing the global distributed checkpoint.

Parameters:

checkpoint_path_override – If provided, use this path directly instead of constructing it from load_dir / iteration. Used when load_dir is already a specific iteration directory.

bridge.training.checkpointing._load_base_checkpoint(
load_dir: Optional[str],
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool = False,
sharded_state_dict: Optional[dict[str, Any]] = None,
checkpointing_context: Optional[dict[str, Any]] = None,
ignore_ckpt_step: bool = False,
cfg: Optional[megatron.bridge.training.config.ConfigContainer] = None,
is_megatron_mimo: bool = False,
*,
pg_collection: megatron.core.process_groups_config.ProcessGroupCollection,
) tuple[Optional[dict[str, Any]], str, bool, Optional[bridge.training.checkpointing.CheckpointType]]#

Load the base state_dict from the given directory.

Parameters:
  • load_dir – Directory containing the checkpoint.

  • ckpt_cfg – Checkpoint configuration.

  • rank0 – If True, only load rank 0 metadata.

  • sharded_state_dict – State dict for distributed loading.

  • checkpointing_context – Context for caching strategies.

  • ignore_ckpt_step – If True, ignore ckpt_step and load latest. Used for pretrained checkpoints.

  • cfg – Full configuration object (needed for FSDP DTensor preprocessing).

Returns:

Tuple of (state_dict, checkpoint_name, release, ckpt_type).

bridge.training.checkpointing.load_fsdp_dtensor_checkpoint(
load_dir: str,
ckpt_cfg: megatron.bridge.training.config.CheckpointConfig,
rank0: bool,
sharded_state_dict: Optional[dict[str, Any]],
iteration: Optional[int],
release: bool = False,
checkpoint_path_override: Optional[str] = None,
checkpointing_context: Optional[dict[str, Any]] = None,
cfg: Any | None = None,
) tuple[dict[str, Any], str, bool, bridge.training.checkpointing.CheckpointType]#

Load the base state dict from an FSDP DTensor checkpoint.

This function preprocesses the state dict (handling expert parameters, SWiGLU, FP8) before loading from checkpoint, matching the preprocessing applied during save.

Parameters:
  • load_dir – Directory containing the checkpoint.

  • ckpt_cfg – Checkpoint configuration.

  • rank0 – If True, only load rank 0 metadata.

  • sharded_state_dict – State dict for distributed loading.

  • iteration – The checkpoint iteration to load.

  • release – Whether this is a release checkpoint.

  • checkpoint_path_override – If provided, use this path directly instead of constructing it from load_dir / iteration.

  • checkpointing_context – Context for caching strategies.

  • cfg – Full configuration object (needed for preprocessing).

Returns:

Tuple of (state_dict, checkpoint_name, release, ckpt_type).

bridge.training.checkpointing._build_sharded_state_dict_metadata(
use_distributed_optimizer: bool,
cfg: megatron.bridge.training.config.CheckpointConfig,
) dict#

Builds metadata used for sharded_state_dict versioning.

The whole content metadata is passed to shared_state_dict model and optimizer methods and therefore affects only the logic behind sharded_state_dict creation. The content metadata should be minimalistic, ideally flat (or with a single nesting level) and with semantically meaningful flag names (e.g. distrib_optim_sharding_type). In particular, a simple integer (or SemVer) versioning flag (e.g. metadata['version'] = 3.4) is discouraged, because the metadata serves for all models and optimizers and it’s practically impossible to enforce a linearly increasing versioning for this whole space.

Parameters:
  • use_distributed_optimizer – Whether to use distributed optimizer.

  • cfg – CheckpointConfig.

bridge.training.checkpointing._get_train_state_from_state_dict(
state_dict: dict[str, Any],
) megatron.bridge.training.state.TrainState#

Create a TrainState from the state dict from a Megatron-LM checkpoint.