core.inference.moe.vllm_fused_moe#
vLLM-style Triton fused MoE kernel (BF16) for Megatron inference.
CUDA-graph compatible: all indirection table construction happens on-device via Triton kernels with fixed-size buffers and valid_tokens gating.
Module Contents#
Functions#
Select BLOCK_SIZE_M based on the token buffer size. |
|
Persistent fused MoE grouped GEMM with indirect token addressing. |
|
Initialize sorted_token_ids to SENTINEL and expert_ids to -1. |
|
Scatter local-expert pair indices into the padded indirection table. |
|
Fill expert_ids with expert index for each BLOCK_SIZE_M block. |
|
Build indirection tables for the vLLM kernel, fully on-device. |
|
Launch the persistent Triton fused-MoE kernel for one GEMM pass. |
|
Reduce topk dimension with valid_tokens gating and routing weight application. |
|
Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K]. |
|
Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16). |
Data#
API#
- core.inference.moe.vllm_fused_moe._select_block_size_m(max_tokens: int) int#
Select BLOCK_SIZE_M based on the token buffer size.
Smaller tiles reduce padding waste in the indirection table when each expert sees few tokens (decode). Larger tiles improve compute density for large batches (prefill). Minimum is 16 (tl.dot requirement on NVIDIA).
- core.inference.moe.vllm_fused_moe._AUTOTUNE_CONFIGS#
None
- core.inference.moe.vllm_fused_moe._fused_moe_kernel(
- a_ptr,
- b_ptr,
- c_ptr,
- topk_weights_ptr,
- sorted_token_ids_ptr,
- expert_ids_ptr,
- num_tokens_post_padded_ptr,
- N,
- K,
- num_valid_tokens,
- num_sms,
- stride_am,
- stride_ak,
- stride_be,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- MUL_ROUTED_WEIGHT: triton.language.constexpr,
- FUSE_SQUARED_RELU: triton.language.constexpr,
- top_k: triton.language.constexpr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK_SIZE_N: triton.language.constexpr,
- BLOCK_SIZE_K: triton.language.constexpr,
- GROUP_SIZE_M: triton.language.constexpr,
Persistent fused MoE grouped GEMM with indirect token addressing.
Launches a fixed grid of num_sms CTAs. Each CTA loops over its share of tiles, with total tile count determined device-side from num_tokens_post_padded. This decouples grid size from buffer size, keeping the kernel CUDA-graph safe while avoiding excess CTA overhead.
- core.inference.moe.vllm_fused_moe._ceil_div(a, b)#
- core.inference.moe.vllm_fused_moe._init_sorted_ids_kernel(
- sorted_token_ids_ptr,
- expert_ids_ptr,
- max_sorted,
- max_blocks,
- SENTINEL: triton.language.constexpr,
- BLOCK: triton.language.constexpr,
Initialize sorted_token_ids to SENTINEL and expert_ids to -1.
- core.inference.moe.vllm_fused_moe._scatter_token_indices_kernel(
- routing_map_ptr,
- sorted_token_ids_ptr,
- counters_ptr,
- valid_tokens_ptr,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- max_pairs,
- BLOCK_SIZE: triton.language.constexpr,
Scatter local-expert pair indices into the padded indirection table.
Only local expert pairs are written; non-local pairs are skipped (the _moe_sum kernel handles them by checking the routing map directly).
- core.inference.moe.vllm_fused_moe._fill_expert_block_ids_kernel(
- expert_ids_ptr,
- exclusive_offsets_ptr,
- inclusive_offsets_ptr,
- BLOCK_SIZE_M: triton.language.constexpr,
- BLOCK: triton.language.constexpr,
Fill expert_ids with expert index for each BLOCK_SIZE_M block.
Grid: one CTA per expert (parallelised across experts). Inner loop uses vectorised stores of BLOCK elements at a time.
- core.inference.moe.vllm_fused_moe._moe_align_block_size_cuda_graphable(
- routing_map: torch.Tensor,
- block_size: int,
- num_local_experts: int,
- local_expert_start: int,
- valid_tokens: torch.Tensor,
Build indirection tables for the vLLM kernel, fully on-device.
Replaces the original _moe_align_block_size which used .item() calls and host-side loops. All buffers are allocated at fixed max sizes so the function is safe for CUDA graph capture.
- Parameters:
routing_map – [max_tokens, topk] expert assignments.
block_size – BLOCK_SIZE_M for the vLLM kernel.
num_local_experts – experts on this rank.
local_expert_start – first global expert index on this rank.
valid_tokens – scalar int32 CUDA tensor.
- Returns:
[max_sorted] int32 indirection table. expert_ids: [max_blocks] int32 expert per block. num_tokens_post_padded: [1] int32 (local expert padded count).
- Return type:
sorted_token_ids
- core.inference.moe.vllm_fused_moe._invoke_fused_moe_kernel(
- A: torch.Tensor,
- B: torch.Tensor,
- C: torch.Tensor,
- topk_weights: Optional[torch.Tensor],
- sorted_token_ids: torch.Tensor,
- expert_ids: torch.Tensor,
- num_tokens_post_padded: torch.Tensor,
- mul_routed_weight: bool,
- top_k: int,
- block_size_m: int,
- fuse_squared_relu: bool = False,
Launch the persistent Triton fused-MoE kernel for one GEMM pass.
Uses a fixed grid of NUM_SMS CTAs for CUDA-graph safety. Each CTA loops over its share of tiles, with actual work determined device-side.
- core.inference.moe.vllm_fused_moe._moe_sum_kernel(
- input_ptr,
- output_ptr,
- topk_weights_ptr,
- valid_tokens_ptr,
- routing_map_ptr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- K,
- topk: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- NUM_K_BLOCKS: triton.language.constexpr,
Reduce topk dimension with valid_tokens gating and routing weight application.
input: [max_tokens * topk, K] bf16 output: [max_tokens, K] — dtype matches the output buffer (fp32 or bf16)
For token t < valid_tokens: output[t] = sum of input[ttopk+k] * prob[ttopk+k] over topk slots k where the expert is local. Non-local slots are skipped (their values in
inputare undefined because FC2 only processes local-expert blocks). For token t >= valid_tokens: output[t] = 0. Routing weight multiplication and accumulation in fp32 for numerical accuracy.
- core.inference.moe.vllm_fused_moe._moe_sum(
- input: torch.Tensor,
- topk_weights: torch.Tensor,
- max_tokens: int,
- topk: int,
- K: int,
- valid_tokens: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- out: Optional[torch.Tensor] = None,
Fused topk reduction: [max_tokens*topk, K] bf16 → [max_tokens, K].
Applies routing weights and reduces over topk in a single kernel. Accumulates in fp32. When
outis None, allocates and returns an fp32 buffer. Whenoutis provided (e.g. the RSV symmetric memory tensor), writes directly into it — tl.store handles the cast to the buffer’s dtype. Rows beyond valid_tokens are zeroed. Only accumulates contributions from local experts; non-local topk slots are skipped (their values ininputare undefined).
- core.inference.moe.vllm_fused_moe.vllm_fused_moe(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- fc1_weight: torch.Tensor,
- fc2_weight: torch.Tensor,
- activation_type: megatron.core.inference.moe.fused_moe.ActivationType,
- num_local_experts: int,
- local_expert_start: int,
- valid_tokens: torch.Tensor,
- routing_map: torch.Tensor,
- out: Optional[torch.Tensor] = None,
- num_tokens_hint: Optional[int] = None,
Fused MoE using the vLLM Triton grouped-GEMM kernel (BF16).
CUDA-graph compatible: indirection tables are built entirely on-device using fixed-size buffers gated by valid_tokens.
- Parameters:
hidden_states – [max_tokens, hidden_size] BF16 input. Only the first valid_tokens rows are valid; the rest are ignored.
probs – [max_tokens, topk] fp32 routing probabilities.
fc1_weight – [num_local_experts, fc1_out, hidden_size] BF16.
fc2_weight – [num_local_experts, hidden_size, fc1_out] BF16.
activation_type – ActivationType enum.
num_local_experts – experts on this rank.
local_expert_start – first global expert index on this rank.
valid_tokens – scalar int32 CUDA tensor with number of valid tokens.
routing_map – [max_tokens, topk] int expert assignments.
out – optional [max_tokens, hidden_size] output buffer (e.g. the RSV symmetric memory tensor). If None, an fp32 buffer is allocated. When provided, tl.store casts to the buffer’s dtype automatically.
num_tokens_hint – optional host-side int with the expected number of valid tokens (e.g. batch_size * ep_size). Used to select a better BLOCK_SIZE_M instead of using the worst-case buffer size.
- Returns:
[max_tokens, hidden_size] output (fp32 when out=None, else out’s dtype). tl.store handles the implicit cast when out is a different dtype.