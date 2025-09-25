The custom Fully Sharded Data Parallelism (FSDP) implementation in Megatron-Core is specifically designed to optimize memory consumption and performance for large language models. The core design principles include:

Optimized for Large Language Models : This custom FSDP implementation is tailored to efficiently scale with models containing billions of parameters, ensuring seamless execution and training of massive models.

Efficient Memory Consumption : By strategically sharding optimizer states, gradients, and model parameters, the custom FSDP significantly reduces memory usage. This approach enables the training of models that would otherwise be too large to fit in memory.

Efficient Workflow & Overlapping Communication and Computation : The implementation is engineered to minimize the number of communication steps required during training. It maximizes the overlap between communication and computation, thereby enhancing overall training efficiency and reducing latency.

Support for MCore’s Efficient Training Methods: The custom FSDP seamlessly integrates with Megatron-Core’s advanced parallelism techniques, including tensor parallelism, expert parallelism and context parallelism. Additionally, it supports automatic mixed precision training, further optimizing training performance and efficiency.

The design of Custom FSDP draws inspiration from PyTorch FSDP Zhao, Yanli, et al. and MCore’s distributed optimizer. The introduction to PyTorch FSDP is referenced here to clarify the underlying concepts of the custom FSDP design.

In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.

When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation.

Notice that the unit processed in workflow here is the “FSDP instance 1: N layers”, where an FSDP instance is the smallest FSDP processing unit (also a PyTorch module), which means that we can safely release this module weights after using it (executing the forward or backward of this module), and there will be no other computations computations relying on these weights. This capability is the foundation of FSDP’s layer-by-layer execution and memory-saving strategy. An FSDP instance is also referred to as an FSDP Unit.

It is worth noting that an FSDP instance can correspond to multiple FSDP parameter groups. These groups are separated by Data Parallel (DP) communication groups and the data type of the parameter or gradient. Consequently, an FSDP instance may require several parameter-gather tasks before execution (forward or backward). Each FSDP parameter group corresponds to one Data Parallel Buffer in custom FSDP.

At a high level FSDP works as follow:

In constructor

Shard model parameters and each rank only keeps its own shard

In forward path

Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit

Run forward computation

Discard parameter shards it has just collected

In backward path

Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit

Run backward computation

Run reduce_scatter to sync gradients

Discard parameters.

One way to view FSDP’s sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.