nemo_automodel.components.distributed.mesh#

Typed MeshContext dataclass, validation, and strategy map.

MeshContext is the single source of truth for everything related to distributed training: strategy config, device meshes, and axis names.

Parallelism sizes (tp_size, pp_size, etc.) are derived at runtime from the attached DeviceMesh objects via @property. When no mesh is present the properties return safe defaults (1 for sizes, None for dp / hsdp).

All inputs and outputs are typed Python objects (dataclasses, enums, etc.). YAML / dict parsing belongs in the recipe layer — see nemo_automodel.recipes._dist_setup.

Module Contents#

Classes#

MeshAxisName

Canonical mesh-dimension names used by DeviceMesh and helpers.

MeshContext

Runtime distributed training context: configs + device meshes.

Functions#

_get_axis_size

Return the size of axis if present in mesh, else default.

_optional_axis

Return axis if present in mesh, else None.

_validate_mesh_dim_names

Ensure every dimension name in the attached meshes is a :class:MeshAxisName.

_validate_distributed_setup

Validate cross-field constraints on a :class:MeshContext.

Data#

API#

nemo_automodel.components.distributed.mesh.STRATEGY_MAP: Dict[str, type]#

None

class nemo_automodel.components.distributed.mesh.MeshAxisName#

Bases: str, enum.Enum

Canonical mesh-dimension names used by DeviceMesh and helpers.

Inherits from str so each member compares equal to (and can be used wherever) a plain string — e.g. MeshAxisName.TP == "tp".

Initialization

Initialize self. See help(type(self)) for accurate signature.

PP#

‘pp’

DP#

‘dp’

DP_REPLICATE#

‘dp_replicate’

DP_SHARD#

‘dp_shard’

DP_SHARD_CP#

‘dp_shard_cp’

DP_CP#

‘dp_cp’

CP#

‘cp’

TP#

‘tp’

EP#

‘ep’

EP_SHARD#

‘ep_shard’

nemo_automodel.components.distributed.mesh._VALID_AXIS_NAMES: frozenset#

‘frozenset(…)’

class nemo_automodel.components.distributed.mesh.MeshContext#

Runtime distributed training context: configs + device meshes.

Parallelism sizes (tp_size, pp_size, etc.) are not stored as fields; they are @property accessors that read directly from the attached DeviceMesh / moe_mesh. When no mesh is present the properties return safe defaults (1 for sizes, None for dp / hsdp).

All DeviceMesh objects passed in must use dimension names from

Class:

MeshAxisName; a ValueError is raised on construction if any unknown name is encountered.

Lifecycle#

  1. Recipes parse YAML to obtain sizes and strategy configs.

  2. Sizes are passed to create_device_mesh to build DeviceMesh objects.

  3. MeshContext is created with those meshes; dimension names are validated automatically in __post_init__.

Alternatively, :meth:from_meshes constructs an instance directly from DeviceMesh objects (used by NeMoAutoModel.from_pretrained).

.. attribute:: strategy_config

Strategy-specific config (FSDP2, MegatronFSDP, or DDP).

.. attribute:: device_mesh

Device mesh for distributed training.

.. attribute:: moe_mesh

MoE-specific device mesh.

.. attribute:: pipeline_config

Pipeline-parallel schedule/splitting config.

.. attribute:: moe_config

MoE parallelizer settings.

.. attribute:: activation_checkpointing

Whether activation checkpointing is enabled.

strategy_config: Optional[Union[nemo_automodel.components.distributed.config.FSDP2Config, nemo_automodel.components.distributed.config.MegatronFSDPConfig, nemo_automodel.components.distributed.config.DDPConfig]]#

None

pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig]#

None

moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig]#

None

activation_checkpointing: bool#

False

device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh]#

‘field(…)’

moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh]#

‘field(…)’

__post_init__() None#
property pp_size: int#

Pipeline-parallel degree (from device_mesh, default 1).

property pp_enabled: bool#

True when pp_size > 1.

property tp_size: int#

Tensor-parallel degree (from device_mesh, default 1).

property cp_size: int#

Context-parallel degree (from device_mesh, default 1).

property ep_size: int#

Expert-parallel degree (from moe_mesh, default 1).

property dp_size: Optional[int]#

Data-parallel degree (from device_mesh, default None).

property dp_replicate_size: Optional[int]#

HSDP replication degree (from device_mesh, default None).

_dp_axis_names() Tuple[str, ...]#

DP axis names for FSDP mesh slicing.

pipeline_axis_kwargs() Dict[str, object]#

Axis-name kwargs for AutoPipeline.

parallelize_axis_kwargs() Dict[str, object]#

Axis-name kwargs for parallelize_fn (EP/FSDP, no pp_axis_name).

classmethod from_meshes(
device_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
*,
strategy_config: Optional[Union[nemo_automodel.components.distributed.config.FSDP2Config, nemo_automodel.components.distributed.config.MegatronFSDPConfig, nemo_automodel.components.distributed.config.DDPConfig]] = None,
pipeline_config: Optional[nemo_automodel.components.distributed.pipelining.config.PipelineConfig] = None,
moe_config: Optional[nemo_automodel.components.moe.config.MoEParallelizerConfig] = None,
activation_checkpointing: bool = False,
) nemo_automodel.components.distributed.mesh.MeshContext#

Build a :class:MeshContext from DeviceMesh objects.

This is the entry-point used by NeMoAutoModel.from_pretrained / from_config where the caller has raw meshes rather than a parsed YAML config.

nemo_automodel.components.distributed.mesh._get_axis_size(
mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
axis: nemo_automodel.components.distributed.mesh.MeshAxisName,
default=1,
) Optional[int]#

Return the size of axis if present in mesh, else default.

nemo_automodel.components.distributed.mesh._optional_axis(
mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
axis: nemo_automodel.components.distributed.mesh.MeshAxisName,
) Optional[str]#

Return axis if present in mesh, else None.

nemo_automodel.components.distributed.mesh._validate_mesh_dim_names(
mesh_context: nemo_automodel.components.distributed.mesh.MeshContext,
) None#

Ensure every dimension name in the attached meshes is a :class:MeshAxisName.

nemo_automodel.components.distributed.mesh._validate_distributed_setup(
mesh_context: nemo_automodel.components.distributed.mesh.MeshContext,
) None#

Validate cross-field constraints on a :class:MeshContext.

Called automatically by MeshContext.__post_init__ when a strategy_config is present. Can also be invoked explicitly after mutating a context.

Raises:

ValueError – If any constraint is violated.