core.inference.moe.fused_moe#
Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute.
Supports BF16 weights with torch.nn.functional.grouped_mm. All permutation logic is handled internally — callers invoke a single function.
Module Contents#
Classes#
Activation functions supported by mcore_fused_moe. |
Functions#
BF16 grouped GEMM using torch.nn.functional.grouped_mm. |
|
MXFP8 scaled_grouped_mm with pre-quantized activations and weights. |
|
Resolve ActivationType enum to a concrete kernel. |
|
Fused MoE: permute -> pad -> FC1 -> activation -> FC2 -> unpad -> unpermute. |
API#
- class core.inference.moe.fused_moe.ActivationType(*args, **kwds)#
Bases:
enum.EnumActivation functions supported by mcore_fused_moe.
Initialization
- SQUARED_RELU#
‘squared_relu’
- core.inference.moe.fused_moe._bf16_grouped_mm(
- x_bf16: torch.Tensor,
- weight: torch.Tensor,
- offs: torch.Tensor,
BF16 grouped GEMM using torch.nn.functional.grouped_mm.
- core.inference.moe.fused_moe._mxfp8_grouped_mm(
- act: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- offs: torch.Tensor,
MXFP8 scaled_grouped_mm with pre-quantized activations and weights.
- core.inference.moe.fused_moe._get_activation_func(
- activation_type: core.inference.moe.fused_moe.ActivationType,
- fused_quant: bool = False,
Resolve ActivationType enum to a concrete kernel.
If fused_quant=True, returns the fused activation + MXFP8 quantize kernel.
- core.inference.moe.fused_moe.mcore_fused_moe(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- fc1_weight,
- fc2_weight,
- activation_type: core.inference.moe.fused_moe.ActivationType,
- num_local_experts: int,
- local_expert_start: int,
- valid_tokens: torch.Tensor,
- routing_map: torch.Tensor,
- disable_fused_quant_kernels: bool = False,
- out: torch.Tensor = None,
Fused MoE: permute -> pad -> FC1 -> activation -> FC2 -> unpad -> unpermute.
Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused kernels that combine permute/activation with MXFP8 quantization into single kernel launches.
- Parameters:
hidden_states – [max_tokens, hidden_size] BF16 input. max_tokens = max_local_tokens * ep_size; only the first valid_tokens rows are valid.
probs – [max_tokens, topk] routing probabilities.
fc1_weight – stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8).
fc2_weight – stacked weight for FC2 (same type as fc1_weight).
activation_type – ActivationType enum (SQUARED_RELU).
num_local_experts – number of experts on this rank.
local_expert_start – first global expert index on this rank.
valid_tokens – scalar int32 CUDA tensor holding the number of valid tokens this iteration. Kernels use this to ignore rows beyond the valid prefix — required for CUDA graph compatibility since hidden_states is always max-sized.
routing_map – [max_tokens, topk] int expert assignments.
disable_fused_quant_kernels – if True, disable fused permute+quantize and activation+quantize kernels for MXFP8, using separate launches instead. Useful for debugging. Ignored when weights are BF16.
out – optional pre-allocated output buffer. If provided, unpermute writes directly into this tensor (e.g. the RSV symmetric buffer), avoiding a separate copy before reduce-scatter.
- Returns:
[max_tokens, hidden_size] BF16 output. Only the first valid_tokens rows are meaningful; rows beyond that are undefined.