Data Loading at Scale#
This guide covers how Megatron’s data pipeline works and how to configure it for efficient training at 256 nodes and beyond. At this scale, the primary bottlenecks are index building and barrier synchronization – not raw data bandwidth.
How Data Loading Works#
Understanding the architecture helps explain why specific flags matter.
Megatron builds three index arrays for each dataset: a document index (shuffled document order), a sample index (mapping samples to document offsets), and a shuffle index (final sample permutation). This happens once during initialization:
Rank 0 builds all three indices and writes them to a cache directory as
.npyfiles.All ranks synchronize at a
torch.distributed.barrier().All other ranks load the cached indices via memory-mapped reads (
numpy.load(mmap_mode='r')).
After initialization, data access is read-only and lock-free. Each data-parallel rank consumes a disjoint subset of samples, and no cross-rank coordination is needed during training because all ranks derive the same deterministic permutation from a shared random seed.
The Problem at 256+ Nodes#
Three things break down at large node counts:
Barrier synchronization: All ranks block while rank 0 builds indices. On a 512-node job, this means 4,095 GPUs sit idle.
Simultaneous memory-mapping: All ranks
mmapthree large.npyfiles at once after the barrier, causing a burst of page faults and I/O.
Baseline: Establish Maximum Achievable Performance#
Before tuning data loading, establish a performance ceiling by running with --mock-data. This bypasses the data pipeline entirely and shows the maximum throughput your configuration can achieve without any dataloader overhead. The gap between --mock-data performance and real-data performance tells you exactly how much time the dataloader is costing you.
Recommended Configuration#
Step 1: Consolidate dataset files#
A common issue at scale is having datasets split across many small file prefixes. Thousands of 100 MB files perform significantly worse than tens of 10 GB+ files, both for building dataset caches and for runtime file access.
Use the merge tool to consolidate datasets stored as many small prefixes in one directory:
python tools/merge_datasets.py \
--input /path/to/input-directory \
--output-prefix /path/to/output/merged
Target at least 10 GB per file. This reduces the number of file descriptors, metadata lookups, and index-building work at initialization.
Step 2: Pre-build the dataset cache#
Build the GPT dataset cache as a separate step before training. This avoids the usual “rank 0 builds, everyone else waits” startup path and is the recommended workflow for large jobs:
python tools/prepare_cache.py \
--data-path <your-data-config> \
--split 99,1,0 \
--data-cache-path /path/to/cache \
--global-batch-size <global-batch-size> \
--seq-length <seq-length> \
...
If your later training job does not set --global-batch-size, or you are preparing the cache on a machine that does not match the future training topology, also pass:
--prepare-cache-world-size <future-world-size>
This keeps the prepared cache aligned with the sample counts expected by training.
Step 3: Optionally pre-build per-dataset metadata#
When blending many datasets, generate the --per-dataset-sequences-path JSON ahead of time to avoid one metadata read per file prefix at startup:
python tools/build_sequences_per_dataset.py \
--data-path <your-data-config> \
--per-dataset-sequences-path sequences.json
Step 4: Launch training with optimized data loading#
Once the cache is ready, enable the fast-path flags:
torchrun --nproc_per_node=8 --nnodes=512 ... pretrain_gpt.py \
--dataloader-fast-cache-load \
--dataloader-defer-npy-index-mmap \
--per-dataset-sequences-path sequences.json \
--data-cache-path /path/to/cache \
--num-workers 2 \
...
Flag reference#
Flag |
Default |
Recommendation |
What it does |
|---|---|---|---|
|
off |
On |
Skips the rank-0 barrier by assuming the cache already exists. All ranks build their dataset views in parallel. This is the single biggest win at scale. |
|
off |
On |
Defers memory-mapping of |
|
None |
Set when blending many datasets |
Points to a JSON file mapping each dataset path to its |
|
None |
Set |
Directory where index |
|
2 |
Keep as small as necessary |
Number of DataLoader worker processes. The goal is to satisfy: time to process a batch > time to prepare a batch. This hides dataloader work behind the training step. Increasing beyond what’s needed wastes CPU and memory. |
|
mmap on |
Test both |
Memory-mapping |
Object storage (S3 / Multi-Storage Client)#
When data lives on S3 or MSC rather than a POSIX filesystem:
Index files (
.idx) are cached locally underobject_storage_cache_path.Binary data files (
.bin) are streamed on-demand in 256 MB chunks, avoiding the need to download entire files.Set
--no-mmap-bin-filessince memory-mapping doesn’t apply to object storage.Ensure the index-cache path is visible wherever the later dataset construction will run.
Scaling Characteristics#
Aspect |
Behavior |
Why it works |
|---|---|---|
Cross-rank contention |
None after init |
All index files are read-only; |
Sampling determinism |
All ranks produce the same permutation |
Shared |
Data-parallel sharding |
Each DP rank gets a disjoint subset of samples |
No overlap during training; assignment happens in the sampler rather than via extra dataset coordination |
Index broadcast |
Via shared filesystem, not collectives |
Rank 0 writes |
Troubleshooting#
Symptom: Training hangs at startup for minutes
Likely cause: Rank 0 is building indices while all other ranks wait at the barrier.
Fix: Pre-build the cache with
tools/prepare_cache.pyand enable--dataloader-fast-cache-load.
Symptom: Spike in I/O at training start, then normal
Likely cause: All ranks simultaneously memory-mapping index files after the barrier.
Fix: Enable
--dataloader-defer-npy-index-mmapto overlap index loading with training.
Symptom: Slow data loading during training (not just startup)
Run with
--mock-datato confirm the dataloader is the bottleneck.If startup, not steady-state throughput, is the main issue, try
--dataloader-defer-npy-index-mmap.If you are blending many dataset prefixes, try
--per-dataset-sequences-path.Test with
--no-mmap-bin-files– the optimal setting depends on your filesystem.