core.distributed.torch_fully_sharded_data_parallel#

Module Contents#

Classes#

TorchFullyShardedDataParallel

Enables fully sharded data parallelism by wrapping the given model with the PyTorch FSDP2 API: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md To utilize this class, PyTorch version >= 2.4.0 is required.

API#

class core.distributed.torch_fully_sharded_data_parallel.TorchFullyShardedDataParallel(
config: core.transformer.transformer_config.TransformerConfig,
ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
module: torch.nn.Module,
sub_modules_to_wrap: Set[torch.nn.Module] = {TransformerLayer, LanguageModelEmbedding, RotaryEmbedding, tensor_parallel.ColumnParallelLinear},
disable_bucketing: bool = False,
process_group: Optional[torch.distributed.ProcessGroup] = None,
)#

Bases: core.distributed.data_parallel_base._BaseDataParallel

Enables fully sharded data parallelism by wrapping the given model with the PyTorch FSDP2 API: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md To utilize this class, PyTorch version >= 2.4.0 is required.

Parameters:
  • config – Transformer config object.

  • ddp_config – TorchDistributedDataParallel config object.

  • module – Underlying model.

  • sub_modules_to_wrap

    Set of sub_modules to shard with FSDP. Parameters within each sub_module will be all-gathered just-in-time. The default set includes the following submodules derived from the GPT model architecture: TransformerLayer (all Transformer layers) LanguageModelEmbedding (initial embedding layer) RotaryEmbedding (initial RoPE layer) tensor_parallel.ColumnParallelLinear (final output layer)

    User can set _fsdp_modules attribute on submodules to set additional submodules to shard with FSDP.

  • process_group – Optional ProcessGroup to use for distributed operations. If None (default), the data parallel process group will be obtained from parallel_state.get_data_parallel_group(with_context_parallel=True).

Initialization

load_state_dict(state_dict, strict=True)#

No-op because tensors are already loaded in-place by _load_base_checkpoint with FSDP2.