core.inference.moe.fused_moe#
Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute.
Supports BF16 weights with torch.nn.functional.grouped_mm. All permutation logic is handled internally — callers invoke a single function.
Module Contents#
Classes#
Activation functions supported by mcore_fused_moe. |
Functions#
BF16 grouped GEMM using torch.nn.functional.grouped_mm. |
|
MXFP8 scaled_grouped_mm with pre-quantized activations and weights. |
|
Resolve ActivationType enum to a concrete kernel. |
|
Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute]. |
API#
- class core.inference.moe.fused_moe.ActivationType(*args, **kwds)#
Bases:
enum.EnumActivation functions supported by mcore_fused_moe.
Initialization
- SQUARED_RELU#
‘squared_relu’
- core.inference.moe.fused_moe._bf16_grouped_mm(
- x_bf16: torch.Tensor,
- weight: torch.Tensor,
- offs: torch.Tensor,
BF16 grouped GEMM using torch.nn.functional.grouped_mm.
- core.inference.moe.fused_moe._mxfp8_grouped_mm(
- act: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- offs: torch.Tensor,
MXFP8 scaled_grouped_mm with pre-quantized activations and weights.
- core.inference.moe.fused_moe._get_activation_func(
- activation_type: core.inference.moe.fused_moe.ActivationType,
- fused_quant: bool = False,
Resolve ActivationType enum to a concrete kernel.
If fused_quant=True, returns the fused activation + MXFP8 quantize kernel.
- core.inference.moe.fused_moe.mcore_fused_moe(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- fc1_weight,
- fc2_weight,
- activation_type: core.inference.moe.fused_moe.ActivationType,
- num_local_experts: int,
- local_expert_start: int,
- routing_map: Optional[torch.Tensor] = None,
- tokens_per_expert: Optional[torch.Tensor] = None,
- skip_permute: bool = False,
- disable_fused_quant_kernels: bool = False,
Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute].
Two modes:
skip_permute=False (default): tokens are unpermuted. Requires routing_map. Performs full permute -> compute -> unpermute.
skip_permute=True: tokens are already permuted by the dispatcher. Requires tokens_per_expert. Pads to alignment, computes, then unpads. Probs are applied during unpad.
Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused kernels that combine permute/activation with MXFP8 quantization into single kernel launches.
- Parameters:
hidden_states – [num_tokens, hidden_size] BF16 input.
probs – routing probabilities. Shape is [num_tokens, topk] when skip_permute=False, or [num_tokens] (already gathered) when skip_permute=True.
fc1_weight – stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8).
fc2_weight – stacked weight for FC2 (same type as fc1_weight).
activation_type – ActivationType enum (SQUARED_RELU).
num_local_experts – number of experts on this rank.
local_expert_start – first global expert index on this rank.
routing_map – [num_tokens, topk] int expert assignments. Required when skip_permute=False.
tokens_per_expert – [num_local_experts] int32 token counts. Required when skip_permute=True.
skip_permute – if True, skip permute/unpermute (tokens already in expert order).
disable_fused_quant_kernels – if True, disable fused permute+quantize and activation+quantize kernels for MXFP8, using separate launches instead. Useful for debugging. Ignored when weights are BF16.
- Returns:
[num_tokens, hidden_size] BF16 output.