core.inference.moe.activations#

Padding-aware activation kernels for fused MoE.

These kernels skip padding rows (where permutation_map == -1) to avoid wasted computation on aligned-but-empty expert slots.

Module Contents#

Functions#

_ceil_div

_squared_relu_kernel

Squared ReLU that skips rows beyond n_used and alignment-padding rows (perm_map == -1).

padded_squared_relu

Squared ReLU activation that skips rows beyond n_used and alignment-padding rows.

_squared_relu_quantize_kernel

Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.

squared_relu_and_quantize_mxfp8

Fused squared ReLU + MXFP8 quantize + swizzle.

API#

core.inference.moe.activations._ceil_div(a, b)#
core.inference.moe.activations._squared_relu_kernel(
input_ptr,
output_ptr,
src_idx_ptr,
n_used_ptr,
N,
max_rows,
BLOCK_N: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Squared ReLU that skips rows beyond n_used and alignment-padding rows (perm_map == -1).

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. n_used_ptr gates how many rows are processed — required for CUDA graph compatibility.

core.inference.moe.activations.padded_squared_relu(
x: torch.Tensor,
permutation_map: torch.Tensor,
n_used: torch.Tensor,
) torch.Tensor#

Squared ReLU activation that skips rows beyond n_used and alignment-padding rows.

Parameters:
  • x – [output_size, ffn_hidden] BF16 FC1 output.

  • permutation_map – [output_size] int32, original token index or -1 for padding.

  • n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1].

core.inference.moe.activations._squared_relu_quantize_kernel(
input_ptr,
out_fp8_ptr,
out_scale_ptr,
src_idx_ptr,
n_used_ptr,
K,
n_col_blocks,
max_rows,
REAL_GROUPS: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
BLOCK_GROUPS: triton.language.constexpr,
NUM_BLOCKS: triton.language.constexpr,
)#

Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.

Grid: fixed NUM_BLOCKS CTAs, each iterating over multiple rows. Rows beyond n_used and alignment-padding rows (perm_map == -1) are skipped.

core.inference.moe.activations.squared_relu_and_quantize_mxfp8(
x: torch.Tensor,
permutation_map: torch.Tensor,
n_used: torch.Tensor,
)#

Fused squared ReLU + MXFP8 quantize + swizzle.

Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize.

Parameters:
  • x – [output_size, K] BF16 FC1 output.

  • permutation_map – [output_size] int32, original token index or -1 for padding.

  • n_used – scalar int32 CUDA tensor = inclusive_expert_offsets[-1]. Rows beyond this are skipped before even checking the permutation_map.

Returns:

MXFP8Tensor with .data [output_size, K] float8_e4m3fn and .scale (swizzled e8m0).