nemo_automodel.components.distributed.mesh_utils
nemo_automodel.components.distributed.mesh_utils
Device mesh construction and access utilities for distributed training.
Module Contents
Classes
Functions
Data
API
Named mesh shape plus derived flattened axes.
Create raw device meshes based on distributed config type.
Create the FSDP2 root mesh and optional MoE mesh.
Create the Megatron FSDP mesh.
Compatibility shim for DeviceMesh._unflatten(), added in PyTorch 2.10.
Access a 1D submesh by parallelism name (e.g. "dp", "tp", "dp_cp").
PyTorch 2.11 deprecates root_mesh["name"] for dimensions created via
_flatten(). This reads the _flatten() result directly.
Parameters:
Any DeviceMesh (root or submesh).
Parallelism dimension name.
Return the DP mesh for FSDP2 without losing the original root mesh.
get_submesh() may rebuild a fresh DeviceMesh when asked to compose native
and flattened dims like ("dp_replicate", "dp_shard_cp"). That is fine
for many local operations, but FSDP2 expects its DP mesh to share the same
root mesh as TP/EP meshes. On multi-node TP runs this can break group
construction in non-obvious ways.
Prefer native dimensions whenever possible:
- cp=1, dp_replicate=1 ->
device_mesh["dp_shard"] - cp=1, dp_replicate>1 ->
device_mesh[("dp_replicate", "dp_shard")] - cp>1, dp_replicate=1 ->
device_mesh[("dp_shard", "cp")]
When both CP and replicated DP are active we fall back to get_submesh()
because the composed mesh is genuinely multi-level.
Access a submesh by parallelism dim names.
Handles all cases: single dims, multi-dim slices, and combinations that
include _flatten()-created dims (e.g. ("dp_replicate", "dp_shard_cp")).
For the latter, finds the parent _flatten() result and calls _unflatten()
to decompose it into the requested shape.
Parameters:
Any DeviceMesh (root or submesh).
Tuple of dimension names.