User Guide#
Quick Start#
This guide for Megatron Core walks you through the following tasks:
Initialize Megatron Core on two GPUS.
Build a GPT model with a tensor model parallel size of two and a pipeline parallel size of one.
Train the model for five iterations using Megatron Core schedules.
Save the model using the distributed checkpoint format.
Load the model.
Note
The following sample was tested using Megatron Core version 0.8.0 and NGC PyTorch Container version 24.02.
Set Up Your Environment#
Run a new Docker container.
Clone the Megatron GitHub repo in it.
docker run --ipc=host --shm-size=512m --gpus 2 -it nvcr.io/nvidia/pytorch:24.02-py3 git clone https://github.com/NVIDIA/Megatron-LM.git cd Megatron-LM pip install -U setuptools packaging pip install --no-build-isolation .[dev]
For a more comprehensive overview of different installation methods, refer to the Installation Guide
Write Your First Training Loop#
In this task, you create a sample GPT model split across tensors (Tensor model parallel) on two GPUS, and run a forward pass through it using a MockGPT dataset helper class that was created in Megatron Core.
Note
All of the following steps are in the run_simple_mcore_train_loop.py script. To run the run_simple_mcore_train_loop.py
script:
PYTHONPATH=$PYTHON_PATH:./megatron torchrun --nproc-per-node 2 examples/run_simple_mcore_train_loop.py
Initialize the distributed training and set up the model parallel:
The following utility, when called, initializes your distributed setup:
import os import torch from megatron.core import parallel_state def initialize_distributed(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1): # Torch setup for distributed training rank = int(os.environ['LOCAL_RANK']) world_size = torch.cuda.device_count() torch.cuda.set_device(rank) torch.distributed.init_process_group(world_size=world_size, rank=rank) # Megatron core distributed training initialization parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
Set up the GPT model:
Use the following code snippet to create a GPT model. For a list of other configurations that you can pass into the model, open and review transformer_config.py.
from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec def model_provider(): """Build the model.""" transformer_config = TransformerConfig( num_layers=2, hidden_size=12, num_attention_heads=4, use_cpu_initialization=True, pipeline_dtype=torch.float32) gpt_model = GPTModel( config=transformer_config, transformer_layer_spec=get_gpt_layer_local_spec(), vocab_size=100, max_sequence_length=64) return gpt_model
Set up the GPT mock dataset:
Use the following code snippet to explore the mock dataset utility.
To train the model using your data, use the
GPTDataset
class in gpt_dataset.py.To find more information about Megatron Core data pipeline, see the data pipeline readme.md.
import torch from torch.utils.data import DataLoader from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset from megatron.training.tokenizer.tokenizer import _NullTokenizer from megatron.core.datasets.utils import compile_helpers _SEQUENCE_LENGTH = 64 def get_train_data_iterator(): if torch.distributed.is_available() and torch.distributed.is_initialized(): if torch.distributed.get_rank() == 0: compile_helpers() torch.distributed.barrier() else: compile_helpers() config = GPTDatasetConfig( random_seed=0, sequence_length=_SEQUENCE_LENGTH, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False, tokenizer=_NullTokenizer(vocab_size=_SEQUENCE_LENGTH), ) datasets = BlendedMegatronDatasetBuilder( MockGPTDataset, [1000, None, None], lambda: True, config ).build() train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True) train_iterator = iter(train_dataloader) return train_iterator
Add a forward step function:
Megatron Core uses schedules.py to run the model. Define a forward step function that takes the data iterator and the model as input and produces the output tensor and a loss function.
from functools import partial def forward_step_func(data_iterator, model): def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): losses = output_tensor.float() loss_mask = loss_mask.view(-1).float() loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # If you have data parallel reduce loss across data parallel groups. # If pipeline parallel, loss computation is done only in last stage. return loss, {'lm loss': loss} data = next(data_iterator) tokens = data['tokens'].to(device) attention_mask = data['attention_mask'].to(device) position_ids = data['position_ids'].to(device) labels = data['labels'].to(device) loss_mask = data['loss_mask'].to(device) output_tensor = model(tokens, position_ids, attention_mask, labels=labels) return output_tensor, partial(loss_func, loss_mask)
Define your load and save distributed checkpoints:
Megatron Core uses distributed checkpoints for loading and saving models. This allows you to convert the model from one parallel setting to another when you load it. For example, a model trained with tensor parallel size
2
, can be loaded again as a tensor model with parallel size4
.from megatron.core import dist_checkpointing def save_distributed_checkpoint(checkpoint_path, gpt_model): sharded_state_dict = gpt_model.sharded_state_dict(prefix='') dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) def load_distributed_checkpoint(checkpoint_path, gpt_model): sharded_state_dict=gpt_model.sharded_state_dict(prefix='') checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path) gpt_model.load_state_dict(checkpoint) return gpt_model
Add the main function:
The following code snippet is the main function that needs to go into your script. It runs the model for five iterations, saves, and loads it.
from pathlib import Path from torch.optim import Adam from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed if __name__ == "__main__": initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1) model_parallel_cuda_manual_seed(123) gpt_model = model_provider() device = torch.device("cuda") gpt_model.to(device) optim = Adam(gpt_model.parameters()) train_iterator = get_train_data_iterator() forward_backward_func = get_forward_backward_func() # Running the model for 5 iterations for _ in range(5): optim.zero_grad() losses_reduced = forward_backward_func( forward_step_func=forward_step_func, data_iterator=train_iterator, model=gpt_model, num_microbatches=1, seq_length=64, micro_batch_size=8, decoder_seq_length=64, forward_only=False) optim.step() print(f'Losses reduced : {losses_reduced}') # Saving the model save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt') # Loading the model gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt') gpt_model.to(device) print('Successfully loaded the model')
Review Advanced Examples#
To review more advanced examples, explore pretrain_gpt.py. pretrain_gpt.py
has more complex training loops and includes the following Megatron Core features:
pipeline parallel
context parallel
rope embeddings
mixture of experts
Installing Megatron Core#
Megatron Core maintains a lightweight installation and minimizes conflicts by keeping its core dependencies (torch, numpy, and packaging) to a minimum. This is achieved through “import-guarding,” where additional dependencies are only verified and loaded when the specific features that require them are actively used.
There are two ways of extending Megatron Core with its requirements that unlock the performance required for large-scale distributed training: Using a NGC PyTorch container or installing from source. While the installation into a NGC PyTorch container may simplify experience by shipping with pre-installed performance optimized dependencies, a source installation gives more freedom and customization options. In the following sections, we will have a look at both.
Before we dive into the fully-featured installation process, let’s have a quick detour to the basic installation process.
Basic installation#
Megatron Core ships released wheels to PyPi bi-monthly.
pip install megatron-core
Additionally, there are weekly pre-release wheels:
pip install --pre megatron-core
Specific commits can be installed from the official NVIDIA/Megatron-LM GitHub repository:
pip install git+https://github.com/NVIDIA/Megatron-LM.git@${COMMIT}
Each installation method has complete feature-parity for a selected version.
Installation inside a NGC PyTorch container#
The NGC PyTorch container includes NVIDIA system-level dependencies such as NCCL, CUDA, and cuDNN, which provide lower-level GPU support. It also comes with Python libraries specifically optimized and compiled for these software versions. Two key libraries for Megatron Core are a performance-optimized version of PyTorch, which incorporates advanced performance features not yet available in upstream Meta PyTorch at the time of release, and NVIDIA Transformer Engine.
To get started, run the following commands:
# On the host machine
docker run --rm -it --gpus all nvcr.io/nvidia/pytorch:XX.YY-py3
# Inside the container
pip install megatron-core
:bulb: For the most recently tested NGC PyTorch image visit the file .gitlab/stages/01.build.yml
. The stable release branches are named core_rX.Y.Z
.
For a complete installation of Megatron Core with all features, follow these steps:
# Inside the container
pip install -U setuptools packaging
pip install --no-build-isolation megatron-core[dev]
:bulb: We add the argument --no-build-isolation
since many dependencies like transformer-engine
need to be aligned with the pre-installed CUDA and torch version. By removing Python’s default build isolation, we expose the installation process to the host and its software versions. As a result, the compiler is able to build the source specific to those versions.
This command also installs libraries such as flash-infer, mamba-ssm, and grouped-gemm. Depending on your CUDA and PyTorch environment versions, the installation could take anywhere from a few seconds to over thirty minutes.
This situation arises because most dependencies offer a wide array of pre-compiled wheels, compatible with various combinations of CUDA, PyTorch, and their respective library versions. When a suitable pre-compiled wheel is located, installation is nearly instantaneous. Conversely, if no such wheel exists, the local host machine must compile the source-distributed wheel into a binary. Feel free to raise an issue at NVIDIA/Megatron-LM if you identify such an issue and we will check if we can accelerate the installation of your use case.
The dev extra-requires option includes all dependencies validated by Megatron-LM’s internal CI. This may be more extensive than necessary for your specific needs. You can review the requirements file at NVIDIA/Megatron-LM to select only the dependencies relevant to your use case.
# Inside the container
# Example to only install support for hybrid models
pip install --no-build-isolation \
megatron-core \
"mamba-ssm~=2.2" \
"causal-conv1d~=1.5" \
"nv-grouped-gemm~=1.1"
Installation inside a vanilla Ubuntu container#
While pre-configured NGC PyTorch containers are often suitable, some use cases may necessitate a custom container. Other noteworthy NGC containers include NGC cuda or NGC cuda-dl-base.
For educational purposes, the following section details a “bare-metal” installation within a plain Ubuntu environment. This demonstration aims to provide users with sufficient knowledge to manage installations within pre-configured NGC containers.
Preliminary requirements#
The software stack used for this guide was configured with Ubuntu 24.04, Cuda 12.8, cuDNN 9.1, Python 3.12, PyTorch 2.8, and Transformer Engine 2.5.0.
Starting the container#
docker run --rm -it --entrypoint bash ubuntu:24.04
Installing Python#
We will install Python 3.12 development headers, Python 3.12 venv
for virtual environment support, and pip
for installing additional packages. For convenience, update-alternatives
will be used to set python
as the default command instead of python3.12
.
apt-get update
apt-get install -y software-properties-common
add-apt-repository ppa:deadsnakes/ppa -y
apt-get install -y python3.12-dev python3.12-venv python3-pip
update-alternatives --install /usr/bin/python python /usr/bin/python3 1
Installing Cuda-toolkit#
To establish a clean CUDA development environment on Ubuntu 24.04, we begin by installing essential tools such as wget
, curl
, git
, and cmake
for software downloading and building. We then remove any existing CUDA/NVIDIA repositories to prevent conflicts. Subsequently, NVIDIA’s official CUDA keyring is retrieved and installed, which securely integrates the latest CUDA repository into the system. The final step involves installing the CUDA Toolkit 12.8 (comprising the compiler, runtime, and libraries), cuDNN 9 (for GPU-accelerated deep learning primitives), and CUTLASS (a template library for high-performance matrix operations).
# Install tools
apt-get update
apt-get install -y wget curl git cmake
rm /etc/apt/sources.list.d/cuda*.list || true
rm /etc/apt/sources.list.d/nvidia-cuda.list || true
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
dpkg -i cuda-keyring_1.1-1_all.deb
rm cuda-keyring_1.1-1_all.deb
apt-get update
apt-get install -y cuda-toolkit-12-8 \
libcudnn9-cuda-12 \
libcutlass-dev
Python libraries#
Finally, we can set up a virtual Python environment and run a feature-complete installation of Megatron Core:
python -m venv .venv
source .venv/bin/activate
# Run this first to install basic dependencies and build-requirements for step two
pip install megatron-core
# Run this for the feature-complete install
pip install --no-build-isolation megatron-core[dev]
Testing correctness#
After successful installation of Megatron Core and its dependencies, we can validate the environment by the following commands.
For testing Megatron Core, the following command should be successful:
import megatron.core
print(megatron.core.__version__)
For testing Transformer Engine, the following command should be successful:
import transformer_engine
import transformer_engine.pytorch
print(transformer_engine.__version__)
Summary#
This guide has aimed to facilitate the installation and operational understanding of Megatron Core
, including its continuous integration and deployment mechanisms. We trust this resource will prove valuable in the seamless development of large-scale LLMs using Megatron Core. Your insights and feedback are highly valued as we continue to enhance this tool. We encourage you to share your experiences, report any issues, or propose improvements by engaging with our GitHub community at github.com/NVIDIA/Megatron-LM/issues. Your contributions are instrumental in shaping the future development of Megatron Core.
Multi-Storage Client (MSC) Integration#
The Multi-Storage Client (MSC) provides a unified interface for reading datasets and storing checkpoints from both filesystems (e.g., local disk, NFS, Lustre) and object storage providers such as S3, GCS, OCI, Azure, AIStore, and SwiftStack.
This guide will walk you through how to:
How to install and configure MSC
How to train models directly using datasets in object storage
How to save and load model checkpoints to/from object storage
Installation#
MSC is vended as the multi-storage-client
package on PyPI.
The base client supports POSIX file systems by default, but there are extras for each storage service which provide the necessary package dependencies for its corresponding storage provider.
# POSIX file systems.
pip install multi-storage-client
# AWS S3 and S3-compatible object stores.
pip install "multi-storage-client[boto3]"
# Google Cloud Storage (GCS).
pip install "multi-storage-client[google-cloud-storage]"
Configuration File#
MSC uses a YAML configuration file to define how it connects to object storage systems. This design allows you to specify one or more storage profiles, each representing a different storage backend or bucket. MSC keeps your training scripts clean and portable by centralizing details in a config file. There is no need to hardcode access keys, bucket names, or other provider-specific options directly into your code.
Here’s an example configuration:
profiles:
my-profile:
storage_provider:
type: s3
options:
# Set the bucket/container name as the base_path
base_path: my-bucket
region_name: us-west-2
# Optional credentials (can also use environment variables for S3)
credentials_provider:
type: S3Credentials
options:
access_key: ${AWS_ACCESS_KEY}
secret_key: ${AWS_SECRET_KEY}
cache:
size: 500G # Maximum cache size
location: /tmp/msc_cache # Cache directory on filesystem
To tell MSC where to find this file, set the following environment variable before running your Megatron-LM script:
export MSC_CONFIG=/path/to/msc_config.yaml
MSC URL Format#
MSC uses a custom URL scheme to identify and access files across different object storage providers. This scheme makes it easy to reference data and checkpoints without worrying about the underlying storage implementation. An MSC URL has the following structure:
msc://<profile-name>/<path/to/object>
Components:
msc://
This is the scheme identifier indicating the path should be interpreted by the Multi-Storage Client.<profile-name>
This corresponds to a named profile defined in your YAML configuration file under the profiles section. Each profile specifies the storage provider (e.g., S3, GCS), credentials, and storage-specific options such as the bucket name or base path.<path/to/object>
This is the logical path to the object or directory within the storage provider, relative to the base_path configured in the profile. It behaves similarly to a path in a local filesystem but maps to object keys or blobs in the underlying storage system.
Example:
Given the following profile configuration:
profiles:
my-profile:
storage_provider:
type: s3
options:
base_path: my-bucket
The MSC URL:
msc://my-profile/dataset/train/data.bin
is interpreted as accessing the object with the key dataset/train/data.bin
inside the S3 bucket named my-bucket
. If this were a GCS or OCI profile instead, MSC would apply the appropriate backend logic based on the profile definition, but your code using the MSC URL would remain unchanged.
This abstraction allows training scripts to reference storage resources uniformly—whether they’re hosted on AWS, GCP, Oracle, or Azure—just by switching profiles in the config file.
Train from Object Storage#
To train with datasets stored in object storage, use an MSC URL with the --data-path
argument. This URL references a dataset stored under a profile defined in your MSC configuration file.
In addition, Megatron-LM requires the --object-storage-cache-path
argument when reading from object storage. This path is used to cache the .idx
index files associated with IndexedDataset, which are needed for efficient data access.
python pretrain_gpt.py \
--object-storage-cache-path /path/to/object_store_cache \
--data-cache-path /path/to/data_cache \
--data-path msc://my-profile/datasets/text_document \
--no-mmap-bin-files
Note
All four arguments must be provided when training with datasets in object storage using MSC.
Save and Load Checkpoints from Object Storage#
MSC can be used to save and load model checkpoints directly from object storage by specifying MSC URLs for the --save
and --load
arguments. This allows you to manage checkpoints in object storage.
python pretrain_gpt.py \
--save msc://my-profile/checkpoints \
--load msc://my-profile/checkpoints \
--save-interval 1000
Note
Only the torch_dist
checkpoint format is currently supported when saving to or loading from MSC URLs.
Disable MSC#
By default, MSC integration is automatically enabled when the multi-storage-client
library is installed. MSC is also used for regular filesystem paths (like /filesystem_mountpoint/path
in --data-path
, --save
, or --load
) even when not using explicit MSC URLs. MSC functions as a very thin abstraction layer with negligible performance impact when used with regular paths, so there’s typically no need to disable it. If you need to disable MSC, you can do so using the --disable-msc
flag:
python pretrain_gpt.py --disable-msc
Performance Considerations#
When using object storage with MSC, there are a few important performance implications to keep in mind:
Reading Datasets
Reading training datasets directly from object storage is typically slower than reading from local disk. This is primarily due to:
High latency of object storage systems, especially for small and random read operations (e.g., reading samples from .bin files).
HTTP-based protocols used by object stores (e.g., S3 GET with range requests), which are slower than local filesystem I/O.
To compensate for this latency, it is recommended to increase the number of data loading workers using the --num-workers
argument in your training command:
python pretrain_gpt.py --num-workers 8 ...
Increasing the number of workers allows more parallel reads from object storage, helping to mask I/O latency and maintain high GPU utilization during training.
Checkpoint Loading
When using MSC to load checkpoints from object storage, it is important to configure the cache section in your MSC configuration file. This local cache is used to store downloaded checkpoint data and metadata, which significantly reduces load time and memory usage.
Example:
cache:
size: 500G
location: /tmp/msc_cache
For optimal performance, configure the cache directory on a high-speed local storage device such as an NVMe SSD.
Additional Resources and Advanced Configuration#
Refer to the MSC Configuration Documentation for complete documentation on MSC configuration options, including detailed information about supported storage providers, credentials management, and advanced caching strategies.
MSC supports collecting observability metrics and traces to help monitor and debug data access patterns during training. These metrics can help you identify bottlenecks in your data loading pipeline, optimize caching strategies, and monitor resource utilization when training with large datasets in object storage. For more information about MSC’s observability features, see the MSC Observability Documentation.
MSC offers an experimental Rust client that bypasses Python’s Global Interpreter Lock (GIL) to significantly improve performance for multi-threaded I/O operations. The Rust client supports AWS S3, SwiftStack, and Google Cloud Storage, enabling true concurrent execution for much better performance compared to the Python implementation. To enable it, add rust_client: {}
to your storage provider configuration. For more details, see the MSC Rust Client Documentation.