core.distributed.fsdp.src.megatron_fsdp.fully_shard#

Module Contents#

Classes#

ShardingStrategy

IntEnum to track the abbreviated sharding strategy for Megatron-FSDP.

Functions#

fully_shard_model

Fully-shard the model for Megatron-FSDP.

fully_shard_optimizer

Fully shard the optimizer for Megatron-FSDP. This is an in-place operation on the optimizer instance, which modifies the optimizer to call methods exposed by the MegatronFSDP model API.

fully_shard

Fully shard the model and the optimizer for Megatron-FSDP.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.fully_shard.logger#

‘getLogger(…)’

class core.distributed.fsdp.src.megatron_fsdp.fully_shard.ShardingStrategy#

Bases: enum.IntEnum

IntEnum to track the abbreviated sharding strategy for Megatron-FSDP.

  • 0 or no_shard implies that your model is not sharded. Similar memory usage to DDP.

  • 1 or optim implies that your optimizer state is sharded. Similar to optimizer state sharding in ZeRO-DP.

  • 2 or optim_grads implies that your optimizer state and gradients are sharded. Similar to optimizer state and gradient sharding in ZeRO-2.

  • 3 or optim_grads_params implies that your optimizer state, gradients, and training parameters are sharded. Similar to optimizer state, gradient, and training parameter sharding in ZeRO-3.

Initialization

Initialize self. See help(type(self)) for accurate signature.

NO_SHARD#

0

OPTIM#

1

OPTIM_GRADS#

2

OPTIM_GRADS_PARAMS#

3

core.distributed.fsdp.src.megatron_fsdp.fully_shard.fully_shard_model(
module: torch.nn.Module,
device_mesh: torch.distributed.DeviceMesh,
dp_shard_dim: str,
dp_outer_dim: Optional[str] = None,
tp_dim: Optional[str] = None,
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[torch.distributed.DeviceMesh] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
device: Optional[torch.device] = None,
init_model_with_meta_device: bool = False,
grad_reduce_in_fp32: bool = False,
preserve_fp32_weights: bool = True,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = True,
sync_model_each_microbatch: bool = True,
preproc_state_dict_for_dcp_ckpt: bool = True,
check_for_nan_in_grad: bool = True,
average_in_collective: bool = False,
disable_bucketing: bool = False,
calculate_per_token_loss: bool = False,
keep_fp8_transpose_cache: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
)#

Fully-shard the model for Megatron-FSDP.

Parameters:

fully_shard (Subset of the arguments for)

Returns:

The wrapped Megatron-FSDP model configured for FSDP.

Return type:

model (MegatronFSDP)

core.distributed.fsdp.src.megatron_fsdp.fully_shard.fully_shard_optimizer(
model: core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.MegatronFSDP,
optimizer: torch.optim.Optimizer,
preproc_state_dict_for_dcp_ckpt: bool = True,
)#

Fully shard the optimizer for Megatron-FSDP. This is an in-place operation on the optimizer instance, which modifies the optimizer to call methods exposed by the MegatronFSDP model API.

Parameters:
  • model (MegatronFSDP) – The Megatron-FSDP model to be fully sharded.

  • optimizer (torch.optim.Optimizer) – The optimizer to be fully sharded.

  • preproc_state_dict_for_dcp_ckpt (bool) – Whether to preprocess the state dict for DCP checkpointing. Defaults to True.

core.distributed.fsdp.src.megatron_fsdp.fully_shard.fully_shard(
module: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device_mesh: torch.distributed.DeviceMesh,
dp_shard_dim: str,
dp_outer_dim: Optional[str] = None,
tp_dim: Optional[str] = None,
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
expt_device_mesh: Optional[torch.distributed.DeviceMesh] = None,
fsdp_unit_modules: Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]] = None,
zero_dp_strategy: str | int = 3,
outer_dp_sharding_strategy: str | int = 0,
device: Optional[torch.device] = None,
init_model_with_meta_device: bool = False,
grad_reduce_in_fp32: bool = False,
preserve_fp32_weights: bool = True,
overlap_grad_reduce: bool = True,
overlap_param_gather: bool = True,
sync_model_each_microbatch: bool = True,
preproc_state_dict_for_dcp_ckpt: bool = True,
check_for_nan_in_grad: bool = True,
average_in_collective: bool = False,
disable_bucketing: bool = False,
calculate_per_token_loss: bool = False,
keep_fp8_transpose_cache: bool = False,
nccl_ub: bool = False,
fsdp_double_buffer: bool = False,
disable_symmetric_registration: bool = False,
) tuple[core.distributed.fsdp.src.megatron_fsdp.megatron_fsdp.MegatronFSDP, torch.optim.Optimizer]#

Fully shard the model and the optimizer for Megatron-FSDP.

Wraps the model as an Megatron-FSDP module, and modifies the optimizer to be compatible with the Megatron-FSDP training strategy.

Parameters:
  • module (torch.nn.Module) – The PyTorch module fully-sharded and managed by Megatron-FSDP.

  • optimizer (torch.optim.Optimizer) – (Distributed) optimizer for training the model, which is extended to automatically execute necessary Megatron-FSDP operations during the training loop. If not provided, the user is expected to utilize fully_shard_optimizer() or the MegatronFSDP API to manually configure the model for optimization. Defaults to None.

  • device_mesh (DeviceMesh) – Device mesh object defining the topology for distributed training.

  • dp_shard_dim (str) – Name of the data parallel sharding sub-mesh in the device_mesh. Supports a flattened DP-CP sub-mesh, in which case parameters, gradients, and optimizer state will be sharded across both DP and CP ranks. Required to enable the core functionality of Megatron-FSDP.

  • dp_outer_dim (Optional[str]) – Name of the “outer” DP sub-mesh in the device_mesh for hybrid-sharding (HSDP), which supports “DP-Replicate” as well as optimizer state sharding (HFSDP). Defaults to None. Required for HSDP, which is enabled by this argument.

  • tp_dim (Optional[str]) – Name of the tensor parallel sub-mesh in the device_mesh, which is necessary for strided sharding between TP and FSDP (and fully-sharded HSDP) dimensions. Defaults to None. Required if TP is used in the model, or if TransformerEngine layers are utilized, as TE defaults to “TP=1”.

  • hybrid_fsdp_group (Optional[torch.distributed.ProcessGroup]) – Cumulative data parallel process group for hybrid FSDP that can be manufactured by flattening the outer-FSDP (dp_outer_dim) and FSDP (dp_shard_dim) process groups or sub-meshes. Defaults to None. Required for HSDP, i.e. if dp_outer_dim is not None.

  • expt_device_mesh (Optional[DeviceMesh]) – Expert parallel device mesh object defining the topology for MoE distributed training.

  • fsdp_unit_modules (Optional[Sequence[Type[torch.nn.Module]] | Sequence[str]]) – List of (sub-)module classes or (sub-)module class import paths that are “units”, which are torch.nn.Module(s) that are sharded and scheduled by Megatron-FSDP. In particular, FSDP unit module parameters can be “safely” deallocated after the forward() or backward() pass without interfering with other computational operations that rely on those parameters in the complete PyTorch model. This information is utilized by Megatron-FSDP to optimally shard, gather, and overlap communications during the forward and backward pass of the module. Defaults to None, which is peak-memory-equivalent to DDP / “no_shard”.

  • zero_dp_strategy (str | int) –

    Zero-redundancy sharding strategy for sharding data parallel parameters and gradients.

    • ”no_shard” / 0: No optimizer, gradient, or parameter sharding. Similar memory usage to DDP.

    • ”optim” / 1: Shards optimizer states (and main weights for mixed precision training), which is conceptually similar to optimizer state sharding in ZeRO-DP.

    • ”optim_grads” / 2: Shards gradients and optimizer states, which is conceptually similar to “ZeRO-2”.

    • ”optim_grads_params” / 3: Shards parameters, gradients and optimizer states, which is conceptually similar to “ZeRO-3”. Defaults to “optim_grads_params” / 3.

  • outer_dp_sharding_strategy (str | int) – Sharding strategy for outer data parallel group in Hybrid Sharded Data Parallel (HSDP). Shares the same semantics as zero_dp_strategy, but only ‘no_shard’ / 0 (DP Replication) and ‘optim’ / 1 (Optimizer State Hybrid Sharding) are supported, and ‘optim’ / 1 is only supported when zero_dp_strategy=’optim_grads_params’. This option is only effective when HSDP is enabled, i.e. when dp_outer_dim is not None. Defaults to “no_shard” / 0, which replicates model parameters across the dp_outer group.

  • device (Optional[torch.device]) – Target device for the sharded model. Used to migrate all parameters in the model to an expected device. If init_model_with_meta_device=True, this argument is ignored. Defaults to None.

  • init_model_with_meta_device (bool) – Utilized to initialize large models that do not fit on a single device, and requires implementing a custom Module.reset_parameters() or Module._reset_parameters() method. Defaults to False.

  • grad_reduce_in_fp32 (bool) – Whether to perform gradient reduction in FP32. Defaults to False.

  • preserve_fp32_weights (bool) – Whether to preserve FP32 optimization weights. Defaults to True.

  • overlap_grad_reduce (bool) – Whether to overlap gradient reduce-scatter (or all-reduce) with backward compute. Defaults to True.

  • overlap_param_gather (bool) – Whether to overlap parameter all-gather with forward and backward compute. Defaults to True.

  • sync_model_each_microbatch (bool) – Whether to sync parameters and install gradients on each training step. When disabled, Megatron-FSDP will overlap reduce-scatter with subsequent compute and delay HSDP gather and reduce operations per optimization cycle, which improves performance and throughput when using delayed optimization strategies such as gradient accumulation. Defaults to True, can be modified before the model forward / backward pass via MegatronFSDP.set_model_auto_sync(bool) or controlled with the (no_)sync context managers or microbatch_count and is_last_microbatch.

  • preproc_state_dict_for_dcp_ckpt (bool) – Whether to preprocess the unevenly-sharded state dict for DCP checkpointing, for both the model and the optimizer. Defaults to True.

  • check_for_nan_in_grad (bool) – Whether to check for NaN values in gradients. Defaults to True.

  • average_in_collective (bool) – Whether to average gradients in collective communication. Defaults to False. TODO: This is currently NOT supported!

  • disable_bucketing (bool) – Whether to disable gradient bucketing optimization, which permits more granular and precise communication of parameters and gradients. Defaults to False.

  • calculate_per_token_loss (bool) – Whether to calculate loss per token, which deactivates gradient scaling. Defaults to False.

  • keep_fp8_transpose_cache (bool) – Whether to keep the FP8 transpose cache when using a Megatron FSDP. Defaults to False.

  • nccl_ub (bool) – Whether to use NCCL UCC for communication. Defaults to False.

  • fsdp_double_buffer (bool) – Whether to use double buffer for FSDP. Defaults to False.

  • disable_symmetric_registration (bool) – Whether to disable symmetric (window) registration for NCCL UB registration. This option forces conventional (local) UB registration when nccl_ub is set.

Returns:

The wrapped Megatron-FSDP model configured for distributed training. torch.optim.Optimizer: The Megatron-FSDP-compliant optimizer for training the model.

Return type:

torch.nn.Module

.. note::

This implementation uses NVIDIA’s FSDP which includes optimizations specific to NVIDIA hardware and software stack.