nemo_automodel.components.moe.parallelizer#
Module Contents#
Classes#
ExpertParallel class is used to shard the MoE parameters on the EP mesh.
Dim |
Functions#
Yield decoder blocks that may contain MoE sublayers. |
|
Applies EP to MoE module. |
|
Apply activation checkpointing to the model. |
|
Apply FSDP wrapping to MoE transformer blocks and model-level modules. |
|
Configure context parallelism for attention and MoE layers. |
|
Apply context, expert, activation-checkpointing, and FSDP parallelism. |
Data#
API#
- nemo_automodel.components.moe.parallelizer.logger#
‘getLogger(…)’
- nemo_automodel.components.moe.parallelizer._CP_STREAM#
None
- nemo_automodel.components.moe.parallelizer._is_deepseek_v4_model(model: torch.nn.Module) bool#
- nemo_automodel.components.moe.parallelizer._get_cp_stream() torch.cuda.Stream#
- nemo_automodel.components.moe.parallelizer._iter_transformer_and_mtp_blocks(model: torch.nn.Module)#
- nemo_automodel.components.moe.parallelizer._get_moe_module(
- block: torch.nn.Module,
- class nemo_automodel.components.moe.parallelizer.ExpertParallel#
Bases:
torch.distributed.tensor.parallel.ParallelStyleExpertParallel class is used to shard the MoE parameters on the EP mesh. Dim
0of each parameter is sharded since that is the expert dimension.- _partition_fn(name, module, device_mesh)#
- _apply(
- module: torch.nn.Module,
- device_mesh: torch.distributed.device_mesh.DeviceMesh,
- nemo_automodel.components.moe.parallelizer._iter_moe_blocks(
- model_wrapper: torch.nn.Module,
- backbone: torch.nn.Module,
Yield decoder blocks that may contain MoE sublayers.
Covers the main backbone (
backbone.layers) plus an optional MTP auxiliary head (model_wrapper.mtp.layers) when present. MTP sublayers are not registered underbackbone.layersbut carry the same MoE structure and must receive the same EP / FSDP treatment so their state-dict round-trips cleanly.- Parameters:
model_wrapper – Outer model (e.g.
NemotronHForCausalLM) — the attribute that may carry the MTP head.backbone – Inner backbone (
model_wrapper.model, possibly text-only after VLM unwrapping) whose.layersholds the main decoder stack.
- nemo_automodel.components.moe.parallelizer.apply_ep(
- model: torch.nn.Module,
- ep_mesh: torch.distributed.device_mesh.DeviceMesh,
- moe_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
Applies EP to MoE module.
- nemo_automodel.components.moe.parallelizer.apply_ac(
- model: torch.nn.Module,
- ignore_router: bool = False,
- hidden_size: int | None = None,
- num_experts: int | None = None,
Apply activation checkpointing to the model.
- Parameters:
model – The model to apply activation checkpointing to.
ignore_router – If True, uses selective checkpointing that saves router outputs.
hidden_size – Hidden dimension size. If None, derived from model.config.hidden_size.
num_experts – Number of routed experts. If None, derived from moe_config.n_routed_experts first, then falls back to model.config attributes.
- nemo_automodel.components.moe.parallelizer.apply_fsdp(
- model: torch.nn.Module,
- fsdp_mesh: torch.distributed.device_mesh.DeviceMesh,
- ep_enabled: bool,
- ep_shard_enabled: bool,
- ep_shard_mesh: torch.distributed.device_mesh.DeviceMesh | None = None,
- mp_policy: torch.distributed.fsdp._fully_shard.MixedPrecisionPolicy | None = None,
- offload_policy: torch.distributed.fsdp._fully_shard.OffloadPolicy | None = None,
- reshard_after_forward: bool = False,
- lm_head_precision: str | torch.dtype | None = None,
- wrap_outer_model: bool = True,
Apply FSDP wrapping to MoE transformer blocks and model-level modules.
- nemo_automodel.components.moe.parallelizer.apply_cp(
- model: torch.nn.Module,
- cp_mesh: torch.distributed.device_mesh.DeviceMesh,
- cp_comm_type: str = 'p2p',
Configure context parallelism for attention and MoE layers.
- nemo_automodel.components.moe.parallelizer.parallelize_model(
- model: torch.nn.Module,
- world_mesh: torch.distributed.device_mesh.DeviceMesh,
- moe_mesh: torch.distributed.device_mesh.DeviceMesh | None,
- *,
- dp_axis_names: tuple[str, ...],
- cp_axis_name: str | None = None,
- tp_axis_name: str | None = None,
- ep_axis_name: str | None = None,
- ep_shard_axis_names: tuple[str, ...] | None = None,
- activation_checkpointing: bool = False,
- ignore_router_for_ac: bool = False,
- reshard_after_forward: bool = False,
- lm_head_precision: str | torch.dtype | None = None,
- wrap_outer_model: bool = True,
- mp_policy: torch.distributed.fsdp._fully_shard.MixedPrecisionPolicy | None = None,
Apply context, expert, activation-checkpointing, and FSDP parallelism.