core.transformer.moe.moe_layer#
Module Contents#
Classes#
MoE Layer Submodule spec |
|
Base class for a mixture of experts layer. |
|
Mixture of Experts layer. |
API#
- class core.transformer.moe.moe_layer.MoESubmodules#
MoE Layer Submodule spec
- experts: Union[megatron.core.transformer.spec_utils.ModuleSpec, type]#
None
None
- class core.transformer.moe.moe_layer.BaseMoELayer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- layer_number: Optional[int] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.module.MegatronModule,abc.ABCBase class for a mixture of experts layer.
- Parameters:
config (TransformerConfig) – Configuration object for the transformer model.
Initialization
- abstractmethod forward(hidden_states)#
Forward method for the MoE layer.
- set_layer_number(layer_number: int)#
Set the layer number for the MoE layer.
- class core.transformer.moe.moe_layer.MoELayer(
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: Optional[core.transformer.moe.moe_layer.MoESubmodules] = None,
- layer_number: Optional[int] = None,
- pg_collection: Optional[megatron.core.process_groups_config.ProcessGroupCollection] = None,
Bases:
core.transformer.moe.moe_layer.BaseMoELayerMixture of Experts layer.
This layer implements a Mixture of Experts model, where each token is routed to a subset of experts. This implementation supports different token dispatching strategies such as All-to-All and All-Gather.
Initialization
- router_and_preprocess(hidden_states: torch.Tensor)#
Compute and preprocess token routing for dispatch.
This method uses the router to determine which experts to send each token to, producing routing probabilities and a mapping. It then preprocesses the hidden states and probabilities for the token dispatcher. The original hidden states are returned as a residual connection.
- dispatch(hidden_states: torch.Tensor, probs: torch.Tensor)#
Dispatches tokens to assigned expert ranks via communication. This method performs the actual communication (e.g., All-to-All) to distribute tokens and their associated probabilities to the devices hosting their assigned experts.
Computes the output of the shared experts.
If a shared expert is configured and not overlapped with communication, it is computed here.
- routed_experts_compute(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- residual: torch.Tensor,
Computes the output of the routed experts on the dispatched tokens.
This method first post-processes the dispatched input to get permuted tokens for each expert. It then passes the tokens through the local experts. The output from the experts is preprocessed for the combine step.
- combine(
- output: torch.Tensor,
- shared_expert_output: Optional[torch.Tensor],
Combines expert outputs via communication and adds shared expert output.
This method uses the token dispatcher to combine the outputs from different experts (e.g., via an All-to-All communication). It then adds the output from the shared expert if it exists.
- forward(hidden_states: torch.Tensor)#
Forward pass for the MoE layer.
The forward pass comprises four main steps:
Routing & Preprocessing: Route tokens to the assigned experts and prepare for dispatch.
Dispatch: Tokens are sent to the expert devices using communication collectives.
Expert Computation: Experts process the dispatched tokens.
Combine: The outputs from the experts are combined and returned.
- Parameters:
hidden_states (torch.Tensor) – The input tensor to the MoE layer.
- Returns:
A tuple containing the output tensor and the MLP bias, if any.
- backward_dw()#
Compute weight gradients for experts and shared experts.
- set_for_recompute_pre_mlp_layernorm()#
Set the MoE layer for recompute pre_mlp_layernorm. Only needed for fp8/fp4.