nemo_automodel.components.moe.layers

View as Markdown

Module Contents

Classes

NameDescription
FakeBalancedGateLoad balanced gate implementation, spreads tokens uniformly across all experts.
GateGating mechanism for routing inputs in a mixture-of-experts (MoE) model.
MLPMulti-Layer Perceptron (MLP) used as a feed-forward layer.
MoEMixture-of-Experts (MoE) module.

Functions

NameDescription
_init_weights-

API

class nemo_automodel.components.moe.layers.FakeBalancedGate(
config: nemo_automodel.components.moe.config.MoEConfig,
skip_first_n_experts: int = 0,
noise: float = 0.0
)

Bases: Module

Load balanced gate implementation, spreads tokens uniformly across all experts. The rationale for this class is to do performance experiments to understand how the load imbalance with real data is impacting end-to-end performance.

When noise > 0, random perturbation is added to mimic realistic routing imbalance. A noise value of 0.0 gives perfectly balanced assignment, while 1.0 gives fully random expert selection and non-uniform weights.

bias_update_factor
= 0.0
n_activated_experts
= config.n_activated_experts
n_routed_experts
= config.n_routed_experts
nemo_automodel.components.moe.layers.FakeBalancedGate.forward(
x: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh]
) -> tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]]

Forward pass for the gating mechanism.

Parameters:

x
torch.Tensor

Input tensor.

token_mask
torch.Tensor

Boolean mask indicating valid tokens.

cp_mesh
Optional[DeviceMesh]

Device mesh for context parallel computation.

Returns: torch.Tensor

Routing weights for the selected experts.

nemo_automodel.components.moe.layers.FakeBalancedGate.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
nemo_automodel.components.moe.layers.FakeBalancedGate.update_bias() -> None
class nemo_automodel.components.moe.layers.Gate(
config: nemo_automodel.components.moe.config.MoEConfig,
gate_precision: torch.dtype | None = None
)

Bases: Module

Gating mechanism for routing inputs in a mixture-of-experts (MoE) model.

_cumulative_expert_load
Optional[Tensor] = None
_last_aux_loss
Optional[Tensor] = None
_last_expert_load
Optional[Tensor] = None
_track_load_balance
bool = False
aux_loss_coeff
= config.aux_loss_coeff
bias
bias_update_factor
= config.gate_bias_update_factor
dim
= config.dim
n_experts
= config.n_routed_experts
n_groups
= config.n_expert_groups
norm_topk_prob
= config.norm_topk_prob
route_scale
= config.route_scale
score_func
= config.score_func
softmax_before_topk
= config.softmax_before_topk
topk
= config.n_activated_experts
topk_groups
= config.n_limited_groups
train_gate
= config.train_gate
weight
nemo_automodel.components.moe.layers.Gate._compute_aux_loss(
original_scores: torch.Tensor,
expert_load: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh]
) -> torch.Tensor

Computes the auxiliary loss for load balancing.

Warning: Assumes batch size = 1, if batch size > 1, the aux_loss will be computed across multiple sequences.

Parameters:

original_scores
torch.Tensor

Original scores from the gating mechanism. Shape is [num_tokens, num_experts].

expert_load
torch.Tensor

Load of each expert (number of tokens routed to each expert). Shape is [num_experts].

token_mask
torch.Tensor

Boolean mask indicating valid tokens. Shape is [num_tokens].

cp_mesh
Optional[DeviceMesh]

Device mesh for context parallel computation.

Returns: torch.Tensor

torch.Tensor: Auxiliary loss for load balancing. Shape is [].

nemo_automodel.components.moe.layers.Gate._compute_expert_load(
indices: torch.Tensor,
token_mask: torch.Tensor
) -> torch.Tensor

Computes the load of each expert based on the selected indices. Args: indices (torch.Tensor): Indices of the selected experts. Shape is [num_tokens, num_activated_experts]. token_mask (torch.Tensor): Boolean mask indicating valid tokens. Shape is [num_tokens].

Returns: torch.Tensor

torch.Tensor: Load of each expert (number of tokens routed to each expert). Shape is [num_local_experts].

nemo_automodel.components.moe.layers.Gate.forward(
x: torch.Tensor,
token_mask: torch.Tensor,
cp_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh]
) -> tuple[torch.Tensor, torch.Tensor, typing.Optional[torch.Tensor]]

Forward pass for the gating mechanism.

Parameters:

x
torch.Tensor

Input tensor.

token_mask
torch.Tensor

Boolean mask indicating valid tokens.

cp_mesh
Optional[DeviceMesh]

Device mesh for context parallel computation.

Returns: torch.Tensor

Routing weights for the selected experts.

nemo_automodel.components.moe.layers.Gate.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
nemo_automodel.components.moe.layers.Gate.update_bias() -> None

Updates the correction bias used in the gate based on the popularity of experts. This function is a NoOp if the gate is not trained.

To avoid routing collapse, and to promote better load balance of experts, DeepSeek-V3 uses a correction mechanism to adjust the scores of experts using a learned bias parameter. The bias parameter is updated based on the popularity of experts, i.e., the number of tokens routed to each expert. If an expert is more popular than the average, its bias term is decreased, and vice versa. This encourages the model to route tokens to less popular experts, promoting better load balance.

class nemo_automodel.components.moe.layers.MLP(
dim: int,
inter_dim: int,
backend: str,
dtype: torch.dtype = torch.bfloat16,
activation: str = 'swiglu',
bias: bool = False,
swiglu_limit: float = 0.0
)

Bases: Module

Multi-Layer Perceptron (MLP) used as a feed-forward layer.

Supports both gated activations (SwiGLU) and simple activations (ReLU²).

down_proj
gate_proj
is_gated
= is_gated_activation(activation)
swiglu_limit
= float(swiglu_limit)
up_proj
nemo_automodel.components.moe.layers.MLP.forward(
x: torch.Tensor
) -> torch.Tensor

Forward pass for the MLP layer.

Parameters:

x
torch.Tensor

Input tensor.

Returns: torch.Tensor

torch.Tensor: Output tensor after MLP computation.

nemo_automodel.components.moe.layers.MLP.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
class nemo_automodel.components.moe.layers.MoE(
config: nemo_automodel.components.moe.config.MoEConfig,
backend: nemo_automodel.components.models.common.BackendConfig
)

Bases: Module

Mixture-of-Experts (MoE) module.

cp_mesh
Optional[DeviceMesh] = None
dim
= config.dim
experts
= GroupedExperts(config, backend=backend)
fc1_latent_proj
fc2_latent_proj
gate
n_activated_experts
= config.n_activated_experts
n_routed_experts
= config.n_routed_experts
shared_expert_gate
shared_experts
nemo_automodel.components.moe.layers.MoE.forward(
x: torch.Tensor,
padding_mask: typing.Optional[torch.Tensor] = None,
cp_mesh: typing.Optional[torch.distributed.device_mesh.DeviceMesh] = None
) -> tuple[torch.Tensor, typing.Optional[torch.Tensor]]

Forward pass for the MoE module.

Parameters:

x
torch.Tensor

Input tensor.

padding_mask
Optional[torch.Tensor]Defaults to None

Boolean mask indicating padding positions.

Returns: torch.Tensor

torch.Tensor: Output tensor after expert routing and computation.

nemo_automodel.components.moe.layers.MoE.init_weights(
buffer_device: torch.device,
init_std: float = 0.02
) -> None
nemo_automodel.components.moe.layers._init_weights(
module,
buffer_device: torch.device,
init_std: float = 0.02
)