nemo_automodel.components.moe.experts#

Module Contents#

Classes#

GroupedExperts

Sparse MoE implementation using all-gather/reduce-scatter primitives.

GroupedExpertsDeepEP

Sparse MoE implementation using DeepEP.

GroupedExpertsTE

MoE experts using TE’s GroupedLinear module directly.

Functions#

is_gated_activation

Check if activation requires gating (gate_proj + up_proj).

swiglu

quick_geglu

relu2

ReLU² activation: relu(x)^2

get_expert_activation

quick_geglu_deepep

relu2_deepep

ReLU² activation for DeepEP: relu(x)^2

get_expert_activation_for_deepep

_init_weights

API#

nemo_automodel.components.moe.experts.is_gated_activation(activation: str) bool#

Check if activation requires gating (gate_proj + up_proj).

Gated activations (SwiGLU, Quick-GEGLU) use both gate_proj and up_proj, requiring gate_and_up_projs tensor with shape [n_experts, dim, 2*inter_dim].

Non-gated activations (ReLU²) only use up_proj, requiring up_projs tensor with shape [n_experts, dim, inter_dim] - 50% memory savings.

nemo_automodel.components.moe.experts.swiglu(
x,
*,
gate_and_up_proj,
down_proj,
gate_up_proj_bias=None,
down_proj_bias=None,
)#
nemo_automodel.components.moe.experts.quick_geglu(
x,
*,
gate_and_up_proj,
down_proj,
gate_up_proj_bias=None,
down_proj_bias=None,
alpha: float = 1.702,
limit: float | None = 7.0,
)#
nemo_automodel.components.moe.experts.relu2(
x,
*,
gate_and_up_proj,
down_proj,
gate_up_proj_bias=None,
down_proj_bias=None,
)#

ReLU² activation: relu(x)^2

Uses efficient gate_and_up_proj tensor with shape [dim, inter_dim]. Memory-efficient pathway - no duplication of weights.

nemo_automodel.components.moe.experts.get_expert_activation(
config: nemo_automodel.components.moe.config.MoEConfig,
)#
class nemo_automodel.components.moe.experts.GroupedExperts(config: nemo_automodel.components.moe.config.MoEConfig)#

Bases: torch.nn.Module

Sparse MoE implementation using all-gather/reduce-scatter primitives.

Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_and_up_projs

Linear layer for gate+up (gated) or just up (non-gated).

Type:

nn.Parameter

.. attribute:: down_projs

Linear layer for hidden-to-output transformation.

Type:

nn.Parameter

Initialization

Initializes the GroupedExperts module.

Parameters:

args (MoEArgs) – Model arguments containing the number of routed experts, model and intermediate dimension parameters.

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass for the grouped experts.

Parameters:
  • x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

  • weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

  • indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns:

Output tensor after expert computation. Shape is [num_tokens, model_dim]

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
nemo_automodel.components.moe.experts.quick_geglu_deepep(
x,
permuted_probs,
alpha: float = 1.702,
limit: float = 7.0,
linear_offset: float = 1.0,
)#
nemo_automodel.components.moe.experts.relu2_deepep(x, permuted_probs)#

ReLU² activation for DeepEP: relu(x)^2

For DeepEP with ReLU², x is the output of the up projection (already computed). x already has shape […, inter_dim] from efficient up_proj.

nemo_automodel.components.moe.experts.get_expert_activation_for_deepep(
config: nemo_automodel.components.moe.config.MoEConfig,
)#
class nemo_automodel.components.moe.experts.GroupedExpertsDeepEP(
config: nemo_automodel.components.moe.config.MoEConfig,
)#

Bases: torch.nn.Module

Sparse MoE implementation using DeepEP.

Once the experts for a particular token have been identified, this module is invoked to compute and average the output of the activated experts.

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_and_up_projs

Linear layer for gate+up (gated) or just up (non-gated).

Type:

nn.Parameter

.. attribute:: down_projs

Linear layer for hidden-to-output transformation.

Type:

nn.Parameter

Initialization

Initializes the GroupedExperts module.

Parameters:

args (MoEArgs) – Model arguments containing the number of routed experts, model and intermediate dimension parameters.

static _apply_bias(value, bias, tokens_per_expert, permuted_probs=None)#
init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
)#
forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass for the grouped experts.

Parameters:
  • x (torch.Tensor) – Input tensor. Shape is [num_tokens, model_dim].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

  • weights (torch.Tensor) – Routing weights for the selected experts. Shape is [num_tokens, num_activated_experts].

  • indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

Returns:

Output tensor after expert computation. Shape is [num_tokens, model_dim]

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.moe.experts.GroupedExpertsTE(
config: nemo_automodel.components.moe.config.MoEConfig,
)#

Bases: torch.nn.Module

MoE experts using TE’s GroupedLinear module directly.

Uses TE’s native GroupedLinear for computation, providing:

  • Optimized grouped GEMM kernels from TE

For expert parallelism, each rank creates GroupedLinear with num_local_experts = n_routed_experts / ep_size.

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: gate_up_linear

Combined gate and up projection.

Type:

GroupedLinear

.. attribute:: down_linear

Down projection.

Type:

GroupedLinear

Initialization

Initialize the GroupedExpertsTEGroupedLinear module.

Parameters:

config – MoE configuration containing expert parameters.

_get_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
transpose: bool = False,
) torch.Tensor#
_get_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear,
) Optional[torch.Tensor]#
_set_stacked_weight(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor,
transpose: bool = False,
)#
_set_stacked_bias(
linear: transformer_engine.pytorch.GroupedLinear,
stacked: torch.Tensor,
)#
_to_ep_dtensor(tensor: torch.Tensor) torch.Tensor#
_normalize_moe_mesh(
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) Optional[torch.distributed.device_mesh.DeviceMesh]#
set_moe_mesh(
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) None#
property gate_and_up_projs: torch.Tensor#
property down_projs: torch.Tensor#
property gate_up_proj_bias: Optional[torch.Tensor]#
property down_proj_bias: Optional[torch.Tensor]#
state_dict(
*args,
destination=None,
prefix='',
keep_vars=False,
**kwargs,
) Dict[str, Any]#

Return state dict with stacked tensors in DeepEP format.

Converts TE GroupedLinear’s weight{i} parameters to stacked format:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]

  • down_projs: [num_local_experts, moe_inter_dim, dim]

When EP is enabled, returns DTensors sharded on dimension 0.

_load_from_state_dict(
state_dict: Dict[str, Any],
prefix: str,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)#

Load state dict with stacked tensors in DeepEP format.

Converts stacked format to TE GroupedLinear’s weight{i} parameters:

  • gate_and_up_projs: [num_local_experts, dim, moe_inter_dim * 2]

  • down_projs: [num_local_experts, moe_inter_dim, dim]

init_token_dispatcher(
ep_mesh: torch.distributed.device_mesh.DeviceMesh,
moe_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
)#

Initialize the token dispatcher for expert parallelism.

Called by the parallelizer after model initialization.

Parameters:

ep_mesh – Device mesh for expert parallelism.

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
weights: torch.Tensor,
indices: torch.Tensor,
) torch.Tensor#

Forward pass using TE’s GroupedLinear with native FP8 support.

Parameters:
  • x – [num_tokens, model_dim] input tensor

  • token_mask – [num_tokens] boolean mask for valid tokens

  • weights – [num_tokens, num_activated_experts] routing weights

  • indices – [num_tokens, num_activated_experts] expert indices

Returns:

[num_tokens, model_dim] output tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#

Initialize weights using reset_parameters()

nemo_automodel.components.moe.experts._init_weights(
module,
buffer_device: torch.device,
init_std: float = 0.02,
)#