nemo_automodel.components.distributed.mesh

View as Markdown

MeshContext dataclass, construction, and validation.

MeshContext is the single source of truth for distributed topology: device meshes, parallelism sizes, and axis names.

Parallelism sizes (tp_size, pp_size, etc.) are derived at runtime from the attached DeviceMesh objects via @property. When no mesh is present the properties return safe defaults (1 for sizes, None for dp / hsdp).

All inputs and outputs are typed Python objects (dataclasses, enums, etc.). YAML / dict parsing belongs in the recipe layer — see nemo_automodel.recipes._dist_utils.

Module Contents

Classes

NameDescription
MeshAxisNameCanonical mesh axis names used by DeviceMesh and helpers.
MeshContextRuntime distributed topology context.
ParallelismSizesBuild-time requested parallelism sizes.

Functions

NameDescription
_get_axis_sizeReturn the size of axis if present in mesh, else default.
_optional_axisReturn axis if present in mesh, else None.
_validate_mesh_axis_namesEnsure every axis name in the attached meshes is a :class:MeshAxisName.

Data

_VALID_AXIS_NAMES

__all__

API

class nemo_automodel.components.distributed.mesh.MeshAxisName

Bases: enum.Enum

Canonical mesh axis names used by DeviceMesh and helpers.

Inherits from str so each member compares equal to (and can be used wherever) a plain string — e.g. MeshAxisName.TP == "tp".

CP
= 'cp'
DP
= 'dp'
DP_CP
= 'dp_cp'
DP_REPLICATE
= 'dp_replicate'
DP_SHARD
= 'dp_shard'
DP_SHARD_CP
= 'dp_shard_cp'
EP
= 'ep'
EP_SHARD
= 'ep_shard'
PP
= 'pp'
TP
= 'tp'
class nemo_automodel.components.distributed.mesh.MeshContext(
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None,
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
)
Dataclass

Runtime distributed topology context.

Parallelism sizes (tp_size, pp_size, etc.) are not stored as fields; they are @property accessors that read directly from the attached DeviceMesh / moe_mesh. When no mesh is present the properties return safe defaults (1 for sizes, None for dp / hsdp).

All DeviceMesh objects passed in must use axis names from :class:MeshAxisName; a ValueError is raised on construction if any unknown name is encountered.

Lifecycle

  1. Recipes parse YAML to obtain sizes and strategy configs.
  2. Sizes are passed to :meth:build to build DeviceMesh objects.
  3. MeshContext is created with those meshes; axis names are validated automatically in __post_init__.

Alternatively, :meth:from_meshes constructs an instance directly from DeviceMesh objects (used by NeMoAutoModel.from_pretrained).

cp_size
int

Context-parallel degree (from device_mesh, default 1).

device_mesh
Optional[DeviceMesh] = field(default=None, repr=False)
dp_replicate_size
Optional[int]

HSDP replication degree (from device_mesh, default None).

dp_shard_size
int

DP shard degree (from device_mesh, default 1).

dp_size
Optional[int]

Data-parallel degree (from device_mesh, default None).

ep_size
int

Expert-parallel degree (from moe_mesh, default 1).

moe_mesh
Optional[DeviceMesh] = field(default=None, repr=False)
pp_enabled
bool

True when pp_size > 1.

pp_size
int

Pipeline-parallel degree (from device_mesh, default 1).

tp_size
int

Tensor-parallel degree (from device_mesh, default 1).

nemo_automodel.components.distributed.mesh.MeshContext.__post_init__() -> None
nemo_automodel.components.distributed.mesh.MeshContext._dp_axis_names() -> typing.Tuple[str, ...]

DP axis names for FSDP mesh slicing.

classmethod

Build a topology-only :class:MeshContext from parallelism sizes.

Parameters:

strategy_config
DistributedStrategyConfig

Already-instantiated distributed strategy config.

parallelism_sizes
ParallelismSizes | NoneDefaults to None

Requested data, tensor, pipeline, context, and expert parallelism sizes. If None, defaults to no parallelism with DP inferred from world_size.

world_size
int | NoneDefaults to None

Total process count. If None, inferred from the distributed environment.

nemo_automodel.components.distributed.mesh.MeshContext.from_meshes(
device_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
) -> nemo_automodel.components.distributed.mesh.MeshContext
classmethod

Build a :class:MeshContext from DeviceMesh objects.

This is the entry-point used by NeMoAutoModel.from_pretrained / from_config where the caller has raw meshes rather than a parsed YAML config.

nemo_automodel.components.distributed.mesh.MeshContext.parallelize_axis_kwargs() -> typing.Dict[str, object]

Axis-name kwargs for parallelize_fn (EP/FSDP, no pp_axis_name).

nemo_automodel.components.distributed.mesh.MeshContext.pipeline_axis_kwargs() -> typing.Dict[str, object]

Axis-name kwargs for AutoPipeline.

class nemo_automodel.components.distributed.mesh.ParallelismSizes(
dp_size: int | None = None,
dp_replicate_size: int | None = None,
tp_size: int = 1,
pp_size: int = 1,
cp_size: int = 1,
ep_size: int = 1
)
Dataclass

Build-time requested parallelism sizes.

This is durable user intent, not runtime topology. MeshContext derives its size properties from live DeviceMesh objects after build.

cp_size
int = 1
dp_replicate_size
int | None = None
dp_size
int | None = None
ep_size
int = 1
pp_size
int = 1
tp_size
int = 1
nemo_automodel.components.distributed.mesh._get_axis_size(
mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
axis: nemo_automodel.components.distributed.mesh.MeshAxisName,
default = 1
) -> typing.Optional[int]

Return the size of axis if present in mesh, else default.

nemo_automodel.components.distributed.mesh._optional_axis(
mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh],
axis: nemo_automodel.components.distributed.mesh.MeshAxisName
) -> typing.Optional[str]

Return axis if present in mesh, else None.

nemo_automodel.components.distributed.mesh._validate_mesh_axis_names(
mesh_context: nemo_automodel.components.distributed.mesh.MeshContext
) -> None

Ensure every axis name in the attached meshes is a :class:MeshAxisName.

nemo_automodel.components.distributed.mesh._VALID_AXIS_NAMES: frozenset = frozenset(MeshAxisName)
nemo_automodel.components.distributed.mesh.__all__ = ['MeshAxisName', 'MeshContext', 'ParallelismSizes']