core.inference.moe.permute#
Triton kernels for token permutation and unpermutation in fused MoE.
Includes:
Token counting per expert
Expert offset computation (aligned prefix sums)
Permute tokens into expert-grouped order
Unpermute expert outputs back to original token order
Module Contents#
Functions#
Count tokens routed to experts on this rank, ignoring tokens routed elsewhere. |
|
Count tokens routed to each local expert. |
|
Exclusive and inclusive prefix sums of aligned token counts. |
|
Compute exclusive and inclusive prefix sums of aligned token counts. |
|
Permute tokens into expert-grouped order. |
|
Permute tokens into expert-grouped order. |
|
Scatter weighted expert outputs back to original token positions. |
|
Unpermute expert outputs back to original token order. |
|
Fused permute + MXFP8 quantize + swizzle in one kernel. |
|
Fused permute + MXFP8 quantize + swizzle. |
API#
- core.inference.moe.permute._ceil_div(a, b)#
- core.inference.moe.permute._count_local_tokens_kernel(
- routing_map_ptr,
- tokens_per_expert_ptr,
- total_pairs,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.
Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to experts outside [local_expert_start, local_expert_start + num_local_experts) are silently skipped.
- core.inference.moe.permute.compute_local_tokens_per_expert(
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
Count tokens routed to each local expert.
- core.inference.moe.permute._prefix_sum_kernel(
- tokens_per_expert_ptr,
- exclusive_offsets_ptr,
- inclusive_offsets_ptr,
- num_local_experts,
- alignment: triton.language.constexpr,
- BLOCK_SIZE: triton.language.constexpr,
Exclusive and inclusive prefix sums of aligned token counts.
Each expert’s token count is rounded up to the nearest multiple of
alignment(experts with 0 tokens stay at 0). The inclusive offsets are used asoffsby grouped_mm / scaled_grouped_mm.
- core.inference.moe.permute.compute_expert_offsets(
- tokens_per_expert: torch.Tensor,
- alignment: int = 1,
Compute exclusive and inclusive prefix sums of aligned token counts.
- core.inference.moe.permute._permute_tokens_kernel(
- hidden_ptr,
- probs_ptr,
- routing_map_ptr,
- out_hidden_ptr,
- out_probs_ptr,
- out_src_idx_ptr,
- counters_ptr,
- num_tokens,
- hidden_dim,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- BLOCK_H: triton.language.constexpr,
Permute tokens into expert-grouped order.
Grid: one program per (token, topk) pair. Each program looks up the assigned expert, skips non-local experts, then atomically claims a position within that expert’s block and copies the hidden state + prob + source index.
- core.inference.moe.permute.permute_tokens(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- alignment: int = 1,
Permute tokens into expert-grouped order.
Computes token counts, aligned expert offsets, output sizing, and permutation in a single call.
- Parameters:
hidden_states – [num_tokens, hidden_size] input.
probs – [num_tokens, topk] routing probabilities.
routing_map – [num_tokens, topk] expert assignments.
local_expert_start – first global expert index on this rank.
num_local_experts – number of experts on this rank.
alignment – per-expert token alignment (default 1).
- Returns:
(permuted_hidden, permuted_probs, permutation_map, inclusive_offsets)
permuted_hidden: [output_size, hidden_size]
permuted_probs: [output_size]
permutation_map: [output_size] int32, maps each permuted row back to its original token index. Used by unpermute_tokens to scatter expert outputs back and by activation kernels to skip padding rows (-1).
inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm
- core.inference.moe.permute._unpermute_tokens_kernel(
- expert_out_ptr,
- probs_ptr,
- src_idx_ptr,
- output_ptr,
- hidden_dim,
- BLOCK_H: triton.language.constexpr,
Scatter weighted expert outputs back to original token positions.
Grid: one program per row of expert_out. Padding rows (src_idx == -1) are skipped. Multiple topk selections for the same token are accumulated via atomic adds. All arithmetic is in fp32 to avoid precision loss.
- core.inference.moe.permute.unpermute_tokens(
- expert_output: torch.Tensor,
- permuted_probs: torch.Tensor,
- permutation_map: torch.Tensor,
- num_tokens: int,
Unpermute expert outputs back to original token order.
Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. Returns fp32 output.
- core.inference.moe.permute._permute_quantize_mxfp8_kernel(
- hidden_ptr,
- probs_ptr,
- routing_map_ptr,
- out_fp8_ptr,
- out_scale_ptr,
- out_probs_ptr,
- out_src_idx_ptr,
- counters_ptr,
- num_tokens,
- K,
- n_col_blocks,
- topk: triton.language.constexpr,
- local_expert_start,
- num_local_experts: triton.language.constexpr,
- REAL_GROUPS: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- BLOCK_GROUPS: triton.language.constexpr,
Fused permute + MXFP8 quantize + swizzle in one kernel.
Grid: (num_tokens * topk,) — one program per (token, k) pair. Reads BF16 from source token, quantizes to FP8 e4m3, writes FP8 data + swizzled e8m0 scales to the permuted write position.
- core.inference.moe.permute.permute_and_quantize_mxfp8(
- hidden_states: torch.Tensor,
- probs: torch.Tensor,
- routing_map: torch.Tensor,
- local_expert_start: int,
- num_local_experts: int,
- alignment: int = 128,
Fused permute + MXFP8 quantize + swizzle.
Self-contained API matching permute_tokens: computes token counts, aligned expert offsets, output sizing, permutation, and MXFP8 quantization in a single kernel launch.
- Parameters:
hidden_states – [num_tokens, hidden_size] BF16 input.
probs – [num_tokens, topk] routing probabilities.
routing_map – [num_tokens, topk] expert assignments.
local_expert_start – first global expert index on this rank.
num_local_experts – number of experts on this rank.
alignment – per-expert token alignment (default 128, required for MXFP8 swizzle).
- Returns:
(permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets)
permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled)
permuted_probs: [output_size] routing probs
permutation_map: [output_size] int32, original token index or -1 for padding
inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm