core.transformer.moe.experts#

Module Contents#

Classes#

GroupedLinearFc1Interface

Interface for linear_fc1 module in TEGroupedMLP.

GroupedLinearFc1Builder

Protocol describing how to build a linear_fc1 layer in TEGroupedMLP.

GroupedLinearFc2Interface

Protocol for linear_fc2 module in TEGroupedMLP.

GroupedLinearFc2Builder

Protocol describing how to build a linear_fc2 layer in TEGroupedMLP.

TEGroupedMLPSubmodules

The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2.

TEGroupedMLP

An efficient implementation of the Experts layer using TE’s GroupedLinear.

InferenceGroupedMLP

Inference-optimized GroupedMLP with GPU-resident offsets.

SequentialMLP

An implementation of the Experts layer using a sequence of MLP layers.

Data#

API#

core.transformer.moe.experts.logger#

‘getLogger(…)’

class core.transformer.moe.experts.GroupedLinearFc1Interface#

Bases: typing.Protocol

Interface for linear_fc1 module in TEGroupedMLP.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: list[int],
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc1 module.

backward_dw() None#

Backward method for linear_fc1 module.

class core.transformer.moe.experts.GroupedLinearFc1Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc1 layer in TEGroupedMLP.

__call__(
num_local_experts: int,
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
) core.transformer.moe.experts.GroupedLinearFc1Interface#

Builds a linear_fc1 layer for TEGroupedMLP.

class core.transformer.moe.experts.GroupedLinearFc2Interface#

Bases: typing.Protocol

Protocol for linear_fc2 module in TEGroupedMLP.

forward(
intermediate_parallel: torch.Tensor,
tokens_per_expert: list[int],
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc2 module.

backward_dw() None#

Backward method for linear_fc2 module.

class core.transformer.moe.experts.GroupedLinearFc2Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc2 layer in TEGroupedMLP.

__call__(
num_local_experts: int,
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
) core.transformer.moe.experts.GroupedLinearFc2Interface#

Builds a linear_fc2 layer for TEGroupedMLP.

class core.transformer.moe.experts.TEGroupedMLPSubmodules#

The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2.

linear_fc1: core.transformer.moe.experts.GroupedLinearFc1Builder#

None

linear_fc2: core.transformer.moe.experts.GroupedLinearFc2Builder#

None

activation_func: megatron.core.transformer.mlp.TEActivationFunctionBuilder | None#

None

Builder for an activation function module; only used if config.use_te_activation_func is True.

class core.transformer.moe.experts.TEGroupedMLP(
num_local_experts: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.moe.experts.TEGroupedMLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An efficient implementation of the Experts layer using TE’s GroupedLinear.

Executes multiple experts in parallel to maximize computational efficiency.

Initialization

static _apply_bias(
intermediate_parallel,
bias_parallel,
tokens_per_expert,
permuted_probs,
)#
bias_act_func(intermediate_parallel, bias_parallel, permuted_probs)#

Applies bias and activation function to the output of linear_fc1.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward of TEGroupedMLP

Parameters:
  • permuted_local_hidden_states (torch.Tensor) – The permuted input hidden states of the

  • experts. (local)

  • tokens_per_expert (torch.Tensor) – The number of tokens per expert.

  • permuted_probs (torch.Tensor) – The permuted probs of each token produced by the router.

Returns:

The output of the local experts.

Return type:

output (torch.Tensor)

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP’s.

backward_dw()#

Performs backward pass for weight gradients in TEGroupedMLP.

This method executes the backward pass for weight gradients by calling backward_dw() on the linear layers in reverse order (fc2 followed by fc1). If an error occurs during execution, it is caught and re-raised with a descriptive message.

class core.transformer.moe.experts.InferenceGroupedMLP(
num_local_experts: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: core.transformer.moe.experts.TEGroupedMLP

Inference-optimized GroupedMLP with GPU-resident offsets.

Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. Supports three forward paths:

  • Training: delegates to parent TEGroupedMLP

  • Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM)

  • Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets

Initialization

_resolve_flashinfer_activation_type()#

Map megatron activation config to FlashInfer ActivationType.

set_inference_cuda_graphed_iteration()#

Enable CUDA-graphed iteration mode.

unset_inference_cuda_graphed_iteration()#

Disable CUDA-graphed iteration mode.

_build_concatenated_weights()#

Create big contiguous weight tensors with per-expert views for checkpoint compatibility.

Creates _fc1_weight and _fc2_weight as contiguous tensors of shape [num_experts, out_features, in_features]. Replaces TE’s individual weight{i} parameters with views into these tensors.

This allows:

  • load_state_dict to load into weight{i} views -> writes into big tensor

  • forward() to use big tensor directly with torch._grouped_mm or FlashInfer

_flashinfer_forward(hidden_states, routing_map, probs)#

FlashInfer fused MoE kernel for CUDA-graphed inference iterations.

_torch_grouped_mm_forward(
permuted_local_hidden_states,
tokens_per_expert,
permuted_probs,
)#
forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor],
permuted_probs: torch.Tensor,
routing_map: Optional[torch.Tensor] = None,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward pass with three modes:

  • Training: delegates to parent TEGroupedMLP.

  • Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert is not used in this path; the FlashInfer kernel operates directly on routing_map.

  • Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets.

Parameters:
  • permuted_local_hidden_states – [num_tokens, hidden_size] input hidden states.

  • tokens_per_expert – [num_experts] number of tokens routed to each expert. None when using the CUDA-graphed FlashInfer path.

  • permuted_probs – [num_tokens, topk] routing probabilities.

  • routing_map – [num_tokens, topk] token-to-expert assignment indices. Required for the FlashInfer CUDA-graphed path, None otherwise.

class core.transformer.moe.experts.SequentialMLP(
num_local_experts,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An implementation of the Experts layer using a sequence of MLP layers.

This class executes each expert sequentially.

Initialization

_pad_tensor_for_quantization(hidden, probs)#

Padding tensor shape to multiples of 16/32.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
)#

Forward step of the SequentialMLP.

backward_dw()#

Backward pass for weight gradients in SequentialMLP.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Maps local expert to global experts.