nemo_automodel.components.moe.experts#
Module Contents#
Classes#
Sparse MoE implementation using all-gather/reduce-scatter primitives. |
|
Sparse MoE implementation using grouped GEMM with DeepEP token dispatch. |
|
MoE experts using TE’s GroupedLinear module directly. |
Functions#
Check if activation requires gating (gate_proj + up_proj). |
|
Permute tokens by expert assignment and compute offs for torch._grouped_mm. |
|
Apply per-expert bias to grouped GEMM output. |
|
ReLU² activation for DeepEP: relu(x)^2 |
|
API#
- nemo_automodel.components.moe.experts.is_gated_activation(activation: str) bool#
Check if activation requires gating (gate_proj + up_proj).
Gated activations (SwiGLU, Quick-GEGLU) use both gate_proj and up_proj, requiring gate_and_up_projs tensor with shape [n_experts, dim, 2*inter_dim].
Non-gated activations (ReLU²) only use up_proj, requiring up_projs tensor with shape [n_experts, dim, inter_dim] - 50% memory savings.
- nemo_automodel.components.moe.experts._permute_tokens_for_grouped_mm(
- indices: torch.Tensor,
- weights: torch.Tensor,
- token_mask: torch.Tensor,
- n_local_experts: int,
- experts_start_idx: int,
Permute tokens by expert assignment and compute offs for torch._grouped_mm.
Takes the raw router outputs and produces sorted token IDs, routing weights, tokens_per_expert counts, and cumulative offsets ready for grouped GEMM.
- Returns:
Token indices sorted by expert assignment. sorted_weights: Routing weights in the same sorted order. tokens_per_expert: Count of tokens per local expert. offs: Cumulative token counts (int32) for torch._grouped_mm.
- Return type:
sorted_token_ids
- nemo_automodel.components.moe.experts._apply_bias(value, bias, tokens_per_expert, permuted_probs=None)#
Apply per-expert bias to grouped GEMM output.
NOTE: torch._grouped_mm accepts a
biaskwarg in its schema but raises “RuntimeError: Bias not supported yet” as of PyTorch 2.9.0. Additionally, down projection bias needs weighting by routing probs (bias * permuted_probs) which native bias support wouldn’t handle.- Parameters:
value – Output from grouped GEMM, shape [total_tokens, features].
bias – Per-expert bias, shape [num_experts, features].
tokens_per_expert – Token counts per expert.
permuted_probs – If provided, bias is weighted by routing probs (for down projection).
- class nemo_automodel.components.moe.experts.GroupedExperts(
- config: nemo_automodel.components.moe.config.MoEConfig,
- backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
Bases:
torch.nn.ModuleSparse MoE implementation using all-gather/reduce-scatter primitives.
Supports two compute backends:
Per-expert loop with gather/scatter (default)
torch._grouped_mm with argsort-based permutation (backend.experts=”torch_mm”)
.. attribute:: n_routed_experts
Total number of experts in the model.
- Type:
int
.. attribute:: gate_and_up_projs
Linear layer for gate+up (gated) or just up (non-gated).
- Type:
nn.Parameter
.. attribute:: down_projs
Linear layer for hidden-to-output transformation.
- Type:
nn.Parameter
Initialization
Initializes the GroupedExperts module.
- Parameters:
config – MoE configuration containing expert parameters.
backend – Backend configuration. When backend.experts == “torch_mm”, uses torch._grouped_mm instead of per-expert loop.
- forward(
- x: torch.Tensor,
- token_mask: torch.Tensor,
- weights: torch.Tensor,
- indices: torch.Tensor,
Forward pass for the grouped experts.
- Parameters:
x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].
token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].
weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].
indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].
- Returns:
Output tensor after expert computation. Shape is [num_tokens, model_dim]
- Return type:
torch.Tensor
- _forward_loop(
- x,
- weights,
- indices,
- token_mask,
- gate_and_up_projs,
- down_projs,
- n_local_experts,
- experts_start_idx,
- experts_end_idx,
Per-expert loop forward path using gather/scatter.
- _forward_grouped_mm(
- x,
- token_mask,
- weights,
- indices,
- gate_and_up_projs,
- down_projs,
- n_local_experts,
- experts_start_idx,
Grouped GEMM forward path using torch._grouped_mm.
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- nemo_automodel.components.moe.experts.quick_geglu_deepep(
- x,
- permuted_probs,
- alpha: float = 1.702,
- limit: float = 7.0,
- linear_offset: float = 1.0,
- nemo_automodel.components.moe.experts.relu2_deepep(x, permuted_probs)#
ReLU² activation for DeepEP: relu(x)^2
For DeepEP with ReLU², x is the output of the up projection (already computed). x already has shape […, inter_dim] from efficient up_proj.
- nemo_automodel.components.moe.experts.get_expert_activation_for_deepep( )#
- class nemo_automodel.components.moe.experts.GroupedExpertsDeepEP(
- config: nemo_automodel.components.moe.config.MoEConfig,
- backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
Bases:
torch.nn.ModuleSparse MoE implementation using grouped GEMM with DeepEP token dispatch.
Supports two GEMM backends via BackendConfig.experts:
grouped_gemm.ops.gmm (experts=”gmm”, default)
torch._grouped_mm (experts=”torch_mm”, no external dependency)
Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.
.. attribute:: n_routed_experts
Total number of experts in the model.
- Type:
int
.. attribute:: gate_and_up_projs
Linear layer for gate+up (gated) or just up (non-gated).
- Type:
nn.Parameter
.. attribute:: down_projs
Linear layer for hidden-to-output transformation.
- Type:
nn.Parameter
Initialization
Initializes the GroupedExperts module.
- Parameters:
config – MoE configuration containing expert parameters.
backend – Backend configuration. When backend.experts == “torch_mm”, uses torch._grouped_mm; otherwise uses grouped_gemm.ops.gmm.
- init_token_dispatcher(
- ep_mesh: torch.distributed.device_mesh.DeviceMesh,
- forward(
- x: torch.Tensor,
- token_mask: torch.Tensor,
- weights: torch.Tensor,
- indices: torch.Tensor,
Forward pass for the grouped experts.
- Parameters:
x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].
token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].
weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].
indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].
- Returns:
Output tensor after expert computation. Shape is [num_tokens, model_dim]
- Return type:
torch.Tensor
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- nemo_automodel.components.moe.experts._torch_mm_experts_fwd(
- hidden_states,
- gate_and_up_projs,
- down_projs,
- tokens_per_expert,
- permuted_probs,
- activation_fn,
- class nemo_automodel.components.moe.experts.GroupedExpertsTE(
- config: nemo_automodel.components.moe.config.MoEConfig,
- backend: Optional[nemo_automodel.components.models.common.utils.BackendConfig] = None,
Bases:
torch.nn.ModuleMoE experts using TE’s GroupedLinear module directly.
Uses TE’s native GroupedLinear for computation, providing:
Optimized grouped GEMM kernels from TE
For expert parallelism, each rank creates GroupedLinear with num_local_experts = n_routed_experts / ep_size.
.. attribute:: n_routed_experts
Total number of experts in the model.
- Type:
int
.. attribute:: gate_up_linear
Combined gate and up projection.
- Type:
GroupedLinear
.. attribute:: down_linear
Down projection.
- Type:
GroupedLinear
Initialization
Initialize the GroupedExpertsTEGroupedLinear module.
- Parameters:
config – MoE configuration containing expert parameters.
backend – Backend configuration (reserved for future use).
- _get_stacked_weight(
- linear: transformer_engine.pytorch.GroupedLinear,
- transpose: bool = False,
- _get_stacked_bias(
- linear: transformer_engine.pytorch.GroupedLinear,
- _set_stacked_weight(
- linear: transformer_engine.pytorch.GroupedLinear,
- stacked: torch.Tensor,
- transpose: bool = False,
- _set_stacked_bias(
- linear: transformer_engine.pytorch.GroupedLinear,
- stacked: torch.Tensor,
- _to_ep_dtensor(tensor: torch.Tensor) torch.Tensor#
- _normalize_moe_mesh(
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
- set_moe_mesh(
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
- property gate_and_up_projs: torch.Tensor#
- property down_projs: torch.Tensor#
- property gate_up_proj_bias: Optional[torch.Tensor]#
- property down_proj_bias: Optional[torch.Tensor]#
- state_dict(
- *args,
- destination=None,
- prefix='',
- keep_vars=False,
- **kwargs,
Return state dict with stacked tensors in DeepEP format.
Converts TE GroupedLinear’s weight{i} parameters to stacked format:
gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
down_projs: [num_local_experts, moe_inter_dim, dim]
When EP is enabled, returns DTensors sharded on dimension 0.
- _load_from_state_dict(
- state_dict: Dict[str, Any],
- prefix: str,
- local_metadata,
- strict,
- missing_keys,
- unexpected_keys,
- error_msgs,
Load state dict with stacked tensors in DeepEP format.
Converts stacked format to TE GroupedLinear’s weight{i} parameters:
gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]
down_projs: [num_local_experts, moe_inter_dim, dim]
- init_token_dispatcher(
- ep_mesh: torch.distributed.device_mesh.DeviceMesh,
- moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Initialize the token dispatcher for expert parallelism.
Called by the parallelizer after model initialization.
- Parameters:
ep_mesh – Device mesh for expert parallelism.
- forward(
- x: torch.Tensor,
- token_mask: torch.Tensor,
- weights: torch.Tensor,
- indices: torch.Tensor,
Forward pass using TE’s GroupedLinear with native FP8 support.
- Parameters:
x – [num_tokens, model_dim] input tensor
token_mask – [num_tokens] boolean mask for valid tokens
weights – [num_tokens, num_activated_experts] routing weights
indices – [num_tokens, num_activated_experts] expert indices
- Returns:
[num_tokens, model_dim] output tensor
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
Initialize weights using reset_parameters()
- nemo_automodel.components.moe.experts._init_weights(
- module,
- buffer_device: torch.device,
- init_std: float = 0.02,