core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp#

Module Contents#

Classes#

TrainingState

States of a FSDP parameter group, which are coupled with the sharding activity of parameters and gradients during training.

MegatronFSDP

Fully Sharded Data Parallel training.

RegisterFSDPBackwardFunction

Register a backward function that will be called after the backward pass of the model. This function is used to release the parameters after the backward pass.

Functions#

_replace_module_parameter

Replace a module’s parameter with a new parameter, preserving the hierarchy.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.logger#

‘getLogger(…)’

class core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.TrainingState(*args, **kwds)#

Bases: enum.Enum

States of a FSDP parameter group, which are coupled with the sharding activity of parameters and gradients during training.

Initialization

FORWARD#

‘auto(…)’

PRE_BACKWARD#

‘auto(…)’

POST_BACKWARD#

‘auto(…)’

IDLE#

‘auto(…)’

class core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.MegatronFSDP(
module: torch.nn.Module,
dist_index: core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex,
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig = None,
fsdp_unit_modules: Optional[List[torch.nn.Module] | List[str]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
calculate_per_token_loss: bool = False,
init_model_with_meta_device: bool = False,
sync_model_each_microbatch: bool = False,
keep_fp8_transpose_cache: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
)#

Bases: torch.nn.Module

Fully Sharded Data Parallel training.

A distributed training wrapper that shards model parameters, gradients and optimizer states across data parallel workers. Integrates seamlessly with MCore’s tensor and expert parallelism features, and in native PyTorch.

We supports following modes:

  • no_shard: Traditional data parallel training without parameter sharding.

  • optim: Shards optimizer states, this is conceptually close to “ZeRO-1”, and main weights for mixed precision training, meanwhile the following optim_grads and optim_grads_params will also sharding main weights during mixed-precision training, omitted without detailed notation.

  • optim_grads: Shards gradients and optimizer states, this is conceptually close to “ZeRO-2”.

  • optim_grads_params: Shards parameters, gradients and optimizer states, this is conceptually close to “ZeRO-3”.

Key Features:

  • Compatible with MCore’s tensor, context and expert parallelism

  • Compatible with Native PyTorch’s tensor and context parallelism with DTensor

  • Automatic mixed precision training (BF16/FP8)

  • Gradient accumulation and bucketing

  • Optimized activation recompute with shard-aware communication: When recomputing a whole Transformer layer, gather parameters once for both the recomputation and backward computation

  • Compatible with MCore’s distributed checkpointing, and native PyTorch.

Parameters:
  • module (torch.nn.Module) – Underlying Torch Module.

  • dist_index (FSDPDistributedIndex) – FSDPDistributedIndex object containing references to the process groups and device meshes used by Megatron-FSDP.

  • ddp_config (DistributedDataParallelConfig) – FullyShardedDataParallel configuration dataclass containing a variety of Megatron-derived parameters that control the behavior of Megatron-FSDP.

  • fsdp_unit_modules (List[torch.nn.Module] | List[str]) – List of modules that should be treated as an FSDP Unit, i.e. the minimum releasable model unit. It affects the granularity of the communication parameter grouping and triggers aggregate collective communication in FP8 mixed precision training.

  • device (torch.device) – Target device for the sharded model. Used to migrate all model parameters to an expected device. If init_model_with_meta_device=True, this argument is ignored.

  • init_model_with_meta_device (bool) – Whether to initialize model parameters in shards across all devices of the fsdp_group. Utilized to initialize large models that do not fit on a single device.

  • sync_model_each_microbatch (bool) – Whether to sync parameters and install gradients on each training step. When disabled, Megatron-FSDP will overlap reduce-scatter with subsequent compute and delay HSDP gather and reduce operations per optimization cycle, which improves performance and throughput when using delayed optimization strategies such as gradient accumulation. Defaults to True, can be modified before the model forward / backward pass via MegatronFSDP.set_model_auto_sync(bool) or controlled with the (no_)sync context managers or microbatch_count and is_last_microbatch.

  • disable_bucketing – If true, force assign all parameters to a single bucket. If false, use standard bucketing policy: assign parameters to smaller buckets and all-reduce per bucket.

  • keep_fp8_transpose_cache (bool) – Whether to keep the fp8 transpose cache when using Megatron-FSDP. It will use significantly more GPU memory but can improve performance.

  • nccl_ub (bool) – Whether to use NCCL userbuffer for the FSDP communication operands, which uses less number of SMs, resulting better overlapped computation performance. This flag automatically sets fsdp_double_buffer to True, which uses additional GPU memory.

  • fsdp_double_buffer (bool) – Whether to use persistently allocated double buffers for the temporary memory needed in the FSDP communication. This flag is automatically set to True when nccl_ub is True.

  • disable_symmetric_registration (bool) – Whether to disable symmetric (window) registration for NCCL userbuffer registration. This option will force to use conventional (local) userbuffer registration when nccl_ub is set.

.. rubric:: Examples

model = GPTModel(config) model = MegatronFSDP( … model, … dist_index, … ddp_config, … fsdp_unit_modules = [TransformerLayer, LanguageModelEmbedding], … device=torch.device(f”cuda:{torch.cuda.current_device()}”), … init_model_with_meta_device=False, … disable_bucketing=False, … keep_fp8_transpose_cache=False, … nccl_ub=False, … fsdp_double_buffer=False, … disable_symmetric_registration=False, … )

Initialization

_check_module_parameter_types()#

Check if the module parameters include special parameters such as Megatron-Core Expert Parallel (EP/EXPT) parameters.

_init_fsdp_param_and_grad_buffer()#
_import_class_from_path(class_path: str)#

Helper function to import classes from string paths.

all_gather_and_wait_parameters_ready(
params,
prefetch=True,
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
wait_bucket_ready=True,
)#

All-gather parameters across the data parallel group and wait for the all-gather operation to complete.

_register_fsdp_hooks(root_module)#

Register necessary hooks for Fully Sharded Data Parallel (FSDP) execution on the model.

This function sets up various hooks required for FSDP operations, including parameter resharding/unsharding and gradient handling. The registered hooks are: - Pre-forward hook: Unshards parameters before forward pass - Post-forward hook: Reshards parameters after forward pass - Pre-backward hook: Unshards parameters before backward pass - Post-backward hook: Reshards parameters and reduces gradients after backward pass

Parameters:

root_module – The PyTorch module to register FSDP hooks on

.. note::

These hooks are essential for FSDP’s memory efficiency as they manage:

  1. Dynamic parameter sharding/unsharding to reduce memory footprint

  2. Proper gradient synchronization across distributed processes

  3. Gradient accumulation for large batch training

Returns:

None

no_sync()#

Context manager that turns off gradient synchronization. For grads shard mode there will actually always be gradient sync happening.

sync()#

Context manager that synchronizes the MegatronFSDP model parameters and gradients every training step as opposed to every optimization cycle.

set_model_auto_sync(sync_model: bool = True)#

Activate or deactivate flag that controls Megatron-FSDP model synchronization. When activated, the model parameters and gradients will be synchronized EVERY training step, i.e. gradient reduction will be waited upon instead of overlapped with subsequent compute, and all-gather + reduce operations across the DP-Outer ProcessGroup will be executed when sharding on DP-Outer during HSDP / HFSDP. Otherwise, MegatronFSDP will perform such synchronizations every optimization cycle depending on is_last_microbatch = True or microbatch_count = 0, which are more flexible but difficult to manage, e.g. microbatch_count and is_last_microbatch can be modified elsewhere for custom training strategies.

Will commonly be called on the final microbatch of a training step before the model forward pass and gradient backward pass to ensure that the model gradients (prior to optimizer.step()) and model parameters (prior to dist. checkpointing) are synchronized and representative of the model trained at that particular training step. Otherwise, model training performance will slightly degrade when MegatronFSDP.model_auto_sync = True.

Parameters:

sync_model (bool, optional) – Whether to synchronize the model every training step. MegatronFSDP.model_auto_sync will be set to the value of sync_model. Defaults to True. MegatronFSDP.model_auto_sync defaults to False.

get_distributed_index() core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex#

Get the distributed environment of Megatron-FSDP, which contains references to the process groups and device meshes used by Megatron-FSDP.

start_param_sync(
*unused,
force_sync: bool = False,
force_dispatch: bool = False,
)#

Initiates param sync (all-gather) communication operations for all model parameters.

By default, when overlap_param_gather is set to True, dispatches asynchronous communication calls; when overlap_param_gather is set to False, calls synchronous communication ops. Can override this default behavior using flags below.

Parameters:
  • force_sync (bool, optional) – force synchronous collective regardless of other settings.

  • force_dispatch (bool, optional) – force dispatch regardless of other settings.

start_grad_sync(*unused)#

Initiates grad sync (all-reduce or reduce-scatter) communication operations for all model gradients.

When overlap_grad_reduce is set to True, dispatches asynchronous communication calls. When overlap_grad_reduce is set to False, calls synchronous communication ops.

synchronize_param_gather()#

Synchronize parameter all-gather operations for all model parameters.

synchronize_gradient_reduce()#

Synchronize gradient reduce-scatter operations for all model gradients.

attach_grad_to_optimizer_state()#

Attach gradients to optimizer named parameters in preparation for optimizer.step().

finish_grad_sync()#

Finishes grad sync (all-reduce or reduce-scatter) communication operations for all model gradients. Call prior to the optimization step to resolve asynchronous gradient reductions.

When overlap_grad_reduce is set to True, waits for asynchronous communication calls to complete. When overlap_grad_reduce is set to False, calls synchronous communication ops.

_replace_param_with_distributed_if_needed()#
_replace_param_with_raw_if_needed()#
_reestablish_shared_weights(old_params, new_params)#

Reestablishes shared (tied) weights in a PyTorch module after parameter replacement.

When iterating over named_parameters(), PyTorch skips parameters that are shared via weight-tying (e.g., lm_head.weight referencing wte.weight). After replacing parameters, these shared weights become independent, causing previously hidden parameters to appear in the parameter list. This function restores the original shared structure by ensuring parameters that were previously tied remain shared.

Parameters:
  • old_params (dict) – Mapping from parameter names to original parameter tensors.

  • new_params (dict) – Mapping from parameter names to new parameter tensors.

scale_gradients(scaling_factor: float)#

Scale all gradients inside the buffers by scaling_factor.

zero_grad_buffer()#

Zeros out all grad buffers. Needs to be called at the beginning of each training iteration alongside optimizer.zero_grad().

install_optimized_model_weights()#

Copies optimized parameter values into the model training parameters managed by Megatron-FSDP. Should be called after the optimizer.step().

broadcast_params()#

Syncs parameters across all DP ranks.

forward(*inputs, **kwargs)#

Wrapped forward pass of the model managed by FSDP.

class core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.RegisterFSDPBackwardFunction#

Bases: torch.autograd.Function

Register a backward function that will be called after the backward pass of the model. This function is used to release the parameters after the backward pass.

static forward(ctx, post_backward, *inputs: torch.Tensor)#

Forward pass of the RegisterFSDPBackwardFunction function.

static backward(ctx, *grads: torch.Tensor)#

Backward pass of the RegisterFSDPBackwardFunction function.

core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp._replace_module_parameter(module, name, new_param)#

Replace a module’s parameter with a new parameter, preserving the hierarchy.