core.inference.moe.fused_moe#

Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute.

Supports BF16 weights with torch.nn.functional.grouped_mm. All permutation logic is handled internally — callers invoke a single function.

Module Contents#

Classes#

ActivationType

Activation functions supported by mcore_fused_moe.

Functions#

_bf16_grouped_mm

BF16 grouped GEMM using torch.nn.functional.grouped_mm.

_mxfp8_grouped_mm

MXFP8 scaled_grouped_mm with pre-quantized activations and weights.

_get_activation_func

Resolve ActivationType enum to a concrete kernel.

mcore_fused_moe

Fused MoE: permute -> pad -> FC1 -> activation -> FC2 -> unpad -> unpermute.

API#

class core.inference.moe.fused_moe.ActivationType(*args, **kwds)#

Bases: enum.Enum

Activation functions supported by mcore_fused_moe.

Initialization

SQUARED_RELU#

‘squared_relu’

core.inference.moe.fused_moe._bf16_grouped_mm(
x_bf16: torch.Tensor,
weight: torch.Tensor,
offs: torch.Tensor,
) torch.Tensor#

BF16 grouped GEMM using torch.nn.functional.grouped_mm.

core.inference.moe.fused_moe._mxfp8_grouped_mm(
act: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
offs: torch.Tensor,
) torch.Tensor#

MXFP8 scaled_grouped_mm with pre-quantized activations and weights.

core.inference.moe.fused_moe._get_activation_func(
activation_type: core.inference.moe.fused_moe.ActivationType,
fused_quant: bool = False,
) Callable#

Resolve ActivationType enum to a concrete kernel.

If fused_quant=True, returns the fused activation + MXFP8 quantize kernel.

core.inference.moe.fused_moe.mcore_fused_moe(
hidden_states: torch.Tensor,
probs: torch.Tensor,
fc1_weight,
fc2_weight,
activation_type: core.inference.moe.fused_moe.ActivationType,
num_local_experts: int,
local_expert_start: int,
valid_tokens: torch.Tensor,
routing_map: torch.Tensor,
disable_fused_quant_kernels: bool = False,
out: torch.Tensor = None,
) torch.Tensor#

Fused MoE: permute -> pad -> FC1 -> activation -> FC2 -> unpad -> unpermute.

Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused kernels that combine permute/activation with MXFP8 quantization into single kernel launches.

Parameters:
  • hidden_states – [max_tokens, hidden_size] BF16 input. max_tokens = max_local_tokens * ep_size; only the first valid_tokens rows are valid.

  • probs – [max_tokens, topk] routing probabilities.

  • fc1_weight – stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8).

  • fc2_weight – stacked weight for FC2 (same type as fc1_weight).

  • activation_type – ActivationType enum (SQUARED_RELU).

  • num_local_experts – number of experts on this rank.

  • local_expert_start – first global expert index on this rank.

  • valid_tokens – scalar int32 CUDA tensor holding the number of valid tokens this iteration. Kernels use this to ignore rows beyond the valid prefix — required for CUDA graph compatibility since hidden_states is always max-sized.

  • routing_map – [max_tokens, topk] int expert assignments.

  • disable_fused_quant_kernels – if True, disable fused permute+quantize and activation+quantize kernels for MXFP8, using separate launches instead. Useful for debugging. Ignored when weights are BF16.

  • out – optional pre-allocated output buffer. If provided, unpermute writes directly into this tensor (e.g. the RSV symmetric buffer), avoiding a separate copy before reduce-scatter.

Returns:

[max_tokens, hidden_size] BF16 output. Only the first valid_tokens rows are meaningful; rows beyond that are undefined.