core.transformer.moe.experts#

Module Contents#

Classes#

GroupedLinearFc1Interface

Interface for linear_fc1 module in TEGroupedMLP.

GroupedLinearFc1Builder

Protocol describing how to build a linear_fc1 layer in TEGroupedMLP.

GroupedLinearFc2Interface

Protocol for linear_fc2 module in TEGroupedMLP.

GroupedLinearFc2Builder

Protocol describing how to build a linear_fc2 layer in TEGroupedMLP.

GroupedMLPSubmodules

The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2.

TEGroupedMLP

An efficient implementation of the Experts layer using TE’s GroupedLinear.

InferenceGroupedMLP

Inference-optimized GroupedMLP with GPU-resident offsets.

SequentialMLP

An implementation of the Experts layer using a sequence of MLP layers.

Data#

API#

core.transformer.moe.experts.logger#

‘getLogger(…)’

class core.transformer.moe.experts.GroupedLinearFc1Interface#

Bases: typing.Protocol

Interface for linear_fc1 module in TEGroupedMLP.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: list[int],
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc1 module.

backward_dw() None#

Backward method for linear_fc1 module.

class core.transformer.moe.experts.GroupedLinearFc1Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc1 layer in TEGroupedMLP.

__call__(
num_local_experts: int,
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
) core.transformer.moe.experts.GroupedLinearFc1Interface#

Builds a linear_fc1 layer for TEGroupedMLP.

class core.transformer.moe.experts.GroupedLinearFc2Interface#

Bases: typing.Protocol

Protocol for linear_fc2 module in TEGroupedMLP.

forward(
intermediate_parallel: torch.Tensor,
tokens_per_expert: list[int],
/,
) tuple[torch.Tensor, torch.Tensor | None]#

Forward method for linear_fc2 module.

backward_dw() None#

Backward method for linear_fc2 module.

class core.transformer.moe.experts.GroupedLinearFc2Builder#

Bases: typing.Protocol

Protocol describing how to build a linear_fc2 layer in TEGroupedMLP.

__call__(
num_local_experts: int,
input_size: int,
output_size: int,
/,
*,
config: megatron.core.transformer.transformer_config.TransformerConfig,
init_method: collections.abc.Callable[[torch.Tensor], None],
bias: bool,
skip_bias_add: bool,
is_expert: bool,
tp_comm_buffer_name: str | None,
pg_collection: megatron.core.transformer.moe.moe_utils.ProcessGroupCollection | None,
) core.transformer.moe.experts.GroupedLinearFc2Interface#

Builds a linear_fc2 layer for TEGroupedMLP.

class core.transformer.moe.experts.GroupedMLPSubmodules#

The dataclass for ModuleSpecs of TEGroupedMLP submodules including linear fc1, activation function, linear fc2.

linear_fc1: core.transformer.moe.experts.GroupedLinearFc1Builder#

None

linear_fc2: core.transformer.moe.experts.GroupedLinearFc2Builder#

None

activation_func: megatron.core.transformer.mlp.TEActivationFunctionBuilder | None#

None

Builder for an activation function module; only used if config.use_te_activation_func is True.

class core.transformer.moe.experts.TEGroupedMLP(
num_local_experts: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.moe.experts.GroupedMLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An efficient implementation of the Experts layer using TE’s GroupedLinear.

Executes multiple experts in parallel to maximize computational efficiency.

Initialization

static _apply_bias(
intermediate_parallel,
bias_parallel,
tokens_per_expert,
permuted_probs,
)#
bias_act_func(intermediate_parallel, bias_parallel, permuted_probs)#

Applies bias and activation function to the output of linear_fc1.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward of TEGroupedMLP

Parameters:
  • permuted_local_hidden_states (torch.Tensor) – The permuted input hidden states of the

  • experts. (local)

  • tokens_per_expert (torch.Tensor) – The number of tokens per expert.

  • permuted_probs (torch.Tensor) – The permuted probs of each token produced by the router.

Returns:

The output of the local experts.

Return type:

output (torch.Tensor)

sharded_state_dict(
prefix: str = '',
sharded_offsets: tuple = (),
metadata: Optional[dict] = None,
) megatron.core.dist_checkpointing.mapping.ShardedStateDict#

Maps local expert to global experts. The sharded state dict is interchangable with SequentialMLP’s.

backward_dw()#

Performs backward pass for weight gradients in TEGroupedMLP.

This method executes the backward pass for weight gradients by calling backward_dw() on the linear layers in reverse order (fc2 followed by fc1). If an error occurs during execution, it is caught and re-raised with a descriptive message.

class core.transformer.moe.experts.InferenceGroupedMLP(
num_local_experts: int,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: core.transformer.moe.experts.GroupedMLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: core.transformer.moe.experts.TEGroupedMLP

Inference-optimized GroupedMLP with GPU-resident offsets.

Inherits from TEGroupedMLP to reuse weight initialization and checkpoint compatibility. Supports three forward paths:

  • Training: delegates to parent TEGroupedMLP

  • Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM)

  • Inference + eager: torch.nn.functional.grouped_mm with GPU-resident cumsum offsets

Initialization

_resolve_flashinfer_activation_type()#

Map megatron activation config to FlashInfer ActivationType.

_resolve_mcore_activation_type()#

Map megatron activation config to mcore_fused_moe ActivationType.

set_inference_cuda_graphed_iteration()#

Enable CUDA-graphed iteration mode.

unset_inference_cuda_graphed_iteration()#

Disable CUDA-graphed iteration mode.

_build_concatenated_mxfp8_weights()#

Build stacked MXFP8 weight tensors from per-expert MXFP8Tensor attributes.

After quantize_model_to_mxfp8, each per-expert weight (weight0, weight1, …) has been replaced with an MXFP8Tensor. This method stacks their data and scales into _fc1_weight / _fc2_weight for scaled_grouped_mm.

Note: this creates a contiguous copy since per-expert MXFP8Tensor attributes are not contiguous across experts. This is a one-time cost at first forward.

Unlike _build_concatenated_weights, this does not create nn.Parameter views back into the buffer — MXFP8 weights are not nn.Parameters (they are plain MXFP8Tensor attributes set by quantize_model_to_mxfp8). This path is only intended for non-colocated inference.

_build_concatenated_weights()#

Create big contiguous weight tensors that share storage with TE’s per-expert parameters.

Creates _fc1_weight and _fc2_weight as contiguous tensors of shape [num_experts, out_features, in_features]. Instead of replacing TE’s parameters (which breaks TE’s internal bookkeeping), we redirect each parameter’s .data to be a view into the contiguous buffer. The nn.Parameter objects themselves remain untouched in TE’s module, preserving FP8 scaling state, etc.

This allows:

  • TE’s forward to work correctly (same Parameter objects, same internal state)

  • Training updates to flow through (param.data is a view into the big tensor)

  • torch.nn.functional.grouped_mm / FlashInfer to use the big tensor directly

_flashinfer_forward(hidden_states, routing_map, probs)#

FlashInfer fused MoE kernel for CUDA-graphed inference iterations.

_mcore_fused_moe_forward(
hidden_states,
probs,
routing_map=None,
tokens_per_expert=None,
skip_permute=False,
)#

Torch grouped_mm fused MoE forward via mcore_fused_moe.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor],
permuted_probs: torch.Tensor,
routing_map: Optional[torch.Tensor] = None,
) Tuple[torch.Tensor, Optional[torch.Tensor]]#

Forward pass with three modes:

  • Training: delegates to parent TEGroupedMLP.

  • Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert is not used in this path; the FlashInfer kernel operates directly on routing_map.

  • Inference + eager: torch.nn.functional.grouped_mm with GPU-resident cumsum offsets.

Parameters:
  • permuted_local_hidden_states – [num_tokens, hidden_size] input hidden states.

  • tokens_per_expert – [num_experts] number of tokens routed to each expert. None when using the CUDA-graphed FlashInfer path.

  • permuted_probs – [num_tokens, topk] routing probabilities.

  • routing_map – [num_tokens, topk] token-to-expert assignment indices. Required for the FlashInfer CUDA-graphed path, None otherwise.

class core.transformer.moe.experts.SequentialMLP(
num_local_experts,
config: megatron.core.transformer.transformer_config.TransformerConfig,
submodules: megatron.core.transformer.mlp.MLPSubmodules,
pg_collection: Optional[megatron.core.transformer.moe.moe_utils.ProcessGroupCollection] = None,
)#

Bases: megatron.core.transformer.module.MegatronModule

An implementation of the Experts layer using a sequence of MLP layers.

This class executes each expert sequentially.

Initialization

_pad_tensor_for_quantization(hidden, probs)#

Padding tensor shape to multiples of 16/32.

forward(
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
permuted_probs: torch.Tensor,
)#

Forward step of the SequentialMLP.

backward_dw()#

Backward pass for weight gradients in SequentialMLP.

sharded_state_dict(prefix='', sharded_offsets=(), metadata=None)#

Maps local expert to global experts.