core.distributed.fsdp.mcore_fsdp_adapter#

Module Contents#

Classes#

FullyShardedDataParallel

Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.

Functions#

Data#

API#

core.distributed.fsdp.mcore_fsdp_adapter.logger#

‘getLogger(…)’

class core.distributed.fsdp.mcore_fsdp_adapter.FullyShardedDataParallel(
config: megatron.core.transformer.transformer_config.TransformerConfig,
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.distributed.data_parallel_base._BaseDataParallel

Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.

Initialization

_MODULE_TYPE_REGISTRY: Dict[str, set]#

None

load_state_dict(state_dict, strict=True)#

Load the state dictionary into the module.

_detect_parallelism_type(
param_name: str,
module: torch.nn.Module,
) Optional[str]#

Infer tensor-parallelism type for a parameter under a given module (forked from Megatron-Bridge).

Returns:

“column”, “row”, or “replicated” if a type can be inferred, else None.

_annotate_tensor_parallelism(root_module: torch.nn.Module) None#

Annotate parameters under root_module with inferred tensor-parallel metadata.

Each parameter that can be classified will get a _tensor_parallel_mode attribute set to one of: “column”, “row”, or “replicated”.

_init_dist_index(pg_collection)#

Initialize the distributed index for the module.

stop_communication()#

Stop communication for the module.

sync_rng_states_across_tp_group()#

Synchronize the tensor parallel random number generator states.

core.distributed.fsdp.mcore_fsdp_adapter._get_hsdp_tp_mesh(
outer_fsdp_dp_group,
dp_cp_group,
tp_group,
ep_size=1,
)#
core.distributed.fsdp.mcore_fsdp_adapter._get_dp_tp_mesh(dp_cp_group, tp_group, ep_size=1)#
core.distributed.fsdp.mcore_fsdp_adapter._check_mesh_ranks_and_group_ranks_are_consistent(
mesh_ranks,
group_ranks,
)#
core.distributed.fsdp.mcore_fsdp_adapter._get_rng_state_dict()#
core.distributed.fsdp.mcore_fsdp_adapter._load_rng_state_dict(rng_state_dict)#