core.distributed.fsdp.mcore_fsdp_adapter#
Module Contents#
Classes#
Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model. |
Functions#
Data#
API#
- core.distributed.fsdp.mcore_fsdp_adapter.logger#
‘getLogger(…)’
- class core.distributed.fsdp.mcore_fsdp_adapter.FullyShardedDataParallel(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- module: torch.nn.Module,
- fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
- disable_bucketing: bool = False,
- device: Optional[torch.device] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.distributed.data_parallel_base._BaseDataParallelFully Sharded Data Parallel (FSDP) wrapper for the Megatron model.
Initialization
- load_state_dict(state_dict, strict=True)#
Load the state dictionary into the module.
- _fix_tensor_parallel_attributes(module)#
- _init_dist_index(pg_collection)#
Initialize the distributed index for the module.
- stop_communication()#
Stop communication for the module.
- sync_rng_states_across_tp_group()#
Synchronize the tensor parallel random number generator states.
- core.distributed.fsdp.mcore_fsdp_adapter._get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group)#
- core.distributed.fsdp.mcore_fsdp_adapter._get_dp_tp_mesh(dp_cp_group, tp_group, ep_size=1)#
- core.distributed.fsdp.mcore_fsdp_adapter._check_mesh_ranks_and_group_ranks_are_consistent(
- mesh_ranks,
- group_ranks,
- core.distributed.fsdp.mcore_fsdp_adapter._get_rng_state_dict()#
- core.distributed.fsdp.mcore_fsdp_adapter._load_rng_state_dict(rng_state_dict)#