core.distributed.torch_fully_sharded_data_parallel_config#
Module Contents#
Classes#
Configuration for TorchFullyShardedDataParallel. |
API#
- class core.distributed.torch_fully_sharded_data_parallel_config.TorchFullyShardedDataParallelConfig#
Bases:
megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfigConfiguration for TorchFullyShardedDataParallel.
- reshard_after_forward: Union[bool, int]#
True
Controls the parameter behavior after forward.
See PyTorch for complete documentation: https://github.com/pytorch/pytorch/blob/ac8ddf115065106f038865389a07f2d0c9ed5e11/torch/distributed/fsdp/_fully_shard/_fully_shard.py#L97C31-L97C49 # pylint: disable=line-too-long