nemo_automodel.components.distributed.mesh_utils#
Device mesh creation utilities for distributed training.
This module provides a central function to create device meshes based on the distributed config type (FSDP2, MegatronFSDP, or DDP).
Usage: from nemo_automodel.components.distributed.config import FSDP2Config from nemo_automodel.components.distributed.mesh_utils import create_device_mesh
config = FSDP2Config(sequence_parallel=True)
device_mesh, moe_mesh = create_device_mesh(
config,
tp_size=2,
pp_size=1,
dp_replicate_size=2,
world_size=8,
)
Module Contents#
Functions#
Create device mesh based on distributed config type. |
|
Create device mesh for FSDP2. |
|
Create device mesh for MegatronFSDP. |
|
Compatibility shim for DeviceMesh._unflatten(), which was added in PyTorch 2.10. |
|
Access a 1D submesh by parallelism name (e.g. |
|
Access a submesh by parallelism dim names. |
|
Return the DP mesh for FSDP2 without losing the original root mesh. |
API#
- nemo_automodel.components.distributed.mesh_utils.create_device_mesh(
- distributed_config: Union[nemo_automodel.components.distributed.config.FSDP2Config, nemo_automodel.components.distributed.config.MegatronFSDPConfig, nemo_automodel.components.distributed.config.DDPConfig],
- *,
- dp_size: Optional[int] = None,
- dp_replicate_size: Optional[int] = None,
- tp_size: int = 1,
- pp_size: int = 1,
- cp_size: int = 1,
- ep_size: int = 1,
- world_size: int,
Create device mesh based on distributed config type.
Routes to the appropriate mesh creation logic based on config type.
- Parameters:
distributed_config – The distributed config (FSDP2Config, MegatronFSDPConfig, or DDPConfig).
dp_size – Data parallel size. If None, inferred from world_size and other parallelism sizes.
dp_replicate_size – FSDP2-only. Size of the replication group for HSDP (Hybrid Sharded Data Parallel). If None or <= 0, defaults to 1. Must be a divisor of dp_size.
tp_size – Tensor parallel size.
pp_size – Pipeline parallel size.
cp_size – Context parallel size.
ep_size – Expert parallel size (for MoE models).
world_size – Total number of processes.
- Returns:
(device_mesh, moe_mesh) - For FSDP2Config: Full device mesh + optional moe_mesh (if ep_size > 1) - For MegatronFSDPConfig: Device mesh + None - For DDPConfig: (None, None) - DDP doesn’t use device mesh
- Return type:
tuple
- Raises:
ValueError – If dp_replicate_size is provided with non-FSDP2 config.
ValueError – If world_size is not divisible by parallelism sizes.
- nemo_automodel.components.distributed.mesh_utils._create_fsdp2_device_mesh(
- dp_size: Optional[int],
- dp_replicate_size: Optional[int],
- tp_size: int,
- pp_size: int,
- cp_size: int,
- ep_size: int,
- world_size: int,
- backend: str,
Create device mesh for FSDP2.
Mesh shape: (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) Mesh names: (“pp”, “dp_replicate”, “dp_shard”, “cp”, “tp”)
Also creates flattened submeshes: - “dp”: dp_replicate + dp_shard - “dp_shard_cp”: dp_shard + cp - “dp_cp”: dp_replicate + dp_shard + cp
- Parameters:
dp_size – Data parallel size. If None, inferred from world_size.
dp_replicate_size – Size of the replication group for HSDP.
tp_size – Tensor parallel size.
pp_size – Pipeline parallel size.
cp_size – Context parallel size.
ep_size – Expert parallel size (for MoE models).
world_size – Total number of processes.
backend – Distributed backend (‘nccl’ or ‘gloo’).
- Returns:
(device_mesh, moe_mesh)
- Return type:
tuple
- nemo_automodel.components.distributed.mesh_utils._create_megatron_fsdp_device_mesh(
- dp_size: Optional[int],
- tp_size: int,
- cp_size: int,
- world_size: int,
- backend: str,
Create device mesh for MegatronFSDP.
Mesh shape: (dp_size, cp_size, tp_size) Mesh names: (“dp”, “cp”, “tp”)
Also creates flattened submesh “dp_cp” if cp_size > 1.
- Parameters:
dp_size – Data parallel size. If None, inferred from world_size.
tp_size – Tensor parallel size.
cp_size – Context parallel size.
world_size – Total number of processes.
backend – Distributed backend (‘nccl’ or ‘gloo’).
- Returns:
The device mesh for MegatronFSDP.
- Return type:
DeviceMesh
- nemo_automodel.components.distributed.mesh_utils._unflatten_compat(
- flat_mesh: torch.distributed.device_mesh.DeviceMesh,
- dim: int,
- sizes: tuple,
- names: tuple,
Compatibility shim for DeviceMesh._unflatten(), which was added in PyTorch 2.10.
Reconstructs a multi-dimensional mesh from a flat mesh by reshaping its rank tensor.
dimmust be 0 (only case used in this codebase).
- nemo_automodel.components.distributed.mesh_utils.get_flat_mesh(
- device_mesh: torch.distributed.device_mesh.DeviceMesh,
- name: str,
Access a 1D submesh by parallelism name (e.g.
"dp","tp","dp_cp").PyTorch 2.11 deprecates
root_mesh["name"]for dimensions created via_flatten(). This reads the_flatten()result directly.- Parameters:
device_mesh – Any DeviceMesh (root or submesh).
name – Parallelism dimension name.
- nemo_automodel.components.distributed.mesh_utils.get_submesh(
- device_mesh: torch.distributed.device_mesh.DeviceMesh,
- names: tuple,
Access a submesh by parallelism dim names.
Handles all cases: single dims, multi-dim slices, and combinations that include
_flatten()-created dims (e.g.("dp_replicate", "dp_shard_cp")). For the latter, finds the parent_flatten()result and calls_unflatten()to decompose it into the requested shape.- Parameters:
device_mesh – Any DeviceMesh (root or submesh).
names – Tuple of dimension names.
- nemo_automodel.components.distributed.mesh_utils.get_fsdp_dp_mesh(
- device_mesh: torch.distributed.device_mesh.DeviceMesh,
- dp_replicate_name: str = MeshAxisName.DP_REPLICATE,
- dp_shard_cp_name: str = MeshAxisName.DP_SHARD_CP,
Return the DP mesh for FSDP2 without losing the original root mesh.
get_submesh()may rebuild a fresh DeviceMesh when asked to compose native and flattened dims like("dp_replicate", "dp_shard_cp"). That is fine for many local operations, but FSDP2 expects its DP mesh to share the same root mesh as TP/EP meshes. On multi-node TP runs this can break group construction in non-obvious ways.Prefer native dimensions whenever possible:
cp=1, dp_replicate=1 ->
device_mesh["dp_shard"]cp=1, dp_replicate>1 ->
device_mesh[("dp_replicate", "dp_shard")]cp>1, dp_replicate=1 ->
device_mesh[("dp_shard", "cp")]
When both CP and replicated DP are active we fall back to
get_submesh()because the composed mesh is genuinely multi-level.