nemo_automodel.components.distributed.mesh_utils

View as Markdown

Device mesh construction and access utilities for distributed training.

Module Contents

Classes

NameDescription
_MeshSpecNamed mesh shape plus derived flattened axes.

Functions

NameDescription
_create_device_meshesCreate raw device meshes based on distributed config type.
_create_fsdp2_device_meshCreate the FSDP2 root mesh and optional MoE mesh.
_create_megatron_fsdp_device_meshCreate the Megatron FSDP mesh.
_create_moe_mesh-
_degree-
_infer_dp_size-
_init_named_mesh-
_mesh_device_type-
_register_flattened_axes-
_require_size_one-
_unflatten_compatCompatibility shim for DeviceMesh._unflatten(), added in PyTorch 2.10.
_validate_mesh_spec-
get_flat_meshAccess a 1D submesh by parallelism name (e.g. "dp", "tp", "dp_cp").
get_fsdp_dp_meshReturn the DP mesh for FSDP2 without losing the original root mesh.
get_submeshAccess a submesh by parallelism dim names.

Data

__all__

API

class nemo_automodel.components.distributed.mesh_utils._MeshSpec(
shape: tuple[int, ...],
axes: tuple[nemo_automodel.components.distributed.mesh.MeshAxisName, ...],
flattened_axes: dict[nemo_automodel.components.distributed.mesh.MeshAxisName, tuple[nemo_automodel.components.distributed.mesh.MeshAxisName, ...]] = dict()
)
Dataclass

Named mesh shape plus derived flattened axes.

axes
tuple[MeshAxisName, ...]
flattened_axes
dict[MeshAxisName, tuple[MeshAxisName, ...]] = field(default_factory=dict)
shape
tuple[int, ...]
nemo_automodel.components.distributed.mesh_utils._create_device_meshes(
strategy_config: nemo_automodel.components.distributed.config.DistributedStrategyConfig,
parallelism: nemo_automodel.components.distributed.mesh.ParallelismSizes,
world_size: int
) -> tuple[torch.distributed.device_mesh.DeviceMesh | None, torch.distributed.device_mesh.DeviceMesh | None]

Create raw device meshes based on distributed config type.

nemo_automodel.components.distributed.mesh_utils._create_fsdp2_device_mesh(
parallelism: nemo_automodel.components.distributed.mesh.ParallelismSizes,
world_size: int
) -> tuple[torch.distributed.device_mesh.DeviceMesh, torch.distributed.device_mesh.DeviceMesh | None]

Create the FSDP2 root mesh and optional MoE mesh.

nemo_automodel.components.distributed.mesh_utils._create_megatron_fsdp_device_mesh(
parallelism: nemo_automodel.components.distributed.mesh.ParallelismSizes,
world_size: int
) -> torch.distributed.device_mesh.DeviceMesh

Create the Megatron FSDP mesh.

nemo_automodel.components.distributed.mesh_utils._create_moe_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
ep_shard_size: int,
ep_size: int
) -> torch.distributed.device_mesh.DeviceMesh
nemo_automodel.components.distributed.mesh_utils._degree(
value: int | None
) -> int
nemo_automodel.components.distributed.mesh_utils._infer_dp_size(
dp_size: int | None,
world_size: int,
non_dp_size: int,
expression: str,
factors: tuple[int, ...]
) -> int
nemo_automodel.components.distributed.mesh_utils._init_named_mesh(
spec: nemo_automodel.components.distributed.mesh_utils._MeshSpec
) -> torch.distributed.device_mesh.DeviceMesh
nemo_automodel.components.distributed.mesh_utils._mesh_device_type() -> str
nemo_automodel.components.distributed.mesh_utils._register_flattened_axes(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
flattened_axes: dict[nemo_automodel.components.distributed.mesh.MeshAxisName, tuple[nemo_automodel.components.distributed.mesh.MeshAxisName, ...]]
) -> None
nemo_automodel.components.distributed.mesh_utils._require_size_one(
strategy_name: str,
size: int | None,
feature_name: str
) -> None
nemo_automodel.components.distributed.mesh_utils._unflatten_compat(
flat_mesh: torch.distributed.device_mesh.DeviceMesh,
axis: int,
sizes: tuple,
names: tuple
) -> torch.distributed.device_mesh.DeviceMesh

Compatibility shim for DeviceMesh._unflatten(), added in PyTorch 2.10.

nemo_automodel.components.distributed.mesh_utils._validate_mesh_spec(
spec: nemo_automodel.components.distributed.mesh_utils._MeshSpec
) -> None
nemo_automodel.components.distributed.mesh_utils.get_flat_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
name: str
) -> torch.distributed.device_mesh.DeviceMesh

Access a 1D submesh by parallelism name (e.g. "dp", "tp", "dp_cp").

PyTorch 2.11 deprecates root_mesh["name"] for dimensions created via _flatten(). This reads the _flatten() result directly.

Parameters:

device_mesh
DeviceMesh

Any DeviceMesh (root or submesh).

name
str

Parallelism dimension name.

nemo_automodel.components.distributed.mesh_utils.get_fsdp_dp_mesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
dp_replicate_name: str = MeshAxisName.DP_REPLICATE,
dp_shard_cp_name: str = MeshAxisName.DP_SHARD_CP
) -> torch.distributed.device_mesh.DeviceMesh

Return the DP mesh for FSDP2 without losing the original root mesh.

get_submesh() may rebuild a fresh DeviceMesh when asked to compose native and flattened dims like ("dp_replicate", "dp_shard_cp"). That is fine for many local operations, but FSDP2 expects its DP mesh to share the same root mesh as TP/EP meshes. On multi-node TP runs this can break group construction in non-obvious ways.

Prefer native dimensions whenever possible:

  • cp=1, dp_replicate=1 -> device_mesh["dp_shard"]
  • cp=1, dp_replicate>1 -> device_mesh[("dp_replicate", "dp_shard")]
  • cp>1, dp_replicate=1 -> device_mesh[("dp_shard", "cp")]

When both CP and replicated DP are active we fall back to get_submesh() because the composed mesh is genuinely multi-level.

nemo_automodel.components.distributed.mesh_utils.get_submesh(
device_mesh: torch.distributed.device_mesh.DeviceMesh,
names: tuple
) -> torch.distributed.device_mesh.DeviceMesh

Access a submesh by parallelism dim names.

Handles all cases: single dims, multi-dim slices, and combinations that include _flatten()-created dims (e.g. ("dp_replicate", "dp_shard_cp")). For the latter, finds the parent _flatten() result and calls _unflatten() to decompose it into the requested shape.

Parameters:

device_mesh
DeviceMesh

Any DeviceMesh (root or submesh).

names
tuple

Tuple of dimension names.

nemo_automodel.components.distributed.mesh_utils.__all__ = ['_create_device_meshes', '_create_fsdp2_device_mesh', '_create_megatron_fsdp_de...