core.inference.moe.vllm_fused_moe#
vLLM-style Triton fused MoE kernel (BF16) for Megatron inference.
CUDA-graph compatible: all indirection table construction happens on-device via Triton kernels with fixed-size buffers and valid_tokens gating.
Module Contents#
Functions#
Pick BLOCK_SIZE_*, GROUP_SIZE_M, num_warps, num_stages from M, E, top_k. |
|
Fused MoE grouped GEMM with indirect token addressing. |
|
Initialize sorted_token_ids to SENTINEL and expert_ids to -1. |
|
Scatter local-expert pair indices into the padded indirection table. |
|
Fill expert_ids with expert index for each BLOCK_SIZE_M block. |
|
Build indirection tables for the vLLM kernel, fully on-device. |
|
Launch the Triton fused-MoE kernel for one GEMM pass. |
|
Reduce topk dimension with routing weight application. |
|
Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K]. |
|
Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16). |
API#
- core.inference.moe.vllm_fused_moe._get_default_config(M: int, E: int, top_k: int) dict#
Pick BLOCK_SIZE_*, GROUP_SIZE_M, num_warps, num_stages from M, E, top_k.
Mirrors vLLM’s
get_default_config(bf16/fp16 branch) verbatim: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.pyM here is the host-side token-count hint (
num_tokens_hintinvllm_fused_moe), NOThidden_states.size(0). The hint is the expected per-step token count; the worst-case buffer size would over-tune for prefill on every decode step.Two intuitions drive the choices:
Small M is memory-bound (favor tall/narrow tiles, more pipeline stages); large M is compute-bound (favor short/wide tiles, more warps).
Padding tax dominates at small M — the indirection table pads M-tiles per expert, so small M-tiles minimize wasted rows.
- core.inference.moe.vllm_fused_moe._fused_moe_kernel(
- a_ptr,
- b_ptr,
- c_ptr,
- topk_weights_ptr,
- sorted_token_ids_ptr,
- expert_ids_ptr,
- num_tokens_post_padded_ptr,
- N,
- K,
- num_valid_tokens,
- stride_am,
- stride_ak,
- stride_be,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- MUL_ROUTED_WEIGHT: triton.language.constexpr,
- FUSE_SQUARED_RELU: triton.language.constexpr,
- top_k: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
Fused MoE grouped GEMM with indirect token addressing.
Body mirrors vLLM’s
fused_moe_kernelverbatim except for theFUSE_SQUARED_RELUbranch (Megatron applies relu+square in fp32 on the accumulator before the bf16 cast — strictly more accurate than upstream’s separate post-FC1 activation kernel).Grid is sized host-side from
num_tokens_hint(the typical-case token count), not the worst-case buffer length, so launch overhead at decode stays small. When the actual padded length exceeds the hinted grid size (rare prefill spikes), each CTA strides over multiple tiles via the outertl.rangeloop.
- core.inference.moe.vllm_fused_moe._ceil_div(a, b)#
- core.inference.moe.vllm_fused_moe._init_sorted_ids_kernel(
- sorted_token_ids_ptr,
- expert_ids_ptr,
- max_sorted,
- max_blocks,
- SENTINEL: triton.language.constexpr,
- BLOCK: triton.language.constexpr,
Initialize sorted_token_ids to SENTINEL and expert_ids to -1.
- core.inference.moe.vllm_fused_moe._scatter_token_indices_kernel(
- routing_map_ptr,
- sorted_token_ids_ptr,
- counters_ptr,
- valid_tokens_ptr,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- max_pairs,
- BLOCK_SIZE: triton.language.constexpr,
Scatter local-expert pair indices into the padded indirection table.
Only local expert pairs are written; non-local pairs are skipped (the _moe_sum kernel handles them by checking the routing map directly).
- core.inference.moe.vllm_fused_moe._fill_expert_block_ids_kernel(
- expert_ids_ptr,
- exclusive_offsets_ptr,
- inclusive_offsets_ptr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK: triton.language.constexpr,
Fill expert_ids with expert index for each BLOCK_SIZE_M block.
Grid: one CTA per expert (parallelised across experts). Inner loop uses vectorised stores of BLOCK elements at a time.
- core.inference.moe.vllm_fused_moe._moe_align_block_size_cuda_graphable(
- routing_map: torch.Tensor,
- block_size: int,
- num_local_experts: int,
- local_expert_start: int,
- valid_tokens: torch.Tensor,
Build indirection tables for the vLLM kernel, fully on-device.
Replaces the original _moe_align_block_size which used .item() calls and host-side loops. All buffers are allocated at fixed max sizes so the function is safe for CUDA graph capture.
- Parameters:
routing_map – [max_tokens, topk] expert assignments.
block_size – BLOCK_SIZE_M for the vLLM kernel.
num_local_experts – experts on this rank.
local_expert_start – first global expert index on this rank.
valid_tokens – scalar int32 CUDA tensor.
- Returns:
[max_sorted] int32 indirection table. expert_ids: [max_blocks] int32 expert per block. num_tokens_post_padded: [1] int32 (local expert padded count).
- Return type:
sorted_token_ids
- core.inference.moe.vllm_fused_moe._invoke_fused_moe_kernel(
- A: torch.Tensor,
- B: torch.Tensor,
- C: torch.Tensor,
- topk_weights: Optional[torch.Tensor],
- sorted_token_ids: torch.Tensor,
- expert_ids: torch.Tensor,
- num_tokens_post_padded: torch.Tensor,
- mul_routed_weight: bool,
- top_k: int,
- config: dict,
- grid_size: int,
- fuse_squared_relu: bool = False,
Launch the Triton fused-MoE kernel for one GEMM pass.
Body matches upstream vLLM
fused_moe_kernel(1 CTA per (pid_m, pid_n) tile, raw pointer arithmetic with% Non the N axis), apart from the optional fused squared-relu activation in fp32.grid_sizeis sized host-side fromnum_tokens_hintso launch overhead at decode is small. When the actual padded length exceeds the hinted grid size, each CTA strides over additional tiles via the kernel’s outertl.range. The full launch config (tile sizes, warps, stages) is picked host-side by_get_default_configfrom M = num_tokens_hint.
- core.inference.moe.vllm_fused_moe._moe_sum_kernel(
- input_ptr,
- output_ptr,
- topk_weights_ptr,
- valid_tokens_ptr,
- routing_map_ptr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- K,
- topk: triton.language.constexpr,
- BLOCK_M: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- NUM_K_BLOCKS: triton.language.constexpr,
Reduce topk dimension with routing weight application.
input: [max_tokens * topk, K] bf16 output: [max_tokens, K] — dtype matches the output buffer (fp32 or bf16)
For token t < valid_tokens: output[t] = sum of input[ttopk+k] * prob[ttopk+k] over topk slots k where the expert is local. Non-local slots are skipped (their values in
inputare undefined because FC2 only processes local-expert blocks). Rows for t >= valid_tokens are not written; downstream consumers (e.g. reduce-scatter-v) only read the first valid_tokens rows. Routing weight multiplication and accumulation in fp32 for numerical accuracy.Persistent grid: launches BLOCK_M CTAs that stride over valid_tokens. CUDA-graph safe (grid is static); the loop bound is loaded device-side.
- core.inference.moe.vllm_fused_moe._moe_sum(
- input: torch.Tensor,
- topk_weights: torch.Tensor,
- max_tokens: int,
- topk: int,
- K: int,
- valid_tokens: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- out: Optional[torch.Tensor] = None,
Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K].
Applies routing weights and reduces over topk in a single kernel. Accumulates in fp32. When
outis None, allocates and returns an fp32 buffer. Whenoutis provided (e.g. the RSV symmetric memory tensor), writes directly into it — tl.store handles the cast to the buffer’s dtype. Only writes the first valid_tokens rows; rows beyond are left untouched (downstream RSV reads only the valid range). Only accumulates contributions from local experts; non-local topk slots are skipped (their values ininputare undefined).
- core.inference.moe.vllm_fused_moe.vllm_fused_moe(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- fc1_weight: torch.Tensor,
- fc2_weight: torch.Tensor,
- activation_type: megatron.core.inference.moe.fused_moe.ActivationType,
- num_local_experts: int,
- local_expert_start: int,
- valid_tokens: torch.Tensor,
- routing_map: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- num_tokens_hint: Optional[int] = None,
Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16).
CUDA-graph compatible: indirection tables are built entirely on-device using fixed-size buffers gated by valid_tokens.
- Parameters:
hidden_states – [max_tokens, hidden_size] BF16 input. Only the first valid_tokens rows are valid; the rest are ignored.
probs – [max_tokens, topk] fp32 routing probabilities.
fc1_weight – [num_local_experts, fc1_out, hidden_size] BF16.
fc2_weight – [num_local_experts, hidden_size, fc1_out] BF16.
activation_type – ActivationType enum.
num_local_experts – experts on this rank.
local_expert_start – first global expert index on this rank.
valid_tokens – scalar int32 CUDA tensor with number of valid tokens.
routing_map – [max_tokens, topk] int expert assignments.
out – optional [max_tokens, hidden_size] output buffer (e.g. the RSV symmetric memory tensor). If None, an fp32 buffer is allocated. When provided, tl.store casts to the buffer’s dtype automatically.
num_tokens_hint – optional host-side int with the expected number of valid tokens (e.g. batch_size * ep_size). Used to select a better BLOCK_SIZE_M instead of using the worst-case buffer size.
- Returns:
[max_tokens, hidden_size] output (fp32 when out=None, else out’s dtype). tl.store handles the implicit cast when out is a different dtype.