core.inference.moe.activations#
Padding-aware activation kernels for fused MoE.
These kernels skip padding rows (where permutation_map == -1) to avoid wasted computation on aligned-but-empty expert slots.
Module Contents#
Functions#
Squared ReLU that skips rows beyond n_used and alignment-padding rows (perm_map == -1). |
|
Squared ReLU activation that skips rows beyond n_used and alignment-padding rows. |
|
Fused squared ReLU + MXFP8 quantize + swizzle in one kernel. |
|
Fused squared ReLU + MXFP8 quantize + swizzle. |
API#
- core.inference.moe.activations._ceil_div(a, b)#
- core.inference.moe.activations._squared_relu_kernel(
- input_ptr,
- output_ptr,
- src_idx_ptr,
- n_used_ptr,
- N,
- max_rows,
- BLOCK_N: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Squared ReLU that skips rows beyond n_used and alignment-padding rows (perm_map == -1).
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. n_used_ptr gates how many rows are processed — required for CUDA graph compatibility.
- core.inference.moe.activations.padded_squared_relu(
- x: torch.Tensor,
- permutation_map: torch.Tensor,
- n_used: torch.Tensor,
Squared ReLU activation that skips rows beyond n_used and alignment-padding rows.
- Parameters:
x – [output_size, ffn_hidden] BF16 FC1 output.
permutation_map – [output_size] int32, original token index or -1 for padding.
n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1].
- core.inference.moe.activations._squared_relu_quantize_kernel(
- input_ptr,
- out_fp8_ptr,
- out_scale_ptr,
- src_idx_ptr,
- n_used_ptr,
- K,
- n_col_blocks,
- max_rows,
- REAL_GROUPS: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- BLOCK_GROUPS: triton.language.constexpr,
- NUM_BLOCKS: triton.language.constexpr,
Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.
Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. Rows beyond n_used and alignment-padding rows (perm_map == -1) are skipped.
- core.inference.moe.activations.squared_relu_and_quantize_mxfp8(
- x: torch.Tensor,
- permutation_map: torch.Tensor,
- n_used: torch.Tensor,
Fused squared ReLU + MXFP8 quantize + swizzle.
Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize.
- Parameters:
x – [output_size, K] BF16 FC1 output.
permutation_map – [output_size] int32, original token index or -1 for padding.
n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows beyond this are skipped before even checking the permutation_map.
- Returns:
MXFP8Tensor with .data [output_size, K] float8_e4m3fn and .scale (swizzled e8m0).