core.inference.moe.pad#

Pad / unpad utilities for already-permuted expert tokens.

When the token dispatcher has already permuted tokens into expert-grouped order, these functions insert/remove alignment padding so that each expert’s token block satisfies the alignment requirements of grouped_mm / scaled_grouped_mm.

Module Contents#

Functions#

_pad_tokens_kernel

Copy one input row into the padded output buffer.

pad_to_alignment

Pad already-permuted tokens so each expert’s block is aligned.

_unpad_tokens_kernel

Copy one real (non-padding) row from padded to unpadded layout.

unpad_from_alignment

Remove alignment padding, scattering results back to original positions.

API#

core.inference.moe.pad._pad_tokens_kernel(
src_ptr,
dst_ptr,
perm_map_ptr,
tpe_ptr,
hidden_dim,
num_experts: triton.language.constexpr,
alignment: triton.language.constexpr,
BLOCK_H: triton.language.constexpr,
)#

Copy one input row into the padded output buffer.

Computes unpadded and padded cumulative offsets inline from tokens_per_expert, avoiding a separate cumsum kernel launch.

core.inference.moe.pad.pad_to_alignment(
hidden_states: torch.Tensor,
tokens_per_expert: torch.Tensor,
alignment: int,
) tuple#

Pad already-permuted tokens so each expert’s block is aligned.

Parameters:
  • hidden_states – [total_tokens, hidden_size] already permuted by dispatcher.

  • tokens_per_expert – [num_local_experts] int32 token counts.

  • alignment – per-expert alignment.

Returns:

(padded_hidden, permutation_map, inclusive_offsets)

  • padded_hidden: [padded_total, hidden_size]

  • permutation_map: [padded_total] int32, original row index or -1 for padding.

  • inclusive_offsets: [num_local_experts] int32 cumulative aligned offsets for grouped_mm.

core.inference.moe.pad._unpad_tokens_kernel(
src_ptr,
dst_ptr,
perm_map_ptr,
probs_ptr,
hidden_dim,
has_probs: triton.language.constexpr,
BLOCK_H: triton.language.constexpr,
)#

Copy one real (non-padding) row from padded to unpadded layout.

Optionally multiplies each row by its routing probability.

core.inference.moe.pad.unpad_from_alignment(
padded_output: torch.Tensor,
permutation_map: torch.Tensor,
original_size: int,
probs: torch.Tensor = None,
) torch.Tensor#

Remove alignment padding, scattering results back to original positions.

Parameters:
  • padded_output – [padded_total, hidden_size] output from expert computation.

  • permutation_map – [padded_total] int32, original row index or -1 for padding.

  • original_size – number of rows in the unpadded output.

  • probs – optional [original_size] routing probabilities to multiply during unpad.

Returns:

[original_size, hidden_size] unpadded output.