core.inference.moe.permute#
Triton kernels for token permutation and unpermutation in fused MoE.
Includes:
Token counting per expert
Expert offset computation (aligned prefix sums)
Permute tokens into expert-grouped order
Unpermute expert outputs back to original token order
Module Contents#
Functions#
Count tokens routed to experts on this rank, ignoring tokens routed elsewhere. |
|
Count tokens routed to local experts using a persistent grid. |
|
Count tokens routed to each local expert. |
|
Exclusive and inclusive prefix sums of aligned token counts. |
|
Initialize permutation_map entries to -1 up to n_used rows. |
|
Fill permutation_map[0:n_used] with -1. |
|
Compute exclusive and inclusive prefix sums of aligned token counts. |
|
Permute tokens into expert-grouped order. |
|
Permute tokens into expert-grouped order. |
|
Zero rows [0, valid_tokens) of the fp32 output buffer. |
|
Scatter weighted expert outputs back to original token positions. |
|
Unpermute expert outputs back to original token order. |
|
Fused permute + MXFP8 quantize + swizzle in one kernel. |
|
Fused permute + MXFP8 quantize + swizzle. |
Data#
API#
- core.inference.moe.permute._NUM_SMS: Optional[int]#
None
- core.inference.moe.permute._get_num_sms(device: torch.device) int#
- core.inference.moe.permute._ceil_div(a, b)#
- core.inference.moe.permute._count_local_tokens_kernel(
- routing_map_ptr,
- tokens_per_expert_ptr,
- valid_tokens_ptr,
- topk,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.
Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to experts outside [local_expert_start, local_expert_start + num_local_experts) or beyond valid_tokens are silently skipped.
Grid is launched at max size (max_tokens * topk); valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility.
- core.inference.moe.permute._count_local_tokens_kernel_persistent(
- routing_map_ptr,
- tokens_per_expert_ptr,
- valid_tokens_ptr,
- topk,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- num_sms,
- BLOCK_SIZE: triton.language.constexpr,
Count tokens routed to local experts using a persistent grid.
Launches num_sms CTAs. Each CTA loops over its share of BLOCK_SIZE-sized chunks, with total work determined device-side from valid_tokens.
- core.inference.moe.permute.compute_local_tokens_per_expert(
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- valid_tokens: torch.Tensor,
- persistent: bool = False,
Count tokens routed to each local expert.
- Parameters:
routing_map – [max_tokens, topk] expert assignments. Only the first valid_tokens rows are processed; the rest are ignored.
local_expert_start – first global expert index on this rank.
num_local_experts – number of experts on this rank.
valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.
persistent – use persistent-grid kernel variant (fewer CTAs, looped).
- core.inference.moe.permute._prefix_sum_kernel(
- tokens_per_expert_ptr,
- exclusive_offsets_ptr,
- inclusive_offsets_ptr,
- num_local_experts,
- alignment: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Exclusive and inclusive prefix sums of aligned token counts.
Each expert’s token count is rounded up to the nearest multiple of
alignment(experts with 0 tokens stay at 0). The inclusive offsets are used asoffsby grouped_mm / scaled_grouped_mm.
- core.inference.moe.permute._init_permutation_map_kernel(
- perm_map_ptr,
- n_used_ptr,
- BLOCK_SIZE: triton.language.constexpr,
Initialize permutation_map entries to -1 up to n_used rows.
Grid is launched at max size; entries beyond n_used are left untouched — the activation and unpermute kernels are gated by the same n_used pointer so they never read those entries.
- core.inference.moe.permute.init_permutation_map(
- permutation_map: torch.Tensor,
- n_used: torch.Tensor,
Fill permutation_map[0:n_used] with -1.
- Parameters:
permutation_map – [output_size] int32 buffer (pre-allocated at max size).
n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1].
- core.inference.moe.permute.compute_expert_offsets(
- tokens_per_expert: torch.Tensor,
- alignment: int = 1,
Compute exclusive and inclusive prefix sums of aligned token counts.
- core.inference.moe.permute._permute_tokens_kernel(
- hidden_ptr,
- probs_ptr,
- routing_map_ptr,
- out_hidden_ptr,
- out_probs_ptr,
- out_src_idx_ptr,
- counters_ptr,
- valid_tokens_ptr,
- hidden_dim,
- max_pairs,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- BLOCK_H: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Permute tokens into expert-grouped order.
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility since the grid size never changes across steps.
- core.inference.moe.permute.permute_tokens(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- valid_tokens: torch.Tensor,
- alignment: int = 1,
Permute tokens into expert-grouped order.
Computes token counts, aligned expert offsets, output sizing, and permutation in a single call.
- Parameters:
hidden_states – [max_tokens, hidden_size] input. Only the first valid_tokens rows are valid; the rest are ignored.
probs – [max_tokens, topk] routing probabilities.
routing_map – [max_tokens, topk] expert assignments.
local_expert_start – first global expert index on this rank.
num_local_experts – number of experts on this rank.
valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.
alignment – per-expert token alignment (default 1).
- Returns:
(permuted_hidden, permuted_probs, permutation_map, inclusive_offsets)
permuted_hidden: [output_size, hidden_size]
permuted_probs: [output_size]
permutation_map: [output_size] int32, maps each permuted row back to its original token index. Used by unpermute_tokens to scatter expert outputs back and by activation kernels to skip padding rows (-1).
inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm
- core.inference.moe.permute._zero_output_rows_kernel(
- output_ptr,
- valid_tokens_ptr,
- hidden_dim,
- num_tokens,
- BLOCK_H: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Zero rows [0, valid_tokens) of the fp32 output buffer.
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. valid_tokens gates which rows are zeroed — required for CUDA graph compatibility.
- core.inference.moe.permute._unpermute_tokens_kernel(
- expert_out_ptr,
- probs_ptr,
- src_idx_ptr,
- output_ptr,
- n_used_ptr,
- hidden_dim,
- max_rows,
- BLOCK_H: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Scatter weighted expert outputs back to original token positions.
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. Rows beyond n_used and alignment-padding rows (src_idx == -1) are skipped. Multiple topk selections for the same token are accumulated via atomic adds. All arithmetic is in fp32 to avoid precision loss.
- core.inference.moe.permute.unpermute_tokens(
- expert_output: torch.Tensor,
- permuted_probs: torch.Tensor,
- permutation_map: torch.Tensor,
- num_tokens: int,
- n_used: torch.Tensor,
- valid_tokens: torch.Tensor,
- out: torch.Tensor = None,
Unpermute expert outputs back to original token order.
Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. Returns fp32 output.
- Parameters:
expert_output – [output_size, hidden_dim] expert outputs in permuted order.
permuted_probs – [output_size] fp32 routing probabilities.
permutation_map – [output_size] int32, original token index or -1 for padding.
num_tokens – max token count (output buffer height); always fixed for CG.
n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows beyond this are skipped without reading permutation_map.
valid_tokens – scalar int32 CUDA tensor = number of valid input tokens. Only rows [0, valid_tokens) are zeroed; all atomic_adds target source_idx < valid_tokens so rows beyond are never written.
out – optional pre-allocated [num_tokens, hidden_dim] fp32 output buffer. Pass a symmetric memory tensor to scatter directly into it, avoiding a separate copy before RSV. If None, a local buffer is allocated.
- core.inference.moe.permute._permute_quantize_mxfp8_kernel(
- hidden_ptr,
- probs_ptr,
- routing_map_ptr,
- out_fp8_ptr,
- out_scale_ptr,
- out_probs_ptr,
- out_src_idx_ptr,
- counters_ptr,
- valid_tokens_ptr,
- K,
- n_col_blocks,
- max_pairs,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- REAL_GROUPS: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- BLOCK_GROUPS: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Fused permute + MXFP8 quantize + swizzle in one kernel.
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility since the grid size never changes across steps.
- core.inference.moe.permute.permute_and_quantize_mxfp8(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- valid_tokens: torch.Tensor,
- alignment: int = 128,
Fused permute + MXFP8 quantize + swizzle.
Self-contained API matching permute_tokens: computes token counts, aligned expert offsets, output sizing, permutation, and MXFP8 quantization in a single kernel launch.
- Parameters:
hidden_states – [max_tokens, hidden_size] BF16 input. Only the first valid_tokens rows are valid; the rest are ignored.
probs – [max_tokens, topk] routing probabilities.
routing_map – [max_tokens, topk] expert assignments.
local_expert_start – first global expert index on this rank.
num_local_experts – number of experts on this rank.
valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.
alignment – per-expert token alignment (default 128, required for MXFP8 swizzle).
- Returns:
(permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets)
permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled)
permuted_probs: [output_size] routing probs
permutation_map: [output_size] int32, original token index or -1 for padding
inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm