core.inference.moe.permute#

Triton kernels for token permutation and unpermutation in fused MoE.

Includes:

  • Token counting per expert

  • Expert offset computation (aligned prefix sums)

  • Permute tokens into expert-grouped order

  • Unpermute expert outputs back to original token order

Module Contents#

Functions#

_get_num_sms

_ceil_div

_count_local_tokens_kernel

Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.

_count_local_tokens_kernel_persistent

Count tokens routed to local experts using a persistent grid.

compute_local_tokens_per_expert

Count tokens routed to each local expert.

_prefix_sum_kernel

Exclusive and inclusive prefix sums of aligned token counts.

_init_permutation_map_kernel

Initialize permutation_map entries to -1 up to n_used rows.

init_permutation_map

Fill permutation_map[0:n_used] with -1.

compute_expert_offsets

Compute exclusive and inclusive prefix sums of aligned token counts.

_permute_tokens_kernel

Permute tokens into expert-grouped order.

permute_tokens

Permute tokens into expert-grouped order.

_zero_output_rows_kernel

Zero rows [0, valid_tokens) of the fp32 output buffer.

_unpermute_tokens_kernel

Scatter weighted expert outputs back to original token positions.

unpermute_tokens

Unpermute expert outputs back to original token order.

_permute_quantize_mxfp8_kernel

Fused permute + MXFP8 quantize + swizzle in one kernel.

permute_and_quantize_mxfp8

Fused permute + MXFP8 quantize + swizzle.

Data#

API#

core.inference.moe.permute._NUM_SMS: Optional[int]#

None

core.inference.moe.permute._get_num_sms(device: torch.device) int#
core.inference.moe.permute._ceil_div(a, b)#
core.inference.moe.permute._count_local_tokens_kernel(
routing_map_ptr,
tokens_per_expert_ptr,
valid_tokens_ptr,
topk,
local_expert_start,
num_local_experts: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Count tokens routed to experts on this rank, ignoring tokens routed elsewhere.

Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to experts outside [local_expert_start, local_expert_start + num_local_experts) or beyond valid_tokens are silently skipped.

Grid is launched at max size (max_tokens * topk); valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility.

core.inference.moe.permute._count_local_tokens_kernel_persistent(
routing_map_ptr,
tokens_per_expert_ptr,
valid_tokens_ptr,
topk,
local_expert_start,
num_local_experts: triton.language.constexpr,
num_sms,
BLOCK_SIZE: triton.language.constexpr,
)#

Count tokens routed to local experts using a persistent grid.

Launches num_sms CTAs. Each CTA loops over its share of BLOCK_SIZE-sized chunks, with total work determined device-side from valid_tokens.

core.inference.moe.permute.compute_local_tokens_per_expert(
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
valid_tokens: torch.Tensor,
persistent: bool = False,
) torch.Tensor#

Count tokens routed to each local expert.

Parameters:
  • routing_map – [max_tokens, topk] expert assignments. Only the first valid_tokens rows are processed; the rest are ignored.

  • local_expert_start – first global expert index on this rank.

  • num_local_experts – number of experts on this rank.

  • valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.

  • persistent – use persistent-grid kernel variant (fewer CTAs, looped).

core.inference.moe.permute._prefix_sum_kernel(
tokens_per_expert_ptr,
exclusive_offsets_ptr,
inclusive_offsets_ptr,
num_local_experts,
alignment: triton.language.constexpr,
BLOCK_SIZE: triton.language.constexpr,
)#

Exclusive and inclusive prefix sums of aligned token counts.

Each expert’s token count is rounded up to the nearest multiple of alignment (experts with 0 tokens stay at 0). The inclusive offsets are used as offs by grouped_mm / scaled_grouped_mm.

core.inference.moe.permute._init_permutation_map_kernel(
perm_map_ptr,
n_used_ptr,
BLOCK_SIZE: triton.language.constexpr,
)#

Initialize permutation_map entries to -1 up to n_used rows.

Grid is launched at max size; entries beyond n_used are left untouched — the activation and unpermute kernels are gated by the same n_used pointer so they never read those entries.

core.inference.moe.permute.init_permutation_map(
permutation_map: torch.Tensor,
n_used: torch.Tensor,
) None#

Fill permutation_map[0:n_used] with -1.

Parameters:
  • permutation_map – [output_size] int32 buffer (pre-allocated at max size).

  • n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1].

core.inference.moe.permute.compute_expert_offsets(
tokens_per_expert: torch.Tensor,
alignment: int = 1,
) tuple#

Compute exclusive and inclusive prefix sums of aligned token counts.

core.inference.moe.permute._permute_tokens_kernel(
hidden_ptr,
probs_ptr,
routing_map_ptr,
out_hidden_ptr,
out_probs_ptr,
out_src_idx_ptr,
counters_ptr,
valid_tokens_ptr,
hidden_dim,
max_pairs,
topk: triton.language.constexpr,
local_expert_start,
num_local_experts: triton.language.constexpr,
BLOCK_H: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Permute tokens into expert-grouped order.

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility since the grid size never changes across steps.

core.inference.moe.permute.permute_tokens(
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
valid_tokens: torch.Tensor,
alignment: int = 1,
) tuple#

Permute tokens into expert-grouped order.

Computes token counts, aligned expert offsets, output sizing, and permutation in a single call.

Parameters:
  • hidden_states – [max_tokens, hidden_size] input. Only the first valid_tokens rows are valid; the rest are ignored.

  • probs – [max_tokens, topk] routing probabilities.

  • routing_map – [max_tokens, topk] expert assignments.

  • local_expert_start – first global expert index on this rank.

  • num_local_experts – number of experts on this rank.

  • valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.

  • alignment – per-expert token alignment (default 1).

Returns:

(permuted_hidden, permuted_probs, permutation_map, inclusive_offsets)

  • permuted_hidden: [output_size, hidden_size]

  • permuted_probs: [output_size]

  • permutation_map: [output_size] int32, maps each permuted row back to its original token index. Used by unpermute_tokens to scatter expert outputs back and by activation kernels to skip padding rows (-1).

  • inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm

core.inference.moe.permute._zero_output_rows_kernel(
output_ptr,
valid_tokens_ptr,
hidden_dim,
num_tokens,
BLOCK_H: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Zero rows [0, valid_tokens) of the fp32 output buffer.

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. valid_tokens gates which rows are zeroed — required for CUDA graph compatibility.

core.inference.moe.permute._unpermute_tokens_kernel(
expert_out_ptr,
probs_ptr,
src_idx_ptr,
output_ptr,
n_used_ptr,
hidden_dim,
max_rows,
BLOCK_H: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Scatter weighted expert outputs back to original token positions.

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. Rows beyond n_used and alignment-padding rows (src_idx == -1) are skipped. Multiple topk selections for the same token are accumulated via atomic adds. All arithmetic is in fp32 to avoid precision loss.

core.inference.moe.permute.unpermute_tokens(
expert_output: torch.Tensor,
permuted_probs: torch.Tensor,
permutation_map: torch.Tensor,
num_tokens: int,
n_used: torch.Tensor,
valid_tokens: torch.Tensor,
out: torch.Tensor = None,
) torch.Tensor#

Unpermute expert outputs back to original token order.

Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. Returns fp32 output.

Parameters:
  • expert_output – [output_size, hidden_dim] expert outputs in permuted order.

  • permuted_probs – [output_size] fp32 routing probabilities.

  • permutation_map – [output_size] int32, original token index or -1 for padding.

  • num_tokens – max token count (output buffer height); always fixed for CG.

  • n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows beyond this are skipped without reading permutation_map.

  • valid_tokens – scalar int32 CUDA tensor = number of valid input tokens. Only rows [0, valid_tokens) are zeroed; all atomic_adds target source_idx < valid_tokens so rows beyond are never written.

  • out – optional pre-allocated [num_tokens, hidden_dim] fp32 output buffer. Pass a symmetric memory tensor to scatter directly into it, avoiding a separate copy before RSV. If None, a local buffer is allocated.

core.inference.moe.permute._permute_quantize_mxfp8_kernel(
hidden_ptr,
probs_ptr,
routing_map_ptr,
out_fp8_ptr,
out_scale_ptr,
out_probs_ptr,
out_src_idx_ptr,
counters_ptr,
valid_tokens_ptr,
K,
n_col_blocks,
max_pairs,
topk: triton.language.constexpr,
local_expert_start,
num_local_experts: triton.language.constexpr,
REAL_GROUPS: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
BLOCK_GROUPS: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Fused permute + MXFP8 quantize + swizzle in one kernel.

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple (token, topk) pairs. valid_tokens gates which pairs are actually processed — required for CUDA graph compatibility since the grid size never changes across steps.

core.inference.moe.permute.permute_and_quantize_mxfp8(
hidden_states: torch.Tensor,
probs: torch.Tensor,
routing_map: torch.Tensor,
local_expert_start: int,
num_local_experts: int,
valid_tokens: torch.Tensor,
alignment: int = 128,
) tuple#

Fused permute + MXFP8 quantize + swizzle.

Self-contained API matching permute_tokens: computes token counts, aligned expert offsets, output sizing, permutation, and MXFP8 quantization in a single kernel launch.

Parameters:
  • hidden_states – [max_tokens, hidden_size] BF16 input. Only the first valid_tokens rows are valid; the rest are ignored.

  • probs – [max_tokens, topk] routing probabilities.

  • routing_map – [max_tokens, topk] expert assignments.

  • local_expert_start – first global expert index on this rank.

  • num_local_experts – number of experts on this rank.

  • valid_tokens – scalar int32 CUDA tensor with the number of valid tokens this iteration. Fixed address; value updated each step before graph replay.

  • alignment – per-expert token alignment (default 128, required for MXFP8 swizzle).

Returns:

(permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets)

  • permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled)

  • permuted_probs: [output_size] routing probs

  • permutation_map: [output_size] int32, original token index or -1 for padding

  • inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm