core.distributed.torch_fully_sharded_data_parallel#
Module Contents#
Classes#
Enables fully sharded data parallelism by wrapping the given model with the PyTorch FSDP2 API: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md To utilize this class, PyTorch version >= 2.4.0 is required. |
API#
- class core.distributed.torch_fully_sharded_data_parallel.TorchFullyShardedDataParallel(
- config: core.transformer.transformer_config.TransformerConfig,
- ddp_config: core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- module: torch.nn.Module,
- sub_modules_to_wrap: Set[torch.nn.Module] = {TransformerLayer, LanguageModelEmbedding, RotaryEmbedding, tensor_parallel.ColumnParallelLinear},
- disable_bucketing: bool = False,
- process_group: Optional[torch.distributed.ProcessGroup] = None,
Bases:
core.distributed.data_parallel_base._BaseDataParallelEnables fully sharded data parallelism by wrapping the given model with the PyTorch FSDP2 API: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md To utilize this class, PyTorch version >= 2.4.0 is required.
- Parameters:
config – Transformer config object.
ddp_config – TorchDistributedDataParallel config object.
module – Underlying model.
sub_modules_to_wrap –
Set of sub_modules to shard with FSDP. Parameters within each sub_module will be all-gathered just-in-time. The default set includes the following submodules derived from the GPT model architecture: TransformerLayer (all Transformer layers) LanguageModelEmbedding (initial embedding layer) RotaryEmbedding (initial RoPE layer) tensor_parallel.ColumnParallelLinear (final output layer)
User can set _fsdp_modules attribute on submodules to set additional submodules to shard with FSDP.
process_group – Optional ProcessGroup to use for distributed operations. If None (default), the data parallel process group will be obtained from parallel_state.get_data_parallel_group(with_context_parallel=True).
Initialization
- load_state_dict(state_dict, strict=True)#
No-op because tensors are already loaded in-place by
_load_base_checkpointwith FSDP2.