core.distributed.fsdp.src.megatron_fsdp.utils#
Module Contents#
Classes#
Class containing references to the process groups utilized by Megatron-FSDP. |
|
Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently. |
Functions#
Get TE version from version; if not available use pip’s. Use caching. |
|
Check if minimum version of |
|
Check if a module is a submodule of another module. |
|
Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions. |
|
Check if a sub-mesh exists in the device mesh by name. |
|
Return the random number generator state of the specified GPU. |
|
Sets the random number generator state of the current GPU. |
|
Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8. |
|
Get cuda rng tracker. |
|
Return the global GlobalMemoryBuffer object |
|
Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class. |
|
Check if the given parameter is Megatron-Core tensor model parallel. |
|
Check if the given parameter is Megatron-Core tensor model parallel and duplicated. |
|
Get the partition dimension for a Megatron-Core tensor model parallel parameter. |
Data#
API#
- core.distributed.fsdp.src.megatron_fsdp.utils.logger#
‘getLogger(…)’
- core.distributed.fsdp.src.megatron_fsdp.utils._MODEL_PARALLEL_RNG_TRACKER_NAME#
‘model-parallel-rng’
- core.distributed.fsdp.src.megatron_fsdp.utils.get_te_version()#
Get TE version from version; if not available use pip’s. Use caching.
- core.distributed.fsdp.src.megatron_fsdp.utils.is_te_min_version(vers, check_equality=True)#
Check if minimum version of
transformer-engineis installed.
- core.distributed.fsdp.src.megatron_fsdp.utils.is_submodule(module, parent_module, strict=True)#
Check if a module is a submodule of another module.
- core.distributed.fsdp.src.megatron_fsdp.utils.get_mesh_names(
- device_mesh: Optional[torch.distributed.DeviceMesh] = None,
- only_submesh_dims: bool = False,
Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.
- core.distributed.fsdp.src.megatron_fsdp.utils.contains_submesh(
- device_mesh: Optional[torch.distributed.DeviceMesh],
- submesh_names: Optional[str | Sequence[str]],
Check if a sub-mesh exists in the device mesh by name.
- core.distributed.fsdp.src.megatron_fsdp.utils._get_cuda_rng_state(
- device: Union[int, str, torch.device] = 'cuda',
- clone: bool = False,
- graph_safe: bool = False,
Return the random number generator state of the specified GPU.
- Parameters:
device (int) – The gpu to retrieve the rng state
clone (bool) – Whether to also clone the retrieved RNG state
graph_safe (bool) – Get the rng state in a graph safe manner.
This function is adapted from torch.cuda.random.get_rng_state()
- core.distributed.fsdp.src.megatron_fsdp.utils._set_cuda_rng_state(
- new_state: torch.Tensor,
- device: int = -1,
- graph_safe: bool = False,
Sets the random number generator state of the current GPU.
- Parameters:
new_state (torch.ByteTensor) – The desired state
device (int) – The gpu to retrieve the rng state
graph_safe (bool) – Set the rng state in a graph safe manner.
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases.
- core.distributed.fsdp.src.megatron_fsdp.utils.initialize_rng_tracker(
- use_te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
- use_cudagraphable_rng: bool = False,
- force_reset: bool = False,
Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.
- core.distributed.fsdp.src.megatron_fsdp.utils.get_cuda_rng_tracker(
- use_te_rng_tracker: bool = False,
- inference_rng_tracker: bool = False,
- use_cudagraphable_rng: bool = False,
Get cuda rng tracker.
- class core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex(
- device_mesh: torch.distributed.DeviceMesh,
- dp_shard_dim: Optional[str] = None,
- dp_outer_dim: Optional[str] = None,
- tp_dim: Optional[str] = None,
- hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
- hsdp_outer_dp_shard: bool = False,
- expt_device_mesh: Optional[torch.distributed.DeviceMesh] = None,
Class containing references to the process groups utilized by Megatron-FSDP.
This class tracks the device mesh and different process groups required for full-sharded data parallelism (FSDP), including support for hybrid and tensor/data parallel strategies.
Initialization
- Parameters:
device_mesh (DeviceMesh) – The DeviceMesh to use for the DistributedIndex.
dp_shard_dim (Optional[str]) – The dimension name of the data parallel (and context parallel) sharding sub-mesh.
dp_outer_dim (Optional[str]) – The dimension name of the “outer” data parallel sub-mesh for replication or sharding when using HSDP.
tp_dim (Optional[str]) – The dimension name of the tensor parallel sub-mesh.
hybrid_fsdp_group (Optional[torch.distributed.ProcessGroup]) – The process group for hybrid FSDP communication, which is the flattened combination of the dp_outer and dp_shard process groups.
hsdp_outer_dp_shard (bool) – Whether to have outer DP group sharding in hybrid FSDP. Specifying outer sharding will lift the bucket sharding coordinate system to flattened ranks of (dp_shard, dp_outer) instead of just sharding across dp_shard ranks and replicating across dp_outer ranks.
expt_device_mesh (Optional[DeviceMesh]) – The expert parallel device mesh to use for the DistributedIndex.
- expt_fsdp_group#
None
Megatron-FSDP is responsible for storing all required DeviceMesh as per best practices recommended by the DeviceMesh API.
NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions will be impossible via the device_mesh[…] API. We will require all users to correctly _unflatten() their DeviceMesh such that all dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh. contains_submesh(…) -> get_mesh_names(only_submesh_dims=True).
- get_submesh(
- mesh_dim_names: str | Sequence[str],
- is_expert_parallel: bool = False,
Retrieve an Megatron-FSDP-registered submesh by name(s).
- get_dp_group(
- is_expert_parallel: bool = False,
Get the data parallel process group.
- get_fsdp_group(
- is_expert_parallel: bool = False,
Get the FSDP process group.
- get_outer_fsdp_group() torch.distributed.ProcessGroup#
Get the outer-FSDP process group.
- get_root_mesh(
- is_expert_parallel: bool = False,
Get the device mesh.
- get_logical_hybrid_fsdp_rank()#
Returns the logical rank of the current process within the full-shard hybrid FSDP group.
In full-shard hybrid FSDP, parameters are first sharded across the inner data-parallel group, then across the outer data-parallel group. This changes the effective rank mapping compared to standard data parallelism. Use this method to get the correct rank index for the hybrid group.
- Returns:
The index of the current process in the hybrid FSDP group.
- Return type:
int
- Raises:
AssertionError – If full-shard hybrid FSDP is not enabled.
- class core.distributed.fsdp.src.megatron_fsdp.utils.GlobalMemoryBuffer#
Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently.
Initialization
- get_tensor(
- tensor_shape,
- dtype,
- name,
- mem_alloc_context: Optional[Callable] = None,
Returns (potentially) a sub-tensor from the self.buffer for the given shape.
- core.distributed.fsdp.src.megatron_fsdp.utils.get_global_memory_buffer()#
Return the global GlobalMemoryBuffer object
- core.distributed.fsdp.src.megatron_fsdp.utils.create_updated_function_signature(
- original_function,
- **extended_kwargs: dict,
Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class.
- core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_model_parallel(param: torch.Tensor) bool#
Check if the given parameter is Megatron-Core tensor model parallel.
- core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_parallel_duplicated(param: torch.Tensor) bool#
Check if the given parameter is Megatron-Core tensor model parallel and duplicated.
- core.distributed.fsdp.src.megatron_fsdp.utils.get_mcore_tensor_parallel_partition_dim(
- param: torch.Tensor,
Get the partition dimension for a Megatron-Core tensor model parallel parameter.