nemo_automodel.components.moe.layers#

Module Contents#

Classes#

MLP

Multi-Layer Perceptron (MLP) used as a feed-forward layer.

FakeBalancedGate

Load balanced gate implementation, spreads tokens uniformly across all experts. The rationale for this class is to do performance experiments to understand how the load imbalance with real data is impacting end-to-end performance.

Gate

Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

MoE

Mixture-of-Experts (MoE) module.

Functions#

Data#

API#

nemo_automodel.components.moe.layers._shared_experts_stream: Optional[torch.cuda.Stream]#

None

class nemo_automodel.components.moe.layers.MLP(
dim: int,
inter_dim: int,
backend: str,
dtype: torch.dtype = torch.bfloat16,
activation: str = 'swiglu',
bias: bool = False,
)#

Bases: torch.nn.Module

Multi-Layer Perceptron (MLP) used as a feed-forward layer.

Supports both gated activations (SwiGLU) and simple activations (ReLUΒ²).

.. attribute:: gate_proj

Linear layer for gate in gated activations (or up_proj for simple).

Type:

nn.Module

.. attribute:: down_proj

Linear layer for hidden-to-output transformation.

Type:

nn.Module

.. attribute:: up_proj

Additional linear layer for gated activations (None for simple).

Type:

nn.Module

Initialization

Initializes the MLP layer.

Parameters:
  • dim (int) – Input and output dimensionality.

  • inter_dim (int) – Hidden layer dimensionality.

  • backend (str) – Backend for linear layers.

  • dtype (torch.dtype) – Data type for weights.

  • activation (str) – Activation function - β€œswiglu” (default) or β€œrelu2”.

  • bias (bool) – Whether to use bias in linear layers.

forward(x: torch.Tensor) torch.Tensor#

Forward pass for the MLP layer.

Parameters:

x (torch.Tensor) – Input tensor.

Returns:

Output tensor after MLP computation.

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.moe.layers.FakeBalancedGate(
config: nemo_automodel.components.moe.config.MoEConfig,
skip_first_n_experts: int = 0,
noise: float = 0.0,
)#

Bases: torch.nn.Module

Load balanced gate implementation, spreads tokens uniformly across all experts. The rationale for this class is to do performance experiments to understand how the load imbalance with real data is impacting end-to-end performance.

When noise > 0, random perturbation is added to mimic realistic routing imbalance. A noise value of 0.0 gives perfectly balanced assignment, while 1.0 gives fully random expert selection and non-uniform weights.

Initialization

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]#

Forward pass for the gating mechanism.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens.

  • cp_mesh (Optional[DeviceMesh]) – Device mesh for context parallel computation.

Returns:

Routing weights for the selected experts. indices (torch.Tensor): Indices of the selected experts. aux_loss (Optional[torch.Tensor]): Auxiliary loss for load balancing.

Return type:

weights (torch.Tensor)

update_bias() None#
init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.moe.layers.Gate(
config: nemo_automodel.components.moe.config.MoEConfig,
gate_precision: torch.dtype | None = None,
)#

Bases: torch.nn.Module

Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

.. attribute:: dim

Dimensionality of input features.

Type:

int

.. attribute:: topk

Number of top experts activated for each input.

Type:

int

.. attribute:: n_groups

Number of groups for routing.

Type:

int

.. attribute:: topk_groups

Number of groups to route inputs to.

Type:

int

.. attribute:: score_func

Scoring function (β€˜softmax’ or β€˜sigmoid’).

Type:

str

.. attribute:: route_scale

Scaling factor for routing weights.

Type:

float

.. attribute:: weight

Learnable weights for the gate.

Type:

torch.nn.Parameter

.. attribute:: bias

Optional bias term for the gate.

Type:

Optional[torch.nn.Parameter]

Initialization

Initializes the Gate module.

Parameters:
  • config (MoEConfig) – Model configuration containing gating parameters.

  • gate_precision (torch.dtype | None) – Precision for gate computations (linear, softmax/sigmoid).

forward(
x: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]#

Forward pass for the gating mechanism.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens.

  • cp_mesh (Optional[DeviceMesh]) – Device mesh for context parallel computation.

Returns:

Routing weights for the selected experts. indices (torch.Tensor): Indices of the selected experts. aux_loss (Optional[torch.Tensor]): Auxiliary loss for load balancing.

Return type:

weights (torch.Tensor)

update_bias() None#

Updates the correction bias used in the gate based on the popularity of experts. This function is a NoOp if the gate is not trained.

To avoid routing collapse, and to promote better load balance of experts, DeepSeek-V3 uses a correction mechanism to adjust the scores of experts using a learned bias parameter. The bias parameter is updated based on the popularity of experts, i.e., the number of tokens routed to each expert. If an expert is more popular than the average, its bias term is decreased, and vice versa. This encourages the model to route tokens to less popular experts, promoting better load balance.

_compute_expert_load(
indices: torch.Tensor,
token_mask: torch.Tensor,
) torch.Tensor#

Computes the load of each expert based on the selected indices.

Parameters:
  • indices (torch.Tensor) – Indices of the selected experts. Shape is [num_tokens, num_activated_experts].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

Returns:

Load of each expert (number of tokens routed to each expert). Shape is [num_local_experts].

Return type:

torch.Tensor

_compute_aux_loss(
original_scores: torch.Tensor,
expert_load: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
) torch.Tensor#

Computes the auxiliary loss for load balancing.

Warning: Assumes batch size = 1, if batch size > 1, the aux_loss will be computed across multiple sequences.

Parameters:
  • original_scores (torch.Tensor) – Original scores from the gating mechanism. Shape is [num_tokens, num_experts].

  • expert_load (torch.Tensor) – Load of each expert (number of tokens routed to each expert). Shape is [num_experts].

  • token_mask (torch.Tensor) – Boolean mask indicating valid tokens. Shape is [num_tokens].

  • cp_mesh (Optional[DeviceMesh]) – Device mesh for context parallel computation.

Returns:

Auxiliary loss for load balancing. Shape is [].

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
class nemo_automodel.components.moe.layers.MoE(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig,
)#

Bases: torch.nn.Module

Mixture-of-Experts (MoE) module.

.. attribute:: dim

Dimensionality of input features.

Type:

int

.. attribute:: n_routed_experts

Total number of experts in the model.

Type:

int

.. attribute:: n_local_experts

Number of experts handled locally in distributed systems.

Type:

int

.. attribute:: n_activated_experts

Number of experts activated for each input.

Type:

int

.. attribute:: gate

Gating mechanism to route inputs to experts.

Type:

nn.Module

.. attribute:: experts

List of expert modules.

Type:

nn.ModuleList

.. attribute:: shared_experts

Shared experts applied to all inputs.

Type:

nn.Module

Initialization

Initializes the MoE module.

Parameters:

args (MoEArgs) – Model arguments containing MoE parameters.

forward(
x: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
) tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward pass for the MoE module.

Parameters:
  • x (torch.Tensor) – Input tensor.

  • padding_mask (Optional[torch.Tensor]) – Boolean mask indicating padding positions.

Returns:

Output tensor after expert routing and computation. Optional[torch.Tensor]: Auxiliary loss for load balancing (if applicable).

Return type:

torch.Tensor

init_weights(
buffer_device: torch.device,
init_std: float = 0.02,
) None#
nemo_automodel.components.moe.layers._init_weights(
module,
buffer_device: torch.device,
init_std: float = 0.02,
)#