core.distributed.fsdp.mcore_fsdp_adapter#

Module Contents#

Classes#

FullyShardedDataParallel

Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.

Functions#

Data#

API#

core.distributed.fsdp.mcore_fsdp_adapter.logger#

‘getLogger(…)’

class core.distributed.fsdp.mcore_fsdp_adapter.FullyShardedDataParallel(
config: megatron.core.transformer.transformer_config.TransformerConfig,
ddp_config: megatron.core.distributed.distributed_data_parallel_config.DistributedDataParallelConfig,
module: torch.nn.Module,
fsdp_unit_modules: Optional[List[torch.nn.Module]] = None,
disable_bucketing: bool = False,
device: Optional[torch.device] = None,
pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
)#

Bases: megatron.core.distributed.data_parallel_base._BaseDataParallel

Fully Sharded Data Parallel (FSDP) wrapper for the Megatron model.

Initialization

load_state_dict(state_dict, strict=True)#

Load the state dictionary into the module.

_fix_tensor_parallel_attributes(module)#
_init_dist_index(pg_collection)#

Initialize the distributed index for the module.

stop_communication()#

Stop communication for the module.

sync_rng_states_across_tp_group()#

Synchronize the tensor parallel random number generator states.

core.distributed.fsdp.mcore_fsdp_adapter._get_hsdp_tp_mesh(outer_fsdp_dp_group, dp_cp_group, tp_group)#
core.distributed.fsdp.mcore_fsdp_adapter._get_dp_tp_mesh(dp_cp_group, tp_group, ep_size=1)#
core.distributed.fsdp.mcore_fsdp_adapter._check_mesh_ranks_and_group_ranks_are_consistent(
mesh_ranks,
group_ranks,
)#
core.distributed.fsdp.mcore_fsdp_adapter._get_rng_state_dict()#
core.distributed.fsdp.mcore_fsdp_adapter._load_rng_state_dict(rng_state_dict)#