core.inference.moe.permute#

Triton kernels for token permutation and unpermutation in fused MoE.

Includes:

  • Token counting per expert

  • Expert offset computation (aligned prefix sums)

  • Permute tokens into expert-grouped order

  • Unpermute expert outputs back to original token order

Module Contents#

Functions#

_ceil_div

_count_local_tokens_kernel

Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.

compute_local_tokens_per_expert

Count tokens routed to each local expert.

_prefix_sum_kernel

Exclusive and inclusive prefix sums of aligned token counts.

compute_expert_offsets

Compute exclusive and inclusive prefix sums of aligned token counts.

_permute_tokens_kernel

Permute tokens into expert-grouped order.

permute_tokens

Permute tokens into expert-grouped order.

_unpermute_tokens_kernel

Scatter weighted expert outputs back to original token positions.

unpermute_tokens

Unpermute expert outputs back to original token order.

_permute_quantize_mxfp8_kernel

Fused permute + MXFP8 quantize + swizzle in one kernel.

permute_and_quantize_mxfp8

Fused permute + MXFP8 quantize + swizzle.

API#

core.inference.moe.permute._ceil_div(a, b)#
core.inference.moe.permute._count_local_tokens_kernel(
routing_map_ptr,
tokens_per_expert_ptr,
total_pairs,
local_expert_start,
num_local_experts: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.

Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to experts outside [local_expert_start, local_expert_start + num_local_experts) are silently skipped.

core.inference.moe.permute.compute_local_tokens_per_expert(
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
) torch.Tensor#

Count tokens routed to each local expert.

core.inference.moe.permute._prefix_sum_kernel(
tokens_per_expert_ptr,
exclusive_offsets_ptr,
inclusive_offsets_ptr,
num_local_experts,
alignment: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Exclusive and inclusive prefix sums of aligned token counts.

Each expert’s token count is rounded up to the nearest multiple of alignment (experts with 0 tokens stay at 0). The inclusive offsets are used as offs by grouped_mm / scaled_grouped_mm.

core.inference.moe.permute.compute_expert_offsets(
tokens_per_expert: torch.Tensor,
alignment: int = 1,
) tuple#

Compute exclusive and inclusive prefix sums of aligned token counts.

core.inference.moe.permute._permute_tokens_kernel(
hidden_ptr,
probs_ptr,
routing_map_ptr,
out_hidden_ptr,
out_probs_ptr,
out_src_idx_ptr,
counters_ptr,
num_tokens,
hidden_dim,
topk: triton.language.constexpr,
local_expert_start,
num_local_experts: triton.language.constexpr,
BLOCK_H: triton.language.constexpr,
)#

Permute tokens into expert-grouped order.

Grid: one program per (token, topk) pair. Each program looks up the assigned expert, skips non-local experts, then atomically claims a position within that expert’s block and copies the hidden state + prob + source index.

core.inference.moe.permute.permute_tokens(
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
alignment: int = 1,
) tuple#

Permute tokens into expert-grouped order.

Computes token counts, aligned expert offsets, output sizing, and permutation in a single call.

Parameters:
  • hidden_states – [num_tokens, hidden_size] input.

  • probs – [num_tokens, topk] routing probabilities.

  • routing_map – [num_tokens, topk] expert assignments.

  • local_expert_start – first global expert index on this rank.

  • num_local_experts – number of experts on this rank.

  • alignment – per-expert token alignment (default 1).

Returns:

(permuted_hidden, permuted_probs, permutation_map, inclusive_offsets)

  • permuted_hidden: [output_size, hidden_size]

  • permuted_probs: [output_size]

  • permutation_map: [output_size] int32, maps each permuted row back to its original token index. Used by unpermute_tokens to scatter expert outputs back and by activation kernels to skip padding rows (-1).

  • inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm

core.inference.moe.permute._unpermute_tokens_kernel(
expert_out_ptr,
probs_ptr,
src_idx_ptr,
output_ptr,
hidden_dim,
BLOCK_H: triton.language.constexpr,
)#

Scatter weighted expert outputs back to original token positions.

Grid: one program per row of expert_out. Padding rows (src_idx == -1) are skipped. Multiple topk selections for the same token are accumulated via atomic adds. All arithmetic is in fp32 to avoid precision loss.

core.inference.moe.permute.unpermute_tokens(
expert_output: torch.Tensor,
permuted_probs: torch.Tensor,
permutation_map: torch.Tensor,
num_tokens: int,
) torch.Tensor#

Unpermute expert outputs back to original token order.

Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. Returns fp32 output.

core.inference.moe.permute._permute_quantize_mxfp8_kernel(
hidden_ptr,
probs_ptr,
routing_map_ptr,
out_fp8_ptr,
out_scale_ptr,
out_probs_ptr,
out_src_idx_ptr,
counters_ptr,
num_tokens,
K,
n_col_blocks,
topk: triton.language.constexpr,
local_expert_start,
num_local_experts: triton.language.constexpr,
REAL_GROUPS: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
BLOCK_GROUPS: triton.language.constexpr,
)#

Fused permute + MXFP8 quantize + swizzle in one kernel.

Grid: (num_tokens * topk,) — one program per (token, k) pair. Reads BF16 from source token, quantizes to FP8 e4m3, writes FP8 data + swizzled e8m0 scales to the permuted write position.

core.inference.moe.permute.permute_and_quantize_mxfp8(
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
alignment: int = 128,
) tuple#

Fused permute + MXFP8 quantize + swizzle.

Self-contained API matching permute_tokens: computes token counts, aligned expert offsets, output sizing, permutation, and MXFP8 quantization in a single kernel launch.

Parameters:
  • hidden_states – [num_tokens, hidden_size] BF16 input.

  • probs – [num_tokens, topk] routing probabilities.

  • routing_map – [num_tokens, topk] expert assignments.

  • local_expert_start – first global expert index on this rank.

  • num_local_experts – number of experts on this rank.

  • alignment – per-expert token alignment (default 128, required for MXFP8 swizzle).

Returns:

(permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets)

  • permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled)

  • permuted_probs: [output_size] routing probs

  • permutation_map: [output_size] int32, original token index or -1 for padding

  • inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm