core.distributed.fsdp.src.megatron_fsdp.utils#

Module Contents#

Classes#

FSDPDistributedIndex

Class containing references to the process groups utilized by Megatron-FSDP.

GlobalMemoryBuffer

Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently.

Functions#

get_te_version

Get TE version from version; if not available use pip’s. Use caching.

is_te_min_version

Check if minimum version of transformer-engine is installed.

is_submodule

Check if a module is a submodule of another module.

get_mesh_names

Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.

contains_submesh

Check if a sub-mesh exists in the device mesh by name.

_get_cuda_rng_state

Return the random number generator state of the specified GPU.

_set_cuda_rng_state

Sets the random number generator state of the current GPU.

initialize_rng_tracker

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

get_cuda_rng_tracker

Get cuda rng tracker.

get_global_memory_buffer

Return the global GlobalMemoryBuffer object

create_updated_function_signature

Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class.

is_mcore_tensor_model_parallel

Check if the given parameter is Megatron-Core tensor model parallel.

is_mcore_tensor_parallel_duplicated

Check if the given parameter is Megatron-Core tensor model parallel and duplicated.

get_mcore_tensor_parallel_partition_dim

Get the partition dimension for a Megatron-Core tensor model parallel parameter.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.utils.logger#

‘getLogger(…)’

core.distributed.fsdp.src.megatron_fsdp.utils._MODEL_PARALLEL_RNG_TRACKER_NAME#

‘model-parallel-rng’

core.distributed.fsdp.src.megatron_fsdp.utils.get_te_version()#

Get TE version from version; if not available use pip’s. Use caching.

core.distributed.fsdp.src.megatron_fsdp.utils.is_te_min_version(vers, check_equality=True)#

Check if minimum version of transformer-engine is installed.

core.distributed.fsdp.src.megatron_fsdp.utils.is_submodule(module, parent_module, strict=True)#

Check if a module is a submodule of another module.

core.distributed.fsdp.src.megatron_fsdp.utils.get_mesh_names(
device_mesh: Optional[torch.distributed.DeviceMesh] = None,
only_submesh_dims: bool = False,
) list[str]#

Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.

core.distributed.fsdp.src.megatron_fsdp.utils.contains_submesh(
device_mesh: Optional[torch.distributed.DeviceMesh],
submesh_names: Optional[str | Sequence[str]],
) bool#

Check if a sub-mesh exists in the device mesh by name.

core.distributed.fsdp.src.megatron_fsdp.utils._get_cuda_rng_state(
device: Union[int, str, torch.device] = 'cuda',
clone: bool = False,
graph_safe: bool = False,
) torch.Tensor#

Return the random number generator state of the specified GPU.

Parameters:
  • device (int) – The gpu to retrieve the rng state

  • clone (bool) – Whether to also clone the retrieved RNG state

  • graph_safe (bool) – Get the rng state in a graph safe manner.

This function is adapted from torch.cuda.random.get_rng_state()

core.distributed.fsdp.src.megatron_fsdp.utils._set_cuda_rng_state(
new_state: torch.Tensor,
device: int = -1,
graph_safe: bool = False,
)#

Sets the random number generator state of the current GPU.

Parameters:
  • new_state (torch.ByteTensor) – The desired state

  • device (int) – The gpu to retrieve the rng state

  • graph_safe (bool) – Set the rng state in a graph safe manner.

This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases.

core.distributed.fsdp.src.megatron_fsdp.utils.initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
force_reset: bool = False,
)#

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

core.distributed.fsdp.src.megatron_fsdp.utils.get_cuda_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
)#

Get cuda rng tracker.

class core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex(
device_mesh: torch.distributed.DeviceMesh,
dp_shard_dim: Optional[str] = None,
dp_outer_dim: Optional[str] = None,
tp_dim: Optional[str] = None,
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hsdp_outer_dp_shard: bool = False,
expt_device_mesh: Optional[torch.distributed.DeviceMesh] = None,
)#

Class containing references to the process groups utilized by Megatron-FSDP.

This class tracks the device mesh and different process groups required for full-sharded data parallelism (FSDP), including support for hybrid and tensor/data parallel strategies.

Initialization

Parameters:
  • device_mesh (DeviceMesh) – The DeviceMesh to use for the DistributedIndex.

  • dp_shard_dim (Optional[str]) – The dimension name of the data parallel (and context parallel) sharding sub-mesh.

  • dp_outer_dim (Optional[str]) – The dimension name of the “outer” data parallel sub-mesh for replication or sharding when using HSDP.

  • tp_dim (Optional[str]) – The dimension name of the tensor parallel sub-mesh.

  • hybrid_fsdp_group (Optional[torch.distributed.ProcessGroup]) – The process group for hybrid FSDP communication, which is the flattened combination of the dp_outer and dp_shard process groups.

  • hsdp_outer_dp_shard (bool) – Whether to have outer DP group sharding in hybrid FSDP. Specifying outer sharding will lift the bucket sharding coordinate system to flattened ranks of (dp_shard, dp_outer) instead of just sharding across dp_shard ranks and replicating across dp_outer ranks.

  • expt_device_mesh (Optional[DeviceMesh]) – The expert parallel device mesh to use for the DistributedIndex.

expt_fsdp_group#

None

Megatron-FSDP is responsible for storing all required DeviceMesh as per best practices recommended by the DeviceMesh API.

NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions will be impossible via the device_mesh[…] API. We will require all users to correctly _unflatten() their DeviceMesh such that all dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh. contains_submesh(…) -> get_mesh_names(only_submesh_dims=True).

get_submesh(
mesh_dim_names: str | Sequence[str],
is_expert_parallel: bool = False,
) torch.distributed.DeviceMesh#

Retrieve an Megatron-FSDP-registered submesh by name(s).

get_dp_group(
is_expert_parallel: bool = False,
) torch.distributed.ProcessGroup#

Get the data parallel process group.

get_fsdp_group(
is_expert_parallel: bool = False,
) torch.distributed.ProcessGroup#

Get the FSDP process group.

get_outer_fsdp_group() torch.distributed.ProcessGroup#

Get the outer-FSDP process group.

get_root_mesh(
is_expert_parallel: bool = False,
) torch.distributed.DeviceMesh#

Get the device mesh.

get_logical_hybrid_fsdp_rank()#

Returns the logical rank of the current process within the full-shard hybrid FSDP group.

In full-shard hybrid FSDP, parameters are first sharded across the inner data-parallel group, then across the outer data-parallel group. This changes the effective rank mapping compared to standard data parallelism. Use this method to get the correct rank index for the hybrid group.

Returns:

The index of the current process in the hybrid FSDP group.

Return type:

int

Raises:

AssertionError – If full-shard hybrid FSDP is not enabled.

class core.distributed.fsdp.src.megatron_fsdp.utils.GlobalMemoryBuffer#

Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently.

Initialization

get_tensor(
tensor_shape,
dtype,
name,
mem_alloc_context: Optional[Callable] = None,
)#

Returns (potentially) a sub-tensor from the self.buffer for the given shape.

core.distributed.fsdp.src.megatron_fsdp.utils.get_global_memory_buffer()#

Return the global GlobalMemoryBuffer object

core.distributed.fsdp.src.megatron_fsdp.utils.create_updated_function_signature(
original_function,
**extended_kwargs: dict,
)#

Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class.

core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_model_parallel(param: torch.Tensor) bool#

Check if the given parameter is Megatron-Core tensor model parallel.

core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_parallel_duplicated(param: torch.Tensor) bool#

Check if the given parameter is Megatron-Core tensor model parallel and duplicated.

core.distributed.fsdp.src.megatron_fsdp.utils.get_mcore_tensor_parallel_partition_dim(
param: torch.Tensor,
) Optional[int]#

Get the partition dimension for a Megatron-Core tensor model parallel parameter.