nemo_automodel.components.moe.experts#

Module Contents#

Classes#

GroupedExperts

Sparse MoE implementation using all-gather/reduce-scatter primitives.

GroupedExpertsDeepEP

Sparse MoE implementation using grouped GEMM with DeepEP token dispatch.

GroupedExpertsTE

MoE experts using TE’s GroupedLinear module directly.

Functions#

is_gated_activation

Check if activation requires gating (gate_proj + up_proj).

_permute_tokens_for_grouped_mm

Permute tokens by expert assignment and compute offs for torch._grouped_mm.

_apply_bias

Apply per-expert bias to grouped GEMM output.

quick_geglu_deepep

relu2_deepep

ReLU² activation for DeepEP: relu(x)^2

get_expert_activation_for_deepep

_torch_mm_experts_fwd

_init_weights

API#

nemo_automodel.components.moe.experts.is_gated_activation(activation: str) bool#

Check if activation requires gating (gate_proj + up_proj).

Gated activations (SwiGLU, Quick-GEGLU) use both gate_proj and up_proj, requiring gate_and_up_projs tensor with shape [n_experts, dim, 2*inter_dim].

Non-gated activations (ReLU²) only use up_proj, requiring up_projs tensor with shape [n_experts, dim, inter_dim] - 50% memory savings.

nemo_automodel.components.moe.experts._permute_tokens_for_grouped_mm(
indices: torch.Tensor,
weights: torch.Tensor,
token_mask: torch.Tensor,
n_local_experts: int,
experts_start_idx: int,
)#

Permute tokens by expert assignment and compute offs for torch._grouped_mm.

Takes the raw router outputs and produces sorted token IDs, routing weights, tokens_per_expert counts, and cumulative offsets ready for grouped GEMM.

Returns:

Token indices sorted by expert assignment. sorted_weights: Routing weights in the same sorted order. tokens_per_expert: Count of tokens per local expert. offs: Cumulative token counts (int32) for torch._grouped_mm.

Return type:

sorted_token_ids

nemo_automodel.components.moe.experts._apply_bias(value, bias, tokens_per_expert, permuted_probs=None)#

Apply per-expert bias to grouped GEMM output.

NOTE: torch._grouped_mm accepts a bias kwarg in its schema but raises “RuntimeError: Bias not supported yet” as of PyTorch 2.9.0. Additionally, down projection bias needs weighting by routing probs (bias * permuted_probs) which native bias support wouldn’t handle.

Parameters:
  • value – Output from grouped GEMM, shape [total_tokens, features].

  • bias – Per-expert bias, shape [num_experts, features].

  • tokens_per_expert – Token counts per expert.

  • permuted_probs – If provided, bias is weighted by routing probs (for down projection).

class nemo_automodel.components.moe.experts.GroupedExperts(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
)#

Bases: torch.nn.Module

Sparse MoE implementation using all-gather/reduce-scatter primitives.

Supports two compute backends:

  • Per-expert loop with gather/scatter (default)

  • torch._grouped_mm with argsort-based permutation (backend.experts=”torch_mm”)

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_and_up_projs

Linear layer for gate+up (gated) or just up (non-gated).

Type:

nn.Parameter

.. attribute:: down_projs

Linear layer for hidden-to-output transformation.

Type:

nn.Parameter

Initialization

Initializes the GroupedExperts module.

Parameters:
  • config – MoE configuration containing expert parameters.

  • backend – Backend configuration. When backend.experts == “torch_mm”, uses torch._grouped_mm instead of per-expert loop.

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass for the grouped experts.

Parameters:
  • x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

  • weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

  • indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns:

Output tensor after expert computation. Shape is [num_tokens, model_dim]

Return type:

torch.Tensor

_forward_loop(
x,
weights,
indices,
token_mask,
gate_and_up_projs,
down_projs,
n_local_experts,
experts_start_idx,
experts_end_idx,
)#

Per-expert loop forward path using gather/scatter.

_forward_grouped_mm(
x,
token_mask,
weights,
indices,
gate_and_up_projs,
down_projs,
n_local_experts,
experts_start_idx,
)#

Grouped GEMM forward path using torch._grouped_mm.

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
nemo_automodel.components.moe.experts.quick_geglu_deepep(
x,
permuted_probs,
alpha: float = 1.702,
limit: float = 7.0,
linear_offset: float = 1.0,
)#
nemo_automodel.components.moe.experts.relu2_deepep(x, permuted_probs)#

ReLU² activation for DeepEP: relu(x)^2

For DeepEP with ReLU², x is the output of the up projection (already computed). x already has shape […, inter_dim] from efficient up_proj.

nemo_automodel.components.moe.experts.get_expert_activation_for_deepep(
config: nemo_automodel.components.moe.config.MoEConfig,
)#
class nemo_automodel.components.moe.experts.GroupedExpertsDeepEP(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
)#

Bases: torch.nn.Module

Sparse MoE implementation using grouped GEMM with DeepEP token dispatch.

Supports two GEMM backends via BackendConfig.experts:

  • grouped_gemm.ops.gmm (experts=”gmm”, default)

  • torch._grouped_mm (experts=”torch_mm”, no external dependency)

Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_and_up_projs

Linear layer for gate+up (gated) or just up (non-gated).

Type:

nn.Parameter

.. attribute:: down_projs

Linear layer for hidden-to-output transformation.

Type:

nn.Parameter

Initialization

Initializes the GroupedExperts module.

Parameters:
  • config – MoE configuration containing expert parameters.

  • backend – Backend configuration. When backend.experts == “torch_mm”, uses torch._grouped_mm; otherwise uses grouped_gemm.ops.gmm.

init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
)#
forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass for the grouped experts.

Parameters:
  • x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

  • weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

  • indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns:

Output tensor after expert computation. Shape is [num_tokens, model_dim]

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
nemo_automodel.components.moe.experts._torch_mm_experts_fwd(
hidden_states,
gate_and_up_projs,
down_projs,
tokens_per_expert,
permuted_probs,
activation_fn,
)#
class nemo_automodel.components.moe.experts.GroupedExpertsTE(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
)#

Bases: torch.nn.Module

MoE experts using TE’s GroupedLinear module directly.

Uses TE’s native GroupedLinear for computation, providing:

  • Optimized grouped GEMM kernels from TE

For expert parallelism, each rank creates GroupedLinear with num_local_experts = n_routed_experts / ep_size.

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_up_linear

Combined gate and up projection.

Type:

GroupedLinear

.. attribute:: down_linear

Down projection.

Type:

GroupedLinear

Initialization

Initialize the GroupedExpertsTEGroupedLinear module.

Parameters:
  • config – MoE configuration containing expert parameters.

  • backend – Backend configuration (reserved for future use).

_get_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
transpose: bool = False,
) torch.Tensor#
_get_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear,
) Optional[torch.Tensor]#
_set_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor,
transpose: bool = False,
)#
_set_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor,
)#
_to_ep_dtensor(tensor: torch.Tensor) torch.Tensor#
_normalize_moe_mesh(
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) Optional[torch.distributed.device_mesh.DeviceMesh]#
set_moe_mesh(
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) None#
property gate_and_up_projs: torch.Tensor#
property down_projs: torch.Tensor#
property gate_up_proj_bias: Optional[torch.Tensor]#
property down_proj_bias: Optional[torch.Tensor]#
state_dict(
*args,
destination=None,
prefix='',
keep_vars=False,
**kwargs,
) Dict[str, Any]#

Return state dict with stacked tensors in DeepEP format.

Converts TE GroupedLinear’s weight{i} parameters to stacked format:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]

  • down_projs: [num_local_experts, moe_inter_dim, dim]

When EP is enabled, returns DTensors sharded on dimension 0.

_load_from_state_dict(
state_dict: Dict[str, Any],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)#

Load state dict with stacked tensors in DeepEP format.

Converts stacked format to TE GroupedLinear’s weight{i} parameters:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]

  • down_projs: [num_local_experts, moe_inter_dim, dim]

init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

Initialize the token dispatcher for expert parallelism.

Called by the parallelizer after model initialization.

Parameters:

ep_mesh – Device mesh for expert parallelism.

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass using TE’s GroupedLinear with native FP8 support.

Parameters:
  • x – [num_tokens, model_dim] input tensor

  • token_mask – [num_tokens] boolean mask for valid tokens

  • weights – [num_tokens, num_activated_experts] routing weights

  • indices – [num_tokens, num_activated_experts] expert indices

Returns:

[num_tokens, model_dim] output tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#

Initialize weights using reset_parameters()

nemo_automodel.components.moe.experts._init_weights(
module,
buffer_device: torch.device,
init_std: float = 0.02,
)#