core.distributed.fsdp.src.megatron_fsdp.utils#

Module Contents#

Classes#

FSDPDistributedIndex

Class containing references to the process groups utilized by Megatron-FSDP.

GlobalMemoryBuffer

Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently.

Functions#

get_te_version

Get TE version from version; if not available use pip’s. Use caching.

is_te_min_version

Check if minimum version of transformer-engine is installed.

is_submodule

Check if a module is a submodule of another module.

is_float8tensor

Check if a tensor is a Transformer Engine Float8Tensor.

get_mesh_names

Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.

contains_submesh

Check if a sub-mesh exists in the device mesh by name.

_multi_tensor_copy_this_to_that

Use multi-tensor-applier to copy values from one list to another. We don’t have a bfloat16 implementation so for now if the overflow_buf is not provided, we default back to simple loop copy to be compatible with bfloat16.

modify_underlying_storage

Replace the underlying raw data of a tensor with new data.

quantize_param_shard

Cast shard fp32 main params to fp8 model params.

_get_cuda_rng_state

Return the random number generator state of the specified GPU.

_set_cuda_rng_state

Sets the random number generator state of the current GPU.

initialize_rng_tracker

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

get_cuda_rng_tracker

Get cuda rng tracker.

get_global_memory_buffer

Return the global GlobalMemoryBuffer object

create_updated_function_signature

Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class.

is_mcore_tensor_model_parallel

Check if the given parameter is Megatron-Core tensor model parallel.

is_mcore_tensor_parallel_duplicated

Check if the given parameter is Megatron-Core tensor model parallel and duplicated.

get_mcore_tensor_parallel_partition_dim

Get the partition dimension for a Megatron-Core tensor model parallel parameter.

Data#

API#

core.distributed.fsdp.src.megatron_fsdp.utils.logger#

‘getLogger(…)’

core.distributed.fsdp.src.megatron_fsdp.utils._MODEL_PARALLEL_RNG_TRACKER_NAME#

‘model-parallel-rng’

core.distributed.fsdp.src.megatron_fsdp.utils.get_te_version()#

Get TE version from version; if not available use pip’s. Use caching.

core.distributed.fsdp.src.megatron_fsdp.utils.is_te_min_version(vers, check_equality=True)#

Check if minimum version of transformer-engine is installed.

core.distributed.fsdp.src.megatron_fsdp.utils.is_submodule(module, parent_module, strict=True)#

Check if a module is a submodule of another module.

core.distributed.fsdp.src.megatron_fsdp.utils.is_float8tensor(tensor: torch.Tensor) bool#

Check if a tensor is a Transformer Engine Float8Tensor.

Note that in TE2.x, in order to support more recipes, the design of the fp8 tensor class has changed. Now Float8Tensor is only used for current scaling and delayed scaling. And mxfp8 and blockwise scaling have their own fp8 tensor classes. These different fp8 tensor classes are both inherited from QuantizedTensor. So, for TE1.x, FP8_TENSOR_CLASS is Float8Tensor, and for TE2.x, FP8_TENSOR_CLASS is QuantizedTensor.

core.distributed.fsdp.src.megatron_fsdp.utils.get_mesh_names(
device_mesh: Optional[torch.distributed.DeviceMesh] = None,
only_submesh_dims: bool = False,
) list[str]#

Get all the sub-mesh (“dp”, “cp”, etc.) and flattened-mesh (“dp_cp”, etc.) names in the DeviceMesh. When only_submesh_dims=True, only checks for sub-mesh dimensions.

core.distributed.fsdp.src.megatron_fsdp.utils.contains_submesh(
device_mesh: Optional[torch.distributed.DeviceMesh],
submesh_names: Optional[str | Sequence[str]],
) bool#

Check if a sub-mesh exists in the device mesh by name.

core.distributed.fsdp.src.megatron_fsdp.utils._multi_tensor_copy_this_to_that(
this: List[torch.Tensor],
that: List[torch.Tensor],
overflow_buf: Optional[torch.Tensor] = None,
)#

Use multi-tensor-applier to copy values from one list to another. We don’t have a bfloat16 implementation so for now if the overflow_buf is not provided, we default back to simple loop copy to be compatible with bfloat16.

core.distributed.fsdp.src.megatron_fsdp.utils.modify_underlying_storage(
tensor: torch.Tensor,
new_raw_data: torch.Tensor,
)#

Replace the underlying raw data of a tensor with new data.

core.distributed.fsdp.src.megatron_fsdp.utils.quantize_param_shard(
model_params,
main_params,
start_offsets,
data_parallel_group,
fsdp_shard_model_params=None,
)#

Cast shard fp32 main params to fp8 model params.

core.distributed.fsdp.src.megatron_fsdp.utils._get_cuda_rng_state(
device: Union[int, str, torch.device] = 'cuda',
clone: bool = False,
graph_safe: bool = False,
) torch.Tensor#

Return the random number generator state of the specified GPU.

Parameters:
  • device (int) – The gpu to retrieve the rng state

  • clone (bool) – Whether to also clone the retrieved RNG state

  • graph_safe (bool) – Get the rng state in a graph safe manner.

This function is adapted from torch.cuda.random.get_rng_state()

core.distributed.fsdp.src.megatron_fsdp.utils._set_cuda_rng_state(
new_state: torch.Tensor,
device: int = -1,
graph_safe: bool = False,
)#

Sets the random number generator state of the current GPU.

Parameters:
  • new_state (torch.ByteTensor) – The desired state

  • device (int) – The gpu to retrieve the rng state

  • graph_safe (bool) – Set the rng state in a graph safe manner.

This function is adapted from PyTorch repo (torch.cuda.set_rng_state) with a single change: the input state is not cloned. Cloning caused major performance issues for +4 GPU cases.

core.distributed.fsdp.src.megatron_fsdp.utils.initialize_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
force_reset: bool = False,
)#

Create the RNG tracker. ‘use_te_rng_tracker’ determines whether to use Megatron or TransformerEngine’s implementation. In particular, TransformerEngine’s implementation is cudagraphable and supports FP8.

core.distributed.fsdp.src.megatron_fsdp.utils.get_cuda_rng_tracker(
use_te_rng_tracker: bool = False,
inference_rng_tracker: bool = False,
use_cudagraphable_rng: bool = False,
)#

Get cuda rng tracker.

class core.distributed.fsdp.src.megatron_fsdp.utils.FSDPDistributedIndex(
device_mesh: torch.distributed.DeviceMesh,
dp_shard_dim: Optional[str] = None,
dp_outer_dim: Optional[str] = None,
tp_dim: Optional[str] = None,
hybrid_fsdp_group: Optional[torch.distributed.ProcessGroup] = None,
hsdp_outer_dp_shard: bool = False,
expt_device_mesh: Optional[torch.distributed.DeviceMesh] = None,
)#

Class containing references to the process groups utilized by Megatron-FSDP.

This class tracks the device mesh and different process groups required for full-sharded data parallelism (FSDP), including support for hybrid and tensor/data parallel strategies.

Initialization

Parameters:
  • device_mesh (DeviceMesh) – The DeviceMesh to use for the DistributedIndex.

  • dp_shard_dim (Optional[str]) – The dimension name of the data parallel (and context parallel) sharding sub-mesh.

  • dp_outer_dim (Optional[str]) – The dimension name of the “outer” data parallel sub-mesh for replication or sharding when using HSDP.

  • tp_dim (Optional[str]) – The dimension name of the tensor parallel sub-mesh.

  • hybrid_fsdp_group (Optional[torch.distributed.ProcessGroup]) – The process group for hybrid FSDP communication, which is the flattened combination of the dp_outer and dp_shard process groups.

  • hsdp_outer_dp_shard (bool) – Whether to have outer DP group sharding in hybrid FSDP. Specifying outer sharding will lift the bucket sharding coordinate system to flattened ranks of (dp_shard, dp_outer) instead of just sharding across dp_shard ranks and replicating across dp_outer ranks.

  • expt_device_mesh (Optional[DeviceMesh]) – The expert parallel device mesh to use for the DistributedIndex.

expt_fsdp_group#

None

Megatron-FSDP is responsible for storing all required DeviceMesh as per best practices recommended by the DeviceMesh API.

NOTE(@cspades): In PyTorch 2.11, retrieving flattened mesh dimensions will be impossible via the device_mesh[…] API. We will require all users to correctly _unflatten() their DeviceMesh such that all dimensions used by Megatron-FSDP are sub-meshes of the DeviceMesh. contains_submesh(…) -> get_mesh_names(only_submesh_dims=True).

get_submesh(
mesh_dim_names: str | Sequence[str],
is_expert_parallel: bool = False,
) torch.distributed.DeviceMesh#

Retrieve an Megatron-FSDP-registered submesh by name(s).

get_dp_group(
is_expert_parallel: bool = False,
) torch.distributed.ProcessGroup#

Get the data parallel process group.

get_fsdp_group(
is_expert_parallel: bool = False,
) torch.distributed.ProcessGroup#

Get the FSDP process group.

get_outer_fsdp_group() torch.distributed.ProcessGroup#

Get the outer-FSDP process group.

get_root_mesh(
is_expert_parallel: bool = False,
) torch.distributed.DeviceMesh#

Get the device mesh.

get_logical_hybrid_fsdp_rank()#

Returns the logical rank of the current process within the full-shard hybrid FSDP group.

In full-shard hybrid FSDP, parameters are first sharded across the inner data-parallel group, then across the outer data-parallel group. This changes the effective rank mapping compared to standard data parallelism. Use this method to get the correct rank index for the hybrid group.

Returns:

The index of the current process in the hybrid FSDP group.

Return type:

int

Raises:

AssertionError – If full-shard hybrid FSDP is not enabled.

class core.distributed.fsdp.src.megatron_fsdp.utils.GlobalMemoryBuffer#

Global buffer to avoid dynamic memory allocations. Caller should ensure that buffers of the same name are not used concurrently.

Initialization

get_tensor(
tensor_shape,
dtype,
name,
mem_alloc_context: Optional[Callable] = None,
)#

Returns (potentially) a sub-tensor from the self.buffer for the given shape.

core.distributed.fsdp.src.megatron_fsdp.utils.get_global_memory_buffer()#

Return the global GlobalMemoryBuffer object

core.distributed.fsdp.src.megatron_fsdp.utils.create_updated_function_signature(
original_function,
**extended_kwargs: dict,
)#

Given a function, create a new version of the function with extended keyword-only arguments or parameters. Used to patch or extend methods in instances of a class.

core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_model_parallel(param: torch.Tensor) bool#

Check if the given parameter is Megatron-Core tensor model parallel.

core.distributed.fsdp.src.megatron_fsdp.utils.is_mcore_tensor_parallel_duplicated(param: torch.Tensor) bool#

Check if the given parameter is Megatron-Core tensor model parallel and duplicated.

core.distributed.fsdp.src.megatron_fsdp.utils.get_mcore_tensor_parallel_partition_dim(
param: torch.Tensor,
) Optional[int]#

Get the partition dimension for a Megatron-Core tensor model parallel parameter.