nemo_automodel.components.moe.layers#
Module Contents#
Classes#
Multi-Layer Perceptron (MLP) used as a feed-forward layer. |
|
Load balanced gate implementation, spreads tokens uniformly across all experts. The rationale for this class is to do performance experiments to understand how the load imbalance with real data is impacting end-to-end performance. |
|
Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. |
|
Mixture-of-Experts (MoE) module. |
Functions#
Data#
API#
None
- class nemo_automodel.components.moe.layers.MLP(
- dim: int,
- inter_dim: int,
- backend: str,
- dtype: torch.dtype = torch.bfloat16,
- activation: str = 'swiglu',
- bias: bool = False,
Bases:
torch.nn.ModuleMulti-Layer Perceptron (MLP) used as a feed-forward layer.
Supports both gated activations (SwiGLU) and simple activations (ReLUΒ²).
.. attribute:: gate_proj
Linear layer for gate in gated activations (or up_proj for simple).
- Type:
nn.Module
.. attribute:: down_proj
Linear layer for hidden-to-output transformation.
- Type:
nn.Module
.. attribute:: up_proj
Additional linear layer for gated activations (None for simple).
- Type:
nn.Module
Initialization
Initializes the MLP layer.
- Parameters:
dim (int) β Input and output dimensionality.
inter_dim (int) β Hidden layer dimensionality.
backend (str) β Backend for linear layers.
dtype (torch.dtype) β Data type for weights.
activation (str) β Activation function - βswigluβ (default) or βrelu2β.
bias (bool) β Whether to use bias in linear layers.
- forward(x: torch.Tensor) torch.Tensor#
Forward pass for the MLP layer.
- Parameters:
x (torch.Tensor) β Input tensor.
- Returns:
Output tensor after MLP computation.
- Return type:
torch.Tensor
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.moe.layers.FakeBalancedGate(
- config: nemo_automodel.components.moe.config.MoEConfig,
- skip_first_n_experts: int = 0,
- noise: float = 0.0,
Bases:
torch.nn.ModuleLoad balanced gate implementation, spreads tokens uniformly across all experts. The rationale for this class is to do performance experiments to understand how the load imbalance with real data is impacting end-to-end performance.
When
noise > 0, random perturbation is added to mimic realistic routing imbalance. A noise value of 0.0 gives perfectly balanced assignment, while 1.0 gives fully random expert selection and non-uniform weights.Initialization
- forward(
- x: torch.Tensor,
- token_mask: torch.Tensor,
- cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
Forward pass for the gating mechanism.
- Parameters:
x (torch.Tensor) β Input tensor.
token_mask (torch.Tensor) β Boolean mask indicating valid tokens.
cp_mesh (Optional[DeviceMesh]) β Device mesh for context parallel computation.
- Returns:
Routing weights for the selected experts. indices (torch.Tensor): Indices of the selected experts. aux_loss (Optional[torch.Tensor]): Auxiliary loss for load balancing.
- Return type:
weights (torch.Tensor)
- update_bias() None#
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.moe.layers.Gate(
- config: nemo_automodel.components.moe.config.MoEConfig,
- gate_precision: torch.dtype | None = None,
Bases:
torch.nn.ModuleGating mechanism for routing inputs in a mixture-of-experts (MoE) model.
.. attribute:: dim
Dimensionality of input features.
- Type:
int
.. attribute:: topk
Number of top experts activated for each input.
- Type:
int
.. attribute:: n_groups
Number of groups for routing.
- Type:
int
.. attribute:: topk_groups
Number of groups to route inputs to.
- Type:
int
.. attribute:: score_func
Scoring function (βsoftmaxβ or βsigmoidβ).
- Type:
str
.. attribute:: route_scale
Scaling factor for routing weights.
- Type:
float
.. attribute:: weight
Learnable weights for the gate.
- Type:
torch.nn.Parameter
.. attribute:: bias
Optional bias term for the gate.
- Type:
Optional[torch.nn.Parameter]
Initialization
Initializes the Gate module.
- Parameters:
config (MoEConfig) β Model configuration containing gating parameters.
gate_precision (torch.dtype | None) β Precision for gate computations (linear, softmax/sigmoid).
- forward(
- x: torch.Tensor,
- token_mask: torch.Tensor,
- cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
Forward pass for the gating mechanism.
- Parameters:
x (torch.Tensor) β Input tensor.
token_mask (torch.Tensor) β Boolean mask indicating valid tokens.
cp_mesh (Optional[DeviceMesh]) β Device mesh for context parallel computation.
- Returns:
Routing weights for the selected experts. indices (torch.Tensor): Indices of the selected experts. aux_loss (Optional[torch.Tensor]): Auxiliary loss for load balancing.
- Return type:
weights (torch.Tensor)
- update_bias() None#
Updates the correction bias used in the gate based on the popularity of experts. This function is a NoOp if the gate is not trained.
To avoid routing collapse, and to promote better load balance of experts, DeepSeek-V3 uses a correction mechanism to adjust the scores of experts using a learned bias parameter. The bias parameter is updated based on the popularity of experts, i.e., the number of tokens routed to each expert. If an expert is more popular than the average, its bias term is decreased, and vice versa. This encourages the model to route tokens to less popular experts, promoting better load balance.
- _compute_expert_load(
- indices: torch.Tensor,
- token_mask: torch.Tensor,
Computes the load of each expert based on the selected indices.
- Parameters:
indices (torch.Tensor) β Indices of the selected experts. Shape is [num_tokens, num_activated_experts].
token_mask (torch.Tensor) β Boolean mask indicating valid tokens. Shape is [num_tokens].
- Returns:
Load of each expert (number of tokens routed to each expert). Shape is [num_local_experts].
- Return type:
torch.Tensor
- _compute_aux_loss(
- original_scores: torch.Tensor,
- expert_load: torch.Tensor,
- token_mask: torch.Tensor,
- cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh],
Computes the auxiliary loss for load balancing.
Warning: Assumes batch size = 1, if batch size > 1, the aux_loss will be computed across multiple sequences.
- Parameters:
original_scores (torch.Tensor) β Original scores from the gating mechanism. Shape is [num_tokens, num_experts].
expert_load (torch.Tensor) β Load of each expert (number of tokens routed to each expert). Shape is [num_experts].
token_mask (torch.Tensor) β Boolean mask indicating valid tokens. Shape is [num_tokens].
cp_mesh (Optional[DeviceMesh]) β Device mesh for context parallel computation.
- Returns:
Auxiliary loss for load balancing. Shape is [].
- Return type:
torch.Tensor
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- class nemo_automodel.components.moe.layers.MoE(
- config: nemo_automodel.components.moe.config.MoEConfig,
- backend: nemo_automodel.components.models.common.BackendConfig,
Bases:
torch.nn.ModuleMixture-of-Experts (MoE) module.
.. attribute:: dim
Dimensionality of input features.
- Type:
int
.. attribute:: n_routed_experts
Total number of experts in the model.
- Type:
int
.. attribute:: n_local_experts
Number of experts handled locally in distributed systems.
- Type:
int
.. attribute:: n_activated_experts
Number of experts activated for each input.
- Type:
int
.. attribute:: gate
Gating mechanism to route inputs to experts.
- Type:
nn.Module
.. attribute:: experts
List of expert modules.
- Type:
nn.ModuleList
.. attribute:: shared_experts
Shared experts applied to all inputs.
- Type:
nn.Module
Initialization
Initializes the MoE module.
- Parameters:
args (MoEArgs) β Model arguments containing MoE parameters.
- forward(
- x: torch.Tensor,
- padding_mask: Optional[torch.Tensor] = None,
- cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
Forward pass for the MoE module.
- Parameters:
x (torch.Tensor) β Input tensor.
padding_mask (Optional[torch.Tensor]) β Boolean mask indicating padding positions.
- Returns:
Output tensor after expert routing and computation. Optional[torch.Tensor]: Auxiliary loss for load balancing (if applicable).
- Return type:
torch.Tensor
- init_weights(
- buffer_device: torch.device,
- init_std: float = 0.02,
- nemo_automodel.components.moe.layers._init_weights(
- module,
- buffer_device: torch.device,
- init_std: float = 0.02,