nemo_automodel.components.moe.experts

View as Markdown

Module Contents

Classes

NameDescription
GroupedExpertsSparse MoE implementation using all-gather/reduce-scatter primitives.
GroupedExpertsDeepEPSparse MoE implementation using grouped GEMM with DeepEP token dispatch.
GroupedExpertsTEMoE experts using TE’s GroupedLinear module directly.
_AllGatherConcatVarlenFnAll-gather with variable local lengths and autograd-safe backward.

Functions

NameDescription
_apply_biasApply per-expert bias to grouped GEMM output.
_init_weights-
_permute_tokens_for_grouped_mmPermute tokens by expert assignment and compute offs for torch._grouped_mm.
_torch_mm_experts_fwd-
get_expert_activation_for_deepepReturn the DeepEP expert activation function selected by the MoE config.
is_gated_activationCheck if activation requires gating (gate_proj + up_proj).
quick_geglu_deepepApply DeepEP Quick-GEGLU activation and routing probabilities.
relu2_deepepReLU² activation for DeepEP: relu(x)^2
swiglu_clamped_deepepClamped SwiGLU (DeepSeek V4 style) for DeepEP.
swiglu_oai_deepepSwiGLU-OAI (GPT-OSS / MiniMax-M3) activation for grouped experts.

API

class nemo_automodel.components.moe.experts.GroupedExperts(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: typing.Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None
)

Bases: Module

Sparse MoE implementation using all-gather/reduce-scatter primitives.

Supports two compute backends:

  • Per-expert loop with gather/scatter (default)
  • torch._grouped_mm with argsort-based permutation (backend.experts=“torch_mm”)
down_proj_bias
down_projs
expert_activation_grouped
= get_expert_activation_for_deepep(config)
expert_bias
= config.expert_bias
gate_and_up_projs
gate_up_proj_bias
is_gated
= is_gated_activation(config.expert_activation)
n_routed_experts
= config.n_routed_experts
use_mxfp8
use_torch_mm
nemo_automodel.components.moe.experts.GroupedExperts._forward_grouped_mm(
x,
token_mask,
weights,
indices,
gate_and_up_projs,
down_projs,
gate_up_proj_bias,
down_proj_bias,
n_local_experts,
experts_start_idx
)

Grouped GEMM forward path using torch._grouped_mm.

nemo_automodel.components.moe.experts.GroupedExperts._forward_loop(
x,
weights,
indices,
token_mask,
gate_and_up_projs,
down_projs,
gate_up_proj_bias,
down_proj_bias,
n_local_experts,
experts_start_idx,
experts_end_idx
)

Per-expert loop forward path using gather/scatter.

nemo_automodel.components.moe.experts.GroupedExperts.forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor
) -> torch.Tensor

Forward pass for the grouped experts.

Parameters:

x
torch.Tensor

Input tensor. Shape is [num_tokens, model_dim].

token_mask
torch.Tensor

Boolean mask indicating valid tokens. Shape is [num_tokens].

weights
torch.Tensor

Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

indices
torch.Tensor

Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns: torch.Tensor

torch.Tensor: Output tensor after expert computation. Shape is [num_tokens, model_dim]

nemo_automodel.components.moe.experts.GroupedExperts.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.moe.experts.GroupedExpertsDeepEP(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: typing.Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
dispatcher_backend: str = 'deepep',
dispatcher_num_sms: int = 20,
dispatcher_share_token_dispatcher: bool = True,
dispatcher_async_dispatch: bool = False
)

Bases: Module

Sparse MoE implementation using grouped GEMM with DeepEP token dispatch.

Supports two GEMM backends via BackendConfig.experts:

  • grouped_gemm.ops.gmm (experts=“gmm”, default)
  • torch._grouped_mm (experts=“torch_mm”, no external dependency)

Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.

down_proj_bias
down_projs
expert_activation
= get_expert_activation_for_deepep(config)
expert_bias
= config.expert_bias
gate_and_up_projs
gate_up_proj_bias
is_gated
= is_gated_activation(config.expert_activation)
use_mxfp8
use_torch_mm
nemo_automodel.components.moe.experts.GroupedExpertsDeepEP._init_deepep_buffer(
ep_group: torch.distributed.ProcessGroup
) -> None

Initialize DeepEP communication buffers before activation checkpointing.

nemo_automodel.components.moe.experts.GroupedExpertsDeepEP.forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor
) -> torch.Tensor

Forward pass for the grouped experts.

Parameters:

x
torch.Tensor

Input tensor. Shape is [num_tokens, model_dim].

token_mask
torch.Tensor

Boolean mask indicating valid tokens. Shape is [num_tokens].

weights
torch.Tensor

Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

indices
torch.Tensor

Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns: torch.Tensor

torch.Tensor: Output tensor after expert computation. Shape is [num_tokens, model_dim]

nemo_automodel.components.moe.experts.GroupedExpertsDeepEP.init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh
)
nemo_automodel.components.moe.experts.GroupedExpertsDeepEP.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.moe.experts.GroupedExpertsTE(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: typing.Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
dispatcher_backend: str = 'deepep',
dispatcher_num_sms: int = 20,
dispatcher_share_token_dispatcher: bool = True,
dispatcher_async_dispatch: bool = False
)

Bases: Module

MoE experts using TE’s GroupedLinear module directly.

Uses TE’s native GroupedLinear for computation, providing:

  • Optimized grouped GEMM kernels from TE

For expert parallelism, each rank creates GroupedLinear with num_local_experts = n_routed_experts / ep_size.

dim
= config.dim
down_linear
down_proj_bias
Optional[Tensor]
down_projs
Tensor
ep_rank
= 0
expert_activation
= get_expert_activation_for_deepep(config)
expert_bias
= config.expert_bias
fp8_padding
= Fp8Padding(config.n_routed_experts)
fp8_unpadding
= Fp8Unpadding(config.n_routed_experts)
gate_and_up_projs
Tensor
gate_up_linear
gate_up_proj_bias
Optional[Tensor]
is_gated
= is_gated_activation(config.expert_activation)
moe_inter_dim
= config.moe_inter_dim
num_local_experts
= config.n_routed_experts
nemo_automodel.components.moe.experts.GroupedExpertsTE._get_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear
) -> typing.Optional[torch.Tensor]
nemo_automodel.components.moe.experts.GroupedExpertsTE._get_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
transpose: bool = False
) -> torch.Tensor
nemo_automodel.components.moe.experts.GroupedExpertsTE._load_from_state_dict(
state_dict: typing.Dict[str, typing.Any],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs
)

Load state dict with stacked tensors in DeepEP format.

Converts stacked format to TE GroupedLinear’s weight{i} parameters:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
  • down_projs: [num_local_experts, moe_inter_dim, dim]
nemo_automodel.components.moe.experts.GroupedExpertsTE._normalize_moe_mesh(
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh]
) -> typing.Optional[torch.distributed.device_mesh.DeviceMesh]
nemo_automodel.components.moe.experts.GroupedExpertsTE._set_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor
)
nemo_automodel.components.moe.experts.GroupedExpertsTE._set_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor,
transpose: bool = False
)
nemo_automodel.components.moe.experts.GroupedExpertsTE._to_ep_dtensor(
tensor: torch.Tensor
) -> torch.Tensor
nemo_automodel.components.moe.experts.GroupedExpertsTE.forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor
) -> torch.Tensor

Forward pass using TE’s GroupedLinear with native FP8 support.

Parameters:

x
torch.Tensor

[num_tokens, model_dim] input tensor

token_mask
torch.Tensor

[num_tokens] boolean mask for valid tokens

weights
torch.Tensor

[num_tokens, num_activated_experts] routing weights

indices
torch.Tensor

[num_tokens, num_activated_experts] expert indices

Returns: torch.Tensor

[num_tokens, model_dim] output tensor

nemo_automodel.components.moe.experts.GroupedExpertsTE.init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
)

Initialize the token dispatcher for expert parallelism.

Called by the parallelizer after model initialization.

Parameters:

ep_mesh
DeviceMesh

Device mesh for expert parallelism.

nemo_automodel.components.moe.experts.GroupedExpertsTE.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None

Initialize weights using reset_parameters()

nemo_automodel.components.moe.experts.GroupedExpertsTE.set_moe_mesh(
moe_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh]
) -> None
nemo_automodel.components.moe.experts.GroupedExpertsTE.state_dict(
args = (),
destination = None,
prefix = '',
keep_vars = False,
kwargs = {}
) -> typing.Dict[str, typing.Any]

Return state dict with stacked tensors in DeepEP format.

Converts TE GroupedLinear’s weight{i} parameters to stacked format:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
  • down_projs: [num_local_experts, moe_inter_dim, dim]

When EP is enabled, returns DTensors sharded on dimension 0.

class nemo_automodel.components.moe.experts._AllGatherConcatVarlenFn()

Bases: Function

All-gather with variable local lengths and autograd-safe backward.

Backward uses all-reduce + local narrow instead of reduce-scatter to avoid monitoredBarrier deadlocks observed with mixed FSDP/EP backward collective ordering.

nemo_automodel.components.moe.experts._AllGatherConcatVarlenFn.backward(
ctx,
grad_output: torch.Tensor
)
staticmethod
nemo_automodel.components.moe.experts._AllGatherConcatVarlenFn.forward(
ctx,
local_tensor: torch.Tensor,
group: torch.distributed.ProcessGroup,
gathered_lens: list[int],
max_len: int
)
staticmethod
nemo_automodel.components.moe.experts._apply_bias(
value,
bias,
tokens_per_expert,
permuted_probs = None
)

Apply per-expert bias to grouped GEMM output.

NOTE: torch._grouped_mm accepts a bias kwarg in its schema but raises “RuntimeError: Bias not supported yet” as of PyTorch 2.9.0. Additionally, down projection bias needs weighting by routing probs (bias * permuted_probs) which native bias support wouldn’t handle.

Parameters:

value

Output from grouped GEMM, shape [total_tokens, features].

bias

Per-expert bias, shape [num_experts, features].

tokens_per_expert

Token counts per expert.

permuted_probs
Defaults to None

If provided, bias is weighted by routing probs (for down projection).

nemo_automodel.components.moe.experts._init_weights(
module,
buffer_device: torch.device,
init_std: float = 0.02
)
nemo_automodel.components.moe.experts._permute_tokens_for_grouped_mm(
indices: torch.Tensor,
weights: torch.Tensor,
token_mask: torch.Tensor,
n_local_experts: int,
experts_start_idx: int
)

Permute tokens by expert assignment and compute offs for torch._grouped_mm.

Takes the raw router outputs and produces sorted token IDs, routing weights, tokens_per_expert counts, and cumulative offsets ready for grouped GEMM.

Returns:

Token indices sorted by expert assignment.

nemo_automodel.components.moe.experts._torch_mm_experts_fwd(
hidden_states,
gate_and_up_projs,
down_projs,
tokens_per_expert,
permuted_probs,
activation_fn,
use_mxfp8 = False
)
nemo_automodel.components.moe.experts.get_expert_activation_for_deepep(
config: nemo_automodel.components.moe.config.MoEConfig
)

Return the DeepEP expert activation function selected by the MoE config.

nemo_automodel.components.moe.experts.is_gated_activation(
activation: str
) -> bool

Check if activation requires gating (gate_proj + up_proj).

Gated activations (SwiGLU, Quick-GEGLU) use both gate_proj and up_proj, requiring gate_and_up_projs tensor with shape [n_experts, dim, 2*inter_dim].

Non-gated activations (ReLU²) only use up_proj, requiring up_projs tensor with shape [n_experts, dim, inter_dim] - 50% memory savings.

nemo_automodel.components.moe.experts.quick_geglu_deepep(
x,
permuted_probs,
alpha: float = 1.702,
limit: float = 7.0,
linear_offset: float = 1.0
)

Apply DeepEP Quick-GEGLU activation and routing probabilities.

nemo_automodel.components.moe.experts.relu2_deepep(
x,
permuted_probs
)

ReLU² activation for DeepEP: relu(x)^2

For DeepEP with ReLU², x is the output of the up projection (already computed). x already has shape […, inter_dim] from efficient up_proj.

nemo_automodel.components.moe.experts.swiglu_clamped_deepep(
x,
permuted_probs,
limit: float
)

Clamped SwiGLU (DeepSeek V4 style) for DeepEP.

Gate is clamped at max=limit and up at (-limit, +limit) in FP32 before silu(gate) * up; the result is multiplied by the permuted routing probs and cast back. Matches the official V4 Expert.forward::

gate = self.w1(x).float() up = self.w3(x).float() if self.swiglu_limit > 0: up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) gate = torch.clamp(gate, max=swiglu_limit) y = F.silu(gate) * up

x has shape [..., 2 * inter_dim] with gate in the first half and up in the second half (same layout as weighted_bias_swiglu_impl).

nemo_automodel.components.moe.experts.swiglu_oai_deepep(
x,
permuted_probs,
alpha: float = 1.702,
limit: float = 7.0
)

SwiGLU-OAI (GPT-OSS / MiniMax-M3) activation for grouped experts.

Computes gate * sigmoid(alpha * gate) * (up + 1) in fp32 with gate clamped max=limit and up clamped +/-limit (when limit > 0).

Unlike :func:quick_geglu_deepep (which expects an interleaved gate/up layout, x[..., ::2] / x[..., 1::2]), this reads the concatenated [gate | up] layout produced by MoESplitExpertsStateDictMixin (torch.cat([gate_t, up_t], dim=-1)), matching sglang’s swiglu_no_interleaved_with_alpha_and_limit.