nemo_automodel.components.distributed.mesh_utils#

Device mesh creation utilities for distributed training.

This module provides a central function to create device meshes based on the distributed config type (FSDP2, MegatronFSDP, or DDP).

Usage: from nemo_automodel.components.distributed.config import FSDP2Config from nemo_automodel.components.distributed.mesh_utils import create_device_mesh

config = FSDP2Config(sequence_parallel=True)
device_mesh, moe_mesh = create_device_mesh(
    config,
    tp_size=2,
    pp_size=1,
    dp_replicate_size=2,
    world_size=8,
)

Module Contents#

Functions#

create_device_mesh

Create device mesh based on distributed config type.

_create_fsdp2_device_mesh

Create device mesh for FSDP2.

_create_megatron_fsdp_device_mesh

Create device mesh for MegatronFSDP.

_unflatten_compat

Compatibility shim for DeviceMesh._unflatten(), which was added in PyTorch 2.10.

get_flat_mesh

Access a 1D submesh by parallelism name (e.g. "dp", "tp", "dp_cp").

get_submesh

Access a submesh by parallelism dim names.

get_fsdp_dp_mesh

Return the DP mesh for FSDP2 without losing the original root mesh.

API#

nemo_automodel.components.distributed.mesh_utils.create_device_mesh(
distributed_config: Union[nemo_automodel.components.distributed.config.FSDP2Config, nemo_automodel.components.distributed.config.MegatronFSDPConfig, nemo_automodel.components.distributed.config.DDPConfig],
*,
dp_size: Optional[int] = None,
dp_replicate_size: Optional[int] = None,
tp_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
ep_size: int = 1,
world_size: int,
) Tuple[Optional[torch.distributed.device_mesh.DeviceMesh], Optional[torch.distributed.device_mesh.DeviceMesh]]#

Create device mesh based on distributed config type.

Routes to the appropriate mesh creation logic based on config type.

Parameters:
  • distributed_config – The distributed config (FSDP2Config, MegatronFSDPConfig, or DDPConfig).

  • dp_size – Data parallel size. If None, inferred from world_size and other parallelism sizes.

  • dp_replicate_size – FSDP2-only. Size of the replication group for HSDP (Hybrid Sharded Data Parallel). If None or <= 0, defaults to 1. Must be a divisor of dp_size.

  • tp_size – Tensor parallel size.

  • pp_size – Pipeline parallel size.

  • cp_size – Context parallel size.

  • ep_size – Expert parallel size (for MoE models).

  • world_size – Total number of processes.

Returns:

(device_mesh, moe_mesh) - For FSDP2Config: Full device mesh + optional moe_mesh (if ep_size > 1) - For MegatronFSDPConfig: Device mesh + None - For DDPConfig: (None, None) - DDP doesn’t use device mesh

Return type:

tuple

Raises:
  • ValueError – If dp_replicate_size is provided with non-FSDP2 config.

  • ValueError – If world_size is not divisible by parallelism sizes.

nemo_automodel.components.distributed.mesh_utils._create_fsdp2_device_mesh(
dp_size: Optional[int],
dp_replicate_size: Optional[int],
tp_size: int,
pp_size: int,
cp_size: int,
ep_size: int,
world_size: int,
backend: str,
) Tuple[torch.distributed.device_mesh.DeviceMesh, Optional[torch.distributed.device_mesh.DeviceMesh]]#

Create device mesh for FSDP2.

Mesh shape: (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) Mesh names: (“pp”, “dp_replicate”, “dp_shard”, “cp”, “tp”)

Also creates flattened submeshes: - “dp”: dp_replicate + dp_shard - “dp_shard_cp”: dp_shard + cp - “dp_cp”: dp_replicate + dp_shard + cp

Parameters:
  • dp_size – Data parallel size. If None, inferred from world_size.

  • dp_replicate_size – Size of the replication group for HSDP.

  • tp_size – Tensor parallel size.

  • pp_size – Pipeline parallel size.

  • cp_size – Context parallel size.

  • ep_size – Expert parallel size (for MoE models).

  • world_size – Total number of processes.

  • backend – Distributed backend (‘nccl’ or ‘gloo’).

Returns:

(device_mesh, moe_mesh)

Return type:

tuple

nemo_automodel.components.distributed.mesh_utils._create_megatron_fsdp_device_mesh(
dp_size: Optional[int],
tp_size: int,
cp_size: int,
world_size: int,
backend: str,
) torch.distributed.device_mesh.DeviceMesh#

Create device mesh for MegatronFSDP.

Mesh shape: (dp_size, cp_size, tp_size) Mesh names: (“dp”, “cp”, “tp”)

Also creates flattened submesh “dp_cp” if cp_size > 1.

Parameters:
  • dp_size – Data parallel size. If None, inferred from world_size.

  • tp_size – Tensor parallel size.

  • cp_size – Context parallel size.

  • world_size – Total number of processes.

  • backend – Distributed backend (‘nccl’ or ‘gloo’).

Returns:

The device mesh for MegatronFSDP.

Return type:

DeviceMesh

nemo_automodel.components.distributed.mesh_utils._unflatten_compat(
flat_mesh: torch.distributed.device_mesh.DeviceMesh,
dim: int,
sizes: tuple,
names: tuple,
) torch.distributed.device_mesh.DeviceMesh#

Compatibility shim for DeviceMesh._unflatten(), which was added in PyTorch 2.10.

Reconstructs a multi-dimensional mesh from a flat mesh by reshaping its rank tensor. dim must be 0 (only case used in this codebase).

nemo_automodel.components.distributed.mesh_utils.get_flat_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
name: str,
) torch.distributed.device_mesh.DeviceMesh#

Access a 1D submesh by parallelism name (e.g. "dp", "tp", "dp_cp").

PyTorch 2.11 deprecates root_mesh["name"] for dimensions created via _flatten(). This reads the _flatten() result directly.

Parameters:
  • device_mesh – Any DeviceMesh (root or submesh).

  • name – Parallelism dimension name.

nemo_automodel.components.distributed.mesh_utils.get_submesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
names: tuple,
) torch.distributed.device_mesh.DeviceMesh#

Access a submesh by parallelism dim names.

Handles all cases: single dims, multi-dim slices, and combinations that include _flatten()-created dims (e.g. ("dp_replicate", "dp_shard_cp")). For the latter, finds the parent _flatten() result and calls _unflatten() to decompose it into the requested shape.

Parameters:
  • device_mesh – Any DeviceMesh (root or submesh).

  • names – Tuple of dimension names.

nemo_automodel.components.distributed.mesh_utils.get_fsdp_dp_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
dp_replicate_name: str = MeshAxisName.DP_REPLICATE,
dp_shard_cp_name: str = MeshAxisName.DP_SHARD_CP,
) torch.distributed.device_mesh.DeviceMesh#

Return the DP mesh for FSDP2 without losing the original root mesh.

get_submesh() may rebuild a fresh DeviceMesh when asked to compose native and flattened dims like ("dp_replicate", "dp_shard_cp"). That is fine for many local operations, but FSDP2 expects its DP mesh to share the same root mesh as TP/EP meshes. On multi-node TP runs this can break group construction in non-obvious ways.

Prefer native dimensions whenever possible:

  • cp=1, dp_replicate=1 -> device_mesh["dp_shard"]

  • cp=1, dp_replicate>1 -> device_mesh[("dp_replicate", "dp_shard")]

  • cp>1, dp_replicate=1 -> device_mesh[("dp_shard", "cp")]

When both CP and replicated DP are active we fall back to get_submesh() because the composed mesh is genuinely multi-level.