nemo_automodel.components.moe.parallelizer

View as Markdown

Module Contents

Classes

NameDescription
ExpertParallelExpertParallel class is used to shard the MoE parameters on the EP mesh.

Functions

NameDescription
_get_cp_stream-
_get_model_moe_configReturn the model-level MoE config exposed by custom MoE architectures.
_get_moe_module-
_is_deepseek_v4_model-
_is_selective_acReturn True when the AC mode requests selective checkpointing.
_iter_moe_blocksYield decoder blocks that may contain MoE sublayers.
_iter_transformer_and_mtp_blocks-
_module_weights_are_tiedReturn True when two modules expose the same weight parameter object.
_moe_shard_placementFSDP shard placement for grouped-expert params.
_shard_fp32_param_holdersShard each _fp32_params holder in block as its own fp32 FSDP unit.
apply_acApply activation checkpointing to the model.
apply_cpConfigure context parallelism for attention and MoE layers.
apply_epApplies EP to MoE module.
apply_fsdpApply FSDP wrapping to MoE transformer blocks and model-level modules.
parallelize_modelApply context, expert, activation-checkpointing, and FSDP parallelism.

Data

_CP_STREAM

logger

API

class nemo_automodel.components.moe.parallelizer.ExpertParallel()

Bases: ParallelStyle

ExpertParallel class is used to shard the MoE parameters on the EP mesh. Dim 0 of each parameter is sharded since that is the expert dimension.

nemo_automodel.components.moe.parallelizer.ExpertParallel._apply(
module: torch.nn.Module,
device_mesh: torch.distributed.device_mesh.DeviceMesh
) -> torch.nn.Module
nemo_automodel.components.moe.parallelizer.ExpertParallel._partition_fn(
name,
module,
device_mesh
)
nemo_automodel.components.moe.parallelizer._get_cp_stream() -> torch.cuda.Stream
nemo_automodel.components.moe.parallelizer._get_model_moe_config(
model: torch.nn.Module
)

Return the model-level MoE config exposed by custom MoE architectures.

nemo_automodel.components.moe.parallelizer._get_moe_module(
block: torch.nn.Module
) -> nemo_automodel.components.moe.layers.MoE | None
nemo_automodel.components.moe.parallelizer._is_deepseek_v4_model(
model: torch.nn.Module
) -> bool
nemo_automodel.components.moe.parallelizer._is_selective_ac(
activation_checkpointing: object
) -> bool

Return True when the AC mode requests selective checkpointing.

Kept inline (rather than imported from the dense FSDP2 parallelizer) so that threading the mode does not pull the heavy distributed.parallelizer module into the lightweight call path.

nemo_automodel.components.moe.parallelizer._iter_moe_blocks(
model_wrapper: torch.nn.Module,
backbone: torch.nn.Module
)

Yield decoder blocks that may contain MoE sublayers.

Covers the main backbone (backbone.layers) plus an optional MTP auxiliary head (model_wrapper.mtp.layers) when present. MTP sublayers are not registered under backbone.layers but carry the same MoE structure and must receive the same EP / FSDP treatment so their state-dict round-trips cleanly.

Parameters:

model_wrapper
nn.Module

Outer model (e.g. NemotronHForCausalLM) — the attribute that may carry the MTP head.

backbone
nn.Module

Inner backbone (model_wrapper.model, possibly text-only after VLM unwrapping) whose .layers holds the main decoder stack.

nemo_automodel.components.moe.parallelizer._iter_transformer_and_mtp_blocks(
model: torch.nn.Module
)
nemo_automodel.components.moe.parallelizer._module_weights_are_tied(
left: torch.nn.Module | None,
right: torch.nn.Module | None
) -> bool

Return True when two modules expose the same weight parameter object.

nemo_automodel.components.moe.parallelizer._moe_shard_placement(
param
)

FSDP shard placement for grouped-expert params.

Shard on dim=1 for the (>=2D) expert weights since there may be more shards than experts (dim=0). A 1D param (e.g. the per-expert bias of the experts=“te” GroupedLinear path, shape [out_features]) has no dim 1, so shard it on dim 0 instead. FSDP all-gathers before use, so the shard dim is a storage detail and does not change compute.

nemo_automodel.components.moe.parallelizer._shard_fp32_param_holders(
block,
fsdp_mesh,
reshard_after_forward,
offload_policy
)

Shard each _fp32_params holder in block as its own fp32 FSDP unit.

Model implementations own the architecture-specific decision to create these holders (for example Qwen3.5/Qwen3-Next GatedDeltaNet A_log/dt_bias). FSDP only treats the holder as a dtype-uniform fp32 unit and excludes its params from the block’s bf16 FSDP unit.

Returns the set of holder parameters to exclude from the block’s FSDP wrap. Blocks that do not expose named_modules (e.g. non-nn.Module test stubs) cannot hold fp32 holders, so an empty set is returned.

nemo_automodel.components.moe.parallelizer.apply_ac(
model: torch.nn.Module,
ignore_router: bool = True,
hidden_size: int | None = None,
num_experts: int | None = None,
selective: bool = False
)

Apply activation checkpointing to the model.

Parameters:

model
nn.Module

The model to apply activation checkpointing to.

ignore_router
boolDefaults to True

If True (the default), saves the MoE router output so the dispatch is not recomputed under activation checkpointing (avoids a CheckpointError from non-deterministic re-routing on recompute). If False, a warning is emitted.

hidden_size
int | NoneDefaults to None

Hidden dimension size. If None, derived from model.config.hidden_size.

num_experts
int | NoneDefaults to None

Number of routed experts. If None, derived from moe_config.n_routed_experts first, then falls back to model.config attributes.

selective
boolDefaults to False

If True, applies TorchTitan-style per-op selective activation checkpointing (shared with the dense FSDP2 path) to each block. Takes precedence over ignore_router; the shared policy already saves expert-parallel communication collectives and topk, so it composes with expert parallelism.

nemo_automodel.components.moe.parallelizer.apply_cp(
model: torch.nn.Module,
cp_mesh: torch.distributed.device_mesh.DeviceMesh,
cp_comm_type: str = 'p2p'
)

Configure context parallelism for attention and MoE layers.

nemo_automodel.components.moe.parallelizer.apply_ep(
model: torch.nn.Module,
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh | None = None
)

Applies EP to MoE module.

nemo_automodel.components.moe.parallelizer.apply_fsdp(
model: torch.nn.Module,
fsdp_mesh: torch.distributed.device_mesh.DeviceMesh,
ep_enabled: bool,
ep_shard_enabled: bool,
ep_shard_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
mp_policy: torch.distributed.fsdp._fully_shard.MixedPrecisionPolicy | None = None,
offload_policy: torch.distributed.fsdp._fully_shard.OffloadPolicy | None = None,
reshard_after_forward: bool = False,
lm_head_precision: str | torch.dtype | None = None,
wrap_outer_model: bool = True
)

Apply FSDP wrapping to MoE transformer blocks and model-level modules.

nemo_automodel.components.moe.parallelizer.parallelize_model(
model: torch.nn.Module,
world_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: torch.distributed.device_mesh.DeviceMesh | None,
dp_axis_names: tuple[str, ...],
cp_axis_name: str | None = None,
tp_axis_name: str | None = None,
ep_axis_name: str | None = None,
ep_shard_axis_names: tuple[str, ...] | None = None,
activation_checkpointing: bool | str = False,
ignore_router_for_ac: bool = True,
reshard_after_forward: bool = False,
lm_head_precision: str | torch.dtype | None = None,
wrap_outer_model: bool = True,
mp_policy: torch.distributed.fsdp._fully_shard.MixedPrecisionPolicy | None = None
)

Apply context, expert, activation-checkpointing, and FSDP parallelism.

nemo_automodel.components.moe.parallelizer._CP_STREAM = None
nemo_automodel.components.moe.parallelizer.logger = logging.getLogger(__name__)