nemo_automodel.components.moe.experts
nemo_automodel.components.moe.experts
Module Contents
Classes
Functions
API
Bases: Module
Sparse MoE implementation using all-gather/reduce-scatter primitives.
Supports two compute backends:
- Per-expert loop with gather/scatter (default)
- torch._grouped_mm with argsort-based permutation (backend.experts=“torch_mm”)
Grouped GEMM forward path using torch._grouped_mm.
Per-expert loop forward path using gather/scatter.
Forward pass for the grouped experts.
Parameters:
Input tensor. Shape is [num_tokens, model_dim].
Boolean mask indicating valid tokens. Shape is [num_tokens].
Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].
Indices of the selected experts. Shape is [num_tokens, num_activated_experts].
Returns: torch.Tensor
torch.Tensor: Output tensor after expert computation. Shape is [num_tokens, model_dim]
Bases: Module
Sparse MoE implementation using grouped GEMM with DeepEP token dispatch.
Supports two GEMM backends via BackendConfig.experts:
- grouped_gemm.ops.gmm (experts=“gmm”, default)
- torch._grouped_mm (experts=“torch_mm”, no external dependency)
Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.
Initialize DeepEP communication buffers before activation checkpointing.
Forward pass for the grouped experts.
Parameters:
Input tensor. Shape is [num_tokens, model_dim].
Boolean mask indicating valid tokens. Shape is [num_tokens].
Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].
Indices of the selected experts. Shape is [num_tokens, num_activated_experts].
Returns: torch.Tensor
torch.Tensor: Output tensor after expert computation. Shape is [num_tokens, model_dim]
Bases: Module
MoE experts using TE’s GroupedLinear module directly.
Uses TE’s native GroupedLinear for computation, providing:
- Optimized grouped GEMM kernels from TE
For expert parallelism, each rank creates GroupedLinear with num_local_experts = n_routed_experts / ep_size.
Load state dict with stacked tensors in DeepEP format.
Converts stacked format to TE GroupedLinear’s weight{i} parameters:
- gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
- down_projs: [num_local_experts, moe_inter_dim, dim]
Forward pass using TE’s GroupedLinear with native FP8 support.
Parameters:
[num_tokens, model_dim] input tensor
[num_tokens] boolean mask for valid tokens
[num_tokens, num_activated_experts] routing weights
[num_tokens, num_activated_experts] expert indices
Returns: torch.Tensor
[num_tokens, model_dim] output tensor
Initialize the token dispatcher for expert parallelism.
Called by the parallelizer after model initialization.
Parameters:
Device mesh for expert parallelism.
Initialize weights using reset_parameters()
Return state dict with stacked tensors in DeepEP format.
Converts TE GroupedLinear’s weight{i} parameters to stacked format:
- gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
- down_projs: [num_local_experts, moe_inter_dim, dim]
When EP is enabled, returns DTensors sharded on dimension 0.
Bases: Function
All-gather with variable local lengths and autograd-safe backward.
Backward uses all-reduce + local narrow instead of reduce-scatter to avoid monitoredBarrier deadlocks observed with mixed FSDP/EP backward collective ordering.
Apply per-expert bias to grouped GEMM output.
NOTE: torch._grouped_mm accepts a bias kwarg in its schema but raises
“RuntimeError: Bias not supported yet” as of PyTorch 2.9.0.
Additionally, down projection bias needs weighting by routing probs
(bias * permuted_probs) which native bias support wouldn’t handle.
Parameters:
Output from grouped GEMM, shape [total_tokens, features].
Per-expert bias, shape [num_experts, features].
Token counts per expert.
If provided, bias is weighted by routing probs (for down projection).
Permute tokens by expert assignment and compute offs for torch._grouped_mm.
Takes the raw router outputs and produces sorted token IDs, routing weights, tokens_per_expert counts, and cumulative offsets ready for grouped GEMM.
Returns:
Token indices sorted by expert assignment.
Return the DeepEP expert activation function selected by the MoE config.
Check if activation requires gating (gate_proj + up_proj).
Gated activations (SwiGLU, Quick-GEGLU) use both gate_proj and up_proj, requiring gate_and_up_projs tensor with shape [n_experts, dim, 2*inter_dim].
Non-gated activations (ReLU²) only use up_proj, requiring up_projs tensor with shape [n_experts, dim, inter_dim] - 50% memory savings.
Apply DeepEP Quick-GEGLU activation and routing probabilities.
ReLU² activation for DeepEP: relu(x)^2
For DeepEP with ReLU², x is the output of the up projection (already computed). x already has shape […, inter_dim] from efficient up_proj.
Clamped SwiGLU (DeepSeek V4 style) for DeepEP.
Gate is clamped at max=limit and up at (-limit, +limit) in FP32
before silu(gate) * up; the result is multiplied by the permuted
routing probs and cast back. Matches the official V4 Expert.forward::
gate = self.w1(x).float() up = self.w3(x).float() if self.swiglu_limit > 0: up = torch.clamp(up, min=-swiglu_limit, max=swiglu_limit) gate = torch.clamp(gate, max=swiglu_limit) y = F.silu(gate) * up
x has shape [..., 2 * inter_dim] with gate in the first half
and up in the second half (same layout as weighted_bias_swiglu_impl).
SwiGLU-OAI (GPT-OSS / MiniMax-M3) activation for grouped experts.
Computes gate * sigmoid(alpha * gate) * (up + 1) in fp32 with gate
clamped max=limit and up clamped +/-limit (when limit > 0).
Unlike :func:quick_geglu_deepep (which expects an interleaved gate/up
layout, x[..., ::2] / x[..., 1::2]), this reads the concatenated
[gate | up] layout produced by MoESplitExpertsStateDictMixin
(torch.cat([gate_t, up_t], dim=-1)), matching sglang’s
swiglu_no_interleaved_with_alpha_and_limit.