core.distributed.fsdp.mcore_fsdp_adapter#
Module Contents#
Classes#
Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model. |
Functions#
Data#
API#
- core.distributed.fsdp.mcore_fsdp_adapter.logger#
‘getLogger(…)’
- class core.distributed.fsdp.mcore_fsdp_adapter.FullyShardedDataParallel(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
- module: torch.nn.Module,
- fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
- disable_bucketing: bool = False,
- device: Optional[torch.device] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.distributed.data_parallel_base._BaseDataParallelFully Sharded Data Parallel (FSDP) wrapper for the Megatron model.
Initialization
- _MODULE_TYPE_REGISTRY: Dict[str, set]#
None
- load_state_dict(state_dict, strict=True)#
Load the state dictionary into the module.
- _detect_parallelism_type(
- param_name: str,
- module: torch.nn.Module,
Infer tensor-parallelism type for a parameter under a given module (forked from Megatron-Bridge).
- Returns:
“column”, “row”, or “replicated” if a type can be inferred, else None.
- _annotate_tensor_parallelism(root_module: torch.nn.Module) None#
Annotate parameters under root_module with inferred tensor-parallel metadata.
Each parameter that can be classified will get a
_tensor_parallel_modeattribute set to one of: “column”, “row”, or “replicated”.
- _init_dist_index(pg_collection)#
Initialize the distributed index for the module.
- stop_communication()#
Stop communication for the module.
- sync_rng_states_across_tp_group()#
Synchronize the tensor parallel random number generator states.
- core.distributed.fsdp.mcore_fsdp_adapter._get_hsdp_tp_mesh(
- outer_fsdp_dp_group,
- dp_cp_group,
- tp_group,
- ep_size=1,
- core.distributed.fsdp.mcore_fsdp_adapter._get_dp_tp_mesh(dp_cp_group, tp_group, ep_size=1)#
- core.distributed.fsdp.mcore_fsdp_adapter._check_mesh_ranks_and_group_ranks_are_consistent(
- mesh_ranks,
- group_ranks,
- core.distributed.fsdp.mcore_fsdp_adapter._get_rng_state_dict()#
- core.distributed.fsdp.mcore_fsdp_adapter._load_rng_state_dict(rng_state_dict)#