core.inference.moe.fused_moe#

Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute.

Supports BF16 weights with torch.nn.functional.grouped_mm. All permutation logic is handled internally — callers invoke a single function.

Module Contents#

Classes#

ActivationType

Activation functions supported by mcore_fused_moe.

Functions#

_bf16_grouped_mm

BF16 grouped GEMM using torch.nn.functional.grouped_mm.

_mxfp8_grouped_mm

MXFP8 scaled_grouped_mm with pre-quantized activations and weights.

_get_activation_func

Resolve ActivationType enum to a concrete kernel.

mcore_fused_moe

Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute].

API#

class core.inference.moe.fused_moe.ActivationType(*args, **kwds)#

Bases: enum.Enum

Activation functions supported by mcore_fused_moe.

Initialization

SQUARED_RELU#

‘squared_relu’

core.inference.moe.fused_moe._bf16_grouped_mm(
x_bf16: torch.Tensor,
weight: torch.Tensor,
offs: torch.Tensor,
) torch.Tensor#

BF16 grouped GEMM using torch.nn.functional.grouped_mm.

core.inference.moe.fused_moe._mxfp8_grouped_mm(
act: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
offs: torch.Tensor,
) torch.Tensor#

MXFP8 scaled_grouped_mm with pre-quantized activations and weights.

core.inference.moe.fused_moe._get_activation_func(
activation_type: core.inference.moe.fused_moe.ActivationType,
fused_quant: bool = False,
) Callable#

Resolve ActivationType enum to a concrete kernel.

If fused_quant=True, returns the fused activation + MXFP8 quantize kernel.

core.inference.moe.fused_moe.mcore_fused_moe(
hidden_states: torch.Tensor,
probs: torch.Tensor,
fc1_weight,
fc2_weight,
activation_type: core.inference.moe.fused_moe.ActivationType,
num_local_experts: int,
local_expert_start: int,
routing_map: Optional[torch.Tensor] = None,
tokens_per_expert: Optional[torch.Tensor] = None,
skip_permute: bool = False,
disable_fused_quant_kernels: bool = False,
) torch.Tensor#

Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute].

Two modes:

  • skip_permute=False (default): tokens are unpermuted. Requires routing_map. Performs full permute -> compute -> unpermute.

  • skip_permute=True: tokens are already permuted by the dispatcher. Requires tokens_per_expert. Pads to alignment, computes, then unpads. Probs are applied during unpad.

Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused kernels that combine permute/activation with MXFP8 quantization into single kernel launches.

Parameters:
  • hidden_states – [num_tokens, hidden_size] BF16 input.

  • probs – routing probabilities. Shape is [num_tokens, topk] when skip_permute=False, or [num_tokens] (already gathered) when skip_permute=True.

  • fc1_weight – stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8).

  • fc2_weight – stacked weight for FC2 (same type as fc1_weight).

  • activation_type – ActivationType enum (SQUARED_RELU).

  • num_local_experts – number of experts on this rank.

  • local_expert_start – first global expert index on this rank.

  • routing_map – [num_tokens, topk] int expert assignments. Required when skip_permute=False.

  • tokens_per_expert – [num_local_experts] int32 token counts. Required when skip_permute=True.

  • skip_permute – if True, skip permute/unpermute (tokens already in expert order).

  • disable_fused_quant_kernels – if True, disable fused permute+quantize and activation+quantize kernels for MXFP8, using separate launches instead. Useful for debugging. Ignored when weights are BF16.

Returns:

[num_tokens, hidden_size] BF16 output.