core.transformer.moe.experts#
Module Contents#
Classes#
Interface for linear_fc1 module in TEGroupedMLP. |
|
Protocol describing how to build a linear_fc1 layer in TEGroupedMLP. |
|
Protocol for linear_fc2 module in TEGroupedMLP. |
|
Protocol describing how to build a linear_fc2 layer in TEGroupedMLP. |
|
The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2. |
|
An efficient implementation of the Experts layer using TE’s GroupedLinear. |
|
Inference-optimized GroupedMLP with GPU-resident offsets. |
|
An implementation of the Experts layer using a sequence of MLP layers. |
Data#
API#
- core.transformer.moe.experts.logger#
‘getLogger(…)’
- class core.transformer.moe.experts.GroupedLinearFc1Interface#
Bases:
typing.ProtocolInterface for linear_fc1 module in TEGroupedMLP.
- forward(
- permuted_local_hidden_states: torch.Tensor,
- tokens_per_expert: list[int],
- /,
Forward method for linear_fc1 module.
- backward_dw() None#
Backward method for linear_fc1 module.
- class core.transformer.moe.experts.GroupedLinearFc1Builder#
Bases:
typing.ProtocolProtocol describing how to build a linear_fc1 layer in TEGroupedMLP.
- __call__(
- num_local_experts: int,
- input_size: int,
- output_size: int,
- /,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: collections.abc.Callable[[torch.Tensor], None],
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: str | None,
- pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
Builds a linear_fc1 layer for TEGroupedMLP.
- class core.transformer.moe.experts.GroupedLinearFc2Interface#
Bases:
typing.ProtocolProtocol for linear_fc2 module in TEGroupedMLP.
- forward(
- intermediate_parallel: torch.Tensor,
- tokens_per_expert: list[int],
- /,
Forward method for linear_fc2 module.
- backward_dw() None#
Backward method for linear_fc2 module.
- class core.transformer.moe.experts.GroupedLinearFc2Builder#
Bases:
typing.ProtocolProtocol describing how to build a linear_fc2 layer in TEGroupedMLP.
- __call__(
- num_local_experts: int,
- input_size: int,
- output_size: int,
- /,
- *,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- init_method: collections.abc.Callable[[torch.Tensor], None],
- bias: bool,
- skip_bias_add: bool,
- is_expert: bool,
- tp_comm_buffer_name: str | None,
- pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
Builds a linear_fc2 layer for TEGroupedMLP.
- class core.transformer.moe.experts.TEGroupedMLPSubmodules#
The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2.
- linear_fc1: core.transformer.moe.experts.GroupedLinearFc1Builder#
None
- linear_fc2: core.transformer.moe.experts.GroupedLinearFc2Builder#
None
- activation_func: megatron.core.transformer.mlp.TEActivationFunctionBuilder | None#
None
Builder for an activation function module; only used if config.use_te_activation_func is True.
- class core.transformer.moe.experts.TEGroupedMLP(
- num_local_experts: int,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: core.transformer.moe.experts.TEGroupedMLPSubmodules,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.module.MegatronModuleAn efficient implementation of the Experts layer using TE’s GroupedLinear.
Executes multiple experts in parallel to maximize computational efficiency.
Initialization
- static _apply_bias(
- intermediate_parallel,
- bias_parallel,
- tokens_per_expert,
- permuted_probs,
- bias_act_func(intermediate_parallel, bias_parallel, permuted_probs)#
Applies bias and activation function to the output of linear_fc1.
- forward(
- permuted_local_hidden_states: torch.Tensor,
- tokens_per_expert: torch.Tensor,
- permuted_probs: torch.Tensor,
Forward of TEGroupedMLP
- Parameters:
permuted_local_hidden_states (torch.Tensor) – The permuted input hidden states of the
experts. (local)
tokens_per_expert (torch.Tensor) – The number of tokens per expert.
permuted_probs (torch.Tensor) – The permuted probs of each token produced by the router.
- Returns:
The output of the local experts.
- Return type:
output (torch.Tensor)
- sharded_state_dict(
- prefix: str = '',
- sharded_offsets: tuple = (),
- metadata: Optional[dict] = None,
Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP’s.
- backward_dw()#
Performs backward pass for weight gradients in TEGroupedMLP.
This method executes the backward pass for weight gradients by calling backward_dw() on the linear layers in reverse order (fc2 followed by fc1). If an error occurs during execution, it is caught and re-raised with a descriptive message.
- class core.transformer.moe.experts.InferenceGroupedMLP(
- num_local_experts: int,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.mlp.MLPSubmodules,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
core.transformer.moe.experts.TEGroupedMLPInference-optimized GroupedMLP with GPU-resident offsets.
Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. Supports three forward paths:
Training: delegates to parent TEGroupedMLP
Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM)
Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets
Initialization
- _resolve_flashinfer_activation_type()#
Map megatron activation config to FlashInfer ActivationType.
- set_inference_cuda_graphed_iteration()#
Enable CUDA-graphed iteration mode.
- unset_inference_cuda_graphed_iteration()#
Disable CUDA-graphed iteration mode.
- _build_concatenated_weights()#
Create big contiguous weight tensors with per-expert views for checkpoint compatibility.
Creates _fc1_weight and _fc2_weight as contiguous tensors of shape [num_experts, out_features, in_features]. Replaces TE’s individual weight{i} parameters with views into these tensors.
This allows:
load_state_dict to load into weight{i} views -> writes into big tensor
forward() to use big tensor directly with torch._grouped_mm or FlashInfer
- _flashinfer_forward(hidden_states, routing_map, probs)#
FlashInfer fused MoE kernel for CUDA-graphed inference iterations.
- _torch_grouped_mm_forward(
- permuted_local_hidden_states,
- tokens_per_expert,
- permuted_probs,
- forward(
- permuted_local_hidden_states: torch.Tensor,
- tokens_per_expert: Optional[torch.Tensor],
- permuted_probs: torch.Tensor,
- routing_map: Optional[torch.Tensor] = None,
Forward pass with three modes:
Training: delegates to parent TEGroupedMLP.
Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert is not used in this path; the FlashInfer kernel operates directly on routing_map.
Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets.
- Parameters:
permuted_local_hidden_states – [num_tokens, hidden_size] input hidden states.
tokens_per_expert – [num_experts] number of tokens routed to each expert. None when using the CUDA-graphed FlashInfer path.
permuted_probs – [num_tokens, topk] routing probabilities.
routing_map – [num_tokens, topk] token-to-expert assignment indices. Required for the FlashInfer CUDA-graphed path, None otherwise.
- class core.transformer.moe.experts.SequentialMLP(
- num_local_experts,
- config: megatron.core.transformer.transformer_config.TransformerConfig,
- submodules: megatron.core.transformer.mlp.MLPSubmodules,
- pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
Bases:
megatron.core.transformer.module.MegatronModuleAn implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
Initialization
- _pad_tensor_for_quantization(hidden, probs)#
Padding tensor shape to multiples of 16/32.
- forward(
- permuted_local_hidden_states: torch.Tensor,
- tokens_per_expert: torch.Tensor,
- permuted_probs: torch.Tensor,
Forward step of the SequentialMLP.
- backward_dw()#
Backward pass for weight gradients in SequentialMLP.
- sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#
Maps local expert to global experts.