Checkpointing#
The bridge.training.config.CheckpointConfig
controls model checkpointing behavior, including saving and loading checkpoints, checkpoint formats, and various optimization features.
Overview#
Megatron Bridge uses Megatron Coreβs distributed checkpointing system, which is designed for large-scale training across multiple GPUs and nodes. The distributed checkpoint approach saves the state of a distributed training job by sharding checkpoint data across multiple files, reducing memory overhead and improving GPU utilization during save/load operations.
Distributed Checkpointing Benefits#
Memory Efficiency: Instead of gathering all model parameters and optimizer states on a single rank, distributed checkpointing saves data directly from each rank, significantly reducing memory requirements during checkpointing.
Parallelism Flexibility: The system provides flexibility to resume training using different parallelism strategies. You can change tensor parallelism, pipeline parallelism, or data parallelism sizes between checkpoint save and load operations.
Scalability: Handles all types of parallelism including:
Data Parallelism (DP): Replicates the model across multiple GPUs with different data batches
Tensor Parallelism (TP): Distributes individual layer parameters across GPUs
Pipeline Parallelism (PP): Assigns consecutive layers to different GPUs
Context Parallelism (CP): Shards tensors along the sequence dimension for long sequences
Expert Parallelism (EP): Distributes MoE expert weights across GPUs
Performance: The distributed optimizer shards optimizer states and master parameters across data-parallel ranks instead of replicating them, reducing memory usage and communication overhead.
Save Configuration#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Output directory to save checkpoints to |
|
|
|
Number of iterations between persistent checkpoint saves |
|
|
|
Whether to save optimizer state |
|
|
|
Whether to save random number generator state |
Asynchronous Saving#
Asynchronous saving allows training to continue while checkpoint data is persisted to disk in the background, reducing the impact of checkpointing on training throughput.
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Enable asynchronous checkpoint saving (requires |
Load Configuration#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Directory containing a model checkpoint to load |
|
|
|
Whether to load optimizer state from checkpoint |
|
|
|
Whether to load random number generator state from checkpoint |
|
|
|
Load main parameters from checkpoint (use with |
|
|
|
Specific checkpoint step to load from |
|
|
|
Exit if specified checkpoint is not found instead of random initialization |
|
|
|
Handling of key mismatches during distributed checkpoint load |
Checkpoint Loading Strictness#
When loading distributed checkpoints, there may be mismatches between the keys in the saved checkpoint and what the current model expects. This can happen when resuming training with different parallelism settings, model configurations, or software versions. The dist_ckpt_strictness
parameter controls how these mismatches are handled:
assume_ok_unexpected
: Assume unexpected keys are acceptable (default, most permissive)log_unexpected
: Log unexpected keys but continue loadinglog_all
: Log all key mismatches for debuggingraise_unexpected
: Raise error on unexpected keys (stricter validation)raise_all
: Raise error on any key mismatch (strictest validation)return_unexpected
: Return information about unexpected keysreturn_all
: Return information about all key mismatchesignore_all
: Ignore all key mismatches completely
Fine-tuning Configuration#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Directory containing pretrained model checkpoint for fine-tuning |
Checkpoint Format#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Checkpoint format (PyTorch distributed checkpoint format) |
Performance Optimizations#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Apply full save parallelization across data parallel ranks |
|
|
|
Apply full load parallelization across data parallel ranks |
|
|
|
Assume constant model/optimizer structure over successive checkpoint saves for performance optimizations |
Checkpoint Contents#
The checkpoint includes the following components when using the torch_dist
checkpoint format:
Model parameters and optimizer states: Stored across
.distcp
files to support distributed training.Training state: Captures the current iteration count, number of consumed samples, and the state of the learning rate scheduler.
Configuration: Serialized as a YAML file (
run_config.yaml
) containing the completeConfigContainer
.Dataloader states: Ensures deterministic resumption of data iteration.
Metadata: Used for validating and correctly loading the checkpoint.
Megatron Bridge creates checkpoints with the following directory structure:
checkpoint_dir/
βββ latest_train_state.pt # Latest training state (top-level)
βββ iter_N/ # Checkpoint at iteration N
β βββ __0_0.distcp # Distributed checkpoint shards: maps to PyTorch DCP weights format
β βββ __0_1.distcp # Contains model parameters, optimizer states
β βββ __1_0.distcp
β βββ __1_1.distcp
β βββ ...
β βββ .metadata # PyTorch DCP checkpoint metadata
β βββ common.pt # MCore dist ckpt states saved from rank 0
β βββ metadata.json # MCore dist ckpt metadata
β βββ run_config.yaml # Serialized ConfigContainer
β βββ train_state.pt # Number of steps, consumed samples, etc
β βββ dataloader_state/ # Data iterator states
β β βββ train_dataloader_dprank000.pt # DP rank 0 dataloader state
β β βββ train_dataloader_dprank001.pt # DP rank 1 dataloader state
β β βββ train_dataloader_dprank002.pt # DP rank 2 dataloader state
β β βββ ... # One file per DP rank
Local Checkpointing#
Local checkpointing saves model checkpoints directly to storage on each node (e.g., local SSDs or RAM disks), instead of relying solely on a shared network filesystem. This approach can significantly speed up the saving process and reduce the load on shared storage infrastructure.
Local checkpointing leverages the NVIDIA Resiliency Extension and provides several key features:
Local Saving: Each node saves its part of the checkpoint locally, reducing network I/O and improving save performance.
Synchronous and Asynchronous Support: Saving can happen synchronously or asynchronously, mirroring the configuration used for global checkpoints.
Automatic Cleanup: Handles the removal of outdated or incomplete local checkpoints automatically.
Optional Replication: For multi-node jobs, checkpoints are replicated to other nodes to allow recovery even if a node fails after saving. Single-node jobs do not use replication.
Automated Loading: When resuming, the framework automatically finds the latest valid checkpoint, comparing local and global checkpoints, and retrieves any needed parts across nodes.
Non-Persistent Checkpointing Configuration#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Iterations between non-persistent saves |
|
|
|
Type of non-persistent checkpointing |
|
|
|
Directory for global non-persistent checkpoints |
|
|
|
Directory for local non-persistent checkpoints |
|
|
|
Algorithm for local non-persistent checkpointing |
Replication and Fault Tolerance#
Parameter |
Type |
Default |
Description |
---|---|---|---|
|
|
|
Enable replication of local checkpoints across ranks |
|
|
|
Spacing between ranks storing replicas |
|
|
|
Number of machines storing replica of each rankβs data |