core.inference.moe.vllm_fused_moe#

vLLM-style Triton fused MoE kernel (BF16) for Megatron inference.

CUDA-graph compatible: all indirection table construction happens on-device via Triton kernels with fixed-size buffers and valid_tokens gating.

Module Contents#

Functions#

_select_block_size_m

Select BLOCK_SIZE_M based on the token buffer size.

_fused_moe_kernel

Persistent fused MoE grouped GEMM with indirect token addressing.

_ceil_div

_init_sorted_ids_kernel

Initialize sorted_token_ids to SENTINEL and expert_ids to -1.

_scatter_token_indices_kernel

Scatter local-expert pair indices into the padded indirection table.

_fill_expert_block_ids_kernel

Fill expert_ids with expert index for each BLOCK_SIZE_M block.

_moe_align_block_size_cuda_graphable

Build indirection tables for the vLLM kernel, fully on-device.

_invoke_fused_moe_kernel

Launch the persistent Triton fused-MoE kernel for one GEMM pass.

_moe_sum_kernel

Reduce topk dimension with valid_tokens gating and routing weight application.

_moe_sum

Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K].

vllm_fused_moe

Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16).

Data#

API#

core.inference.moe.vllm_fused_moe._select_block_size_m(max_tokens: int) int#

Select BLOCK_SIZE_M based on the token buffer size.

Smaller tiles reduce padding waste in the indirection table when each expert sees few tokens (decode). Larger tiles improve compute density for large batches (prefill). Minimum is 16 (tl.dot requirement on NVIDIA).

core.inference.moe.vllm_fused_moe._AUTOTUNE_CONFIGS#

None

core.inference.moe.vllm_fused_moe._fused_moe_kernel(
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
N,
K,
num_valid_tokens,
num_sms,
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
MUL_ROUTED_WEIGHT: triton.language.constexpr,
FUSE_SQUARED_RELU: triton.language.constexpr,
top_k: triton.language.constexpr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK_SIZE_N: triton.language.constexpr,
BLOCK_SIZE_K: triton.language.constexpr,
GROUP_SIZE_M: triton.language.constexpr,
)#

Persistent fused MoE grouped GEMM with indirect token addressing.

Launches a fixed grid of num_sms CTAs. Each CTA loops over its share of tiles, with total tile count determined device-side from num_tokens_post_padded. This decouples grid size from buffer size, keeping the kernel CUDA-graph safe while avoiding excess CTA overhead.

core.inference.moe.vllm_fused_moe._ceil_div(a, b)#
core.inference.moe.vllm_fused_moe._init_sorted_ids_kernel(
sorted_token_ids_ptr,
expert_ids_ptr,
max_sorted,
max_blocks,
SENTINEL: triton.language.constexpr,
BLOCK: triton.language.constexpr,
)#

Initialize sorted_token_ids to SENTINEL and expert_ids to -1.

core.inference.moe.vllm_fused_moe._scatter_token_indices_kernel(
routing_map_ptr,
sorted_token_ids_ptr,
counters_ptr,
valid_tokens_ptr,
topk: triton.language.constexpr,
local_expert_start,
num_local_experts: triton.language.constexpr,
max_pairs,
BLOCK_SIZE: triton.language.constexpr,
)#

Scatter local-expert pair indices into the padded indirection table.

Only local expert pairs are written; non-local pairs are skipped (the _moe_sum kernel handles them by checking the routing map directly).

core.inference.moe.vllm_fused_moe._fill_expert_block_ids_kernel(
expert_ids_ptr,
exclusive_offsets_ptr,
inclusive_offsets_ptr,
BLOCK_SIZE_M: triton.language.constexpr,
BLOCK: triton.language.constexpr,
)#

Fill expert_ids with expert index for each BLOCK_SIZE_M block.

Grid: one CTA per expert (parallelised across experts). Inner loop uses vectorised stores of BLOCK elements at a time.

core.inference.moe.vllm_fused_moe._moe_align_block_size_cuda_graphable(
routing_map: torch.Tensor,
block_size: int,
num_local_experts: int,
local_expert_start: int,
valid_tokens: torch.Tensor,
) tuple[torch.Tensor, torch.Tensor, torch.Tensor]#

Build indirection tables for the vLLM kernel, fully on-device.

Replaces the original _moe_align_block_size which used .item() calls and host-side loops. All buffers are allocated at fixed max sizes so the function is safe for CUDA graph capture.

Parameters:
  • routing_map – [max_tokens, topk] expert assignments.

  • block_size – BLOCK_SIZE_M for the vLLM kernel.

  • num_local_experts – experts on this rank.

  • local_expert_start – first global expert index on this rank.

  • valid_tokens – scalar int32 CUDA tensor.

Returns:

[max_sorted] int32 indirection table. expert_ids: [max_blocks] int32 expert per block. num_tokens_post_padded: [1] int32 (local expert padded count).

Return type:

sorted_token_ids

core.inference.moe.vllm_fused_moe._invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: Optional[torch.Tensor],
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
block_size_m: int,
fuse_squared_relu: bool = False,
)#

Launch the persistent Triton fused-MoE kernel for one GEMM pass.

Uses a fixed grid of NUM_SMS CTAs for CUDA-graph safety. Each CTA loops over its share of tiles, with actual work determined device-side.

core.inference.moe.vllm_fused_moe._moe_sum_kernel(
input_ptr,
output_ptr,
topk_weights_ptr,
valid_tokens_ptr,
routing_map_ptr,
local_expert_start,
num_local_experts: triton.language.constexpr,
K,
topk: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
NUM_K_BLOCKS: triton.language.constexpr,
)#

Reduce topk dimension with valid_tokens gating and routing weight application.

input: [max_tokens * topk, K] bf16 output: [max_tokens, K] — dtype matches the output buffer (fp32 or bf16)

For token t < valid_tokens: output[t] = sum of input[ttopk+k] * prob[ttopk+k] over topk slots k where the expert is local. Non-local slots are skipped (their values in input are undefined because FC2 only processes local-expert blocks). For token t >= valid_tokens: output[t] = 0. Routing weight multiplication and accumulation in fp32 for numerical accuracy.

core.inference.moe.vllm_fused_moe._moe_sum(
input: torch.Tensor,
topk_weights: torch.Tensor,
max_tokens: int,
topk: int,
K: int,
valid_tokens: torch.Tensor,
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
out: Optional[torch.Tensor] = None,
) torch.Tensor#

Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K].

Applies routing weights and reduces over topk in a single kernel. Accumulates in fp32. When out is None, allocates and returns an fp32 buffer. When out is provided (e.g. the RSV symmetric memory tensor), writes directly into it — tl.store handles the cast to the buffer’s dtype. Rows beyond valid_tokens are zeroed. Only accumulates contributions from local experts; non-local topk slots are skipped (their values in input are undefined).

core.inference.moe.vllm_fused_moe.vllm_fused_moe(
hidden_states: torch.Tensor,
probs: torch.Tensor,
fc1_weight: torch.Tensor,
fc2_weight: torch.Tensor,
activation_type: megatron.core.inference.moe.fused_moe.ActivationType,
num_local_experts: int,
local_expert_start: int,
valid_tokens: torch.Tensor,
routing_map: torch.Tensor,
out: Optional[torch.Tensor] = None,
num_tokens_hint: Optional[int] = None,
) torch.Tensor#

Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16).

CUDA-graph compatible: indirection tables are built entirely on-device using fixed-size buffers gated by valid_tokens.

Parameters:
  • hidden_states – [max_tokens, hidden_size] BF16 input. Only the first valid_tokens rows are valid; the rest are ignored.

  • probs – [max_tokens, topk] fp32 routing probabilities.

  • fc1_weight – [num_local_experts, fc1_out, hidden_size] BF16.

  • fc2_weight – [num_local_experts, hidden_size, fc1_out] BF16.

  • activation_type – ActivationType enum.

  • num_local_experts – experts on this rank.

  • local_expert_start – first global expert index on this rank.

  • valid_tokens – scalar int32 CUDA tensor with number of valid tokens.

  • routing_map – [max_tokens, topk] int expert assignments.

  • out – optional [max_tokens, hidden_size] output buffer (e.g. the RSV symmetric memory tensor). If None, an fp32 buffer is allocated. When provided, tl.store casts to the buffer’s dtype automatically.

  • num_tokens_hint – optional host-side int with the expected number of valid tokens (e.g. batch_size * ep_size). Used to select a better BLOCK_SIZE_M instead of using the worst-case buffer size.

Returns:

[max_tokens, hidden_size] output (fp32 when out=None, else out’s dtype). tl.store handles the implicit cast when out is a different dtype.