core.inference.moe.pad#
Pad / unpad utilities for already-permuted expert tokens.
When the token dispatcher has already permuted tokens into expert-grouped order, these functions insert/remove alignment padding so that each expert’s token block satisfies the alignment requirements of grouped_mm / scaled_grouped_mm.
Module Contents#
Functions#
Copy one input row into the padded output buffer. |
|
Pad already-permuted tokens so each expert’s block is aligned. |
|
Copy one real (non-padding) row from padded to unpadded layout. |
|
Remove alignment padding, scattering results back to original positions. |
API#
- core.inference.moe.pad._pad_tokens_kernel(
- src_ptr,
- dst_ptr,
- perm_map_ptr,
- tpe_ptr,
- hidden_dim,
- num_experts: triton.language.constexpr,
- alignment: triton.language.constexpr,
- BLOCK_H: triton.language.constexpr,
Copy one input row into the padded output buffer.
Computes unpadded and padded cumulative offsets inline from tokens_per_expert, avoiding a separate cumsum kernel launch.
- core.inference.moe.pad.pad_to_alignment(
- hidden_states: torch.Tensor,
- tokens_per_expert: torch.Tensor,
- alignment: int,
Pad already-permuted tokens so each expert’s block is aligned.
- Parameters:
hidden_states – [total_tokens, hidden_size] already permuted by dispatcher.
tokens_per_expert – [num_local_experts] int32 token counts.
alignment – per-expert alignment.
- Returns:
(padded_hidden, permutation_map, inclusive_offsets)
padded_hidden: [padded_total, hidden_size]
permutation_map: [padded_total] int32, original row index or -1 for padding.
inclusive_offsets: [num_local_experts] int32 cumulative aligned offsets for grouped_mm.
- core.inference.moe.pad._unpad_tokens_kernel(
- src_ptr,
- dst_ptr,
- perm_map_ptr,
- probs_ptr,
- hidden_dim,
- has_probs: triton.language.constexpr,
- BLOCK_H: triton.language.constexpr,
Copy one real (non-padding) row from padded to unpadded layout.
Optionally multiplies each row by its routing probability.
- core.inference.moe.pad.unpad_from_alignment(
- padded_output: torch.Tensor,
- permutation_map: torch.Tensor,
- original_size: int,
- probs: torch.Tensor = None,
Remove alignment padding, scattering results back to original positions.
- Parameters:
padded_output – [padded_total, hidden_size] output from expert computation.
permutation_map – [padded_total] int32, original row index or -1 for padding.
original_size – number of rows in the unpadded output.
probs – optional [original_size] routing probabilities to multiply during unpad.
- Returns:
[original_size, hidden_size] unpadded output.