core.inference.moe.activations#

Padding-aware activation kernels for fused MoE.

These kernels skip padding rows (where permutation_map == -1) to avoid wasted computation on aligned-but-empty expert slots.

Module Contents#

Functions#

_ceil_div

_squared_relu_kernel

Squared ReLU that skips padding rows (permutation_map == -1).

padded_squared_relu

Squared ReLU activation that skips padding rows.

_squared_relu_quantize_kernel

Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.

squared_relu_and_quantize_mxfp8

Fused squared ReLU + MXFP8 quantize + swizzle.

API#

core.inference.moe.activations._ceil_div(a, b)#
core.inference.moe.activations._squared_relu_kernel(
input_ptr,
output_ptr,
src_idx_ptr,
M,
N,
BLOCK_N: triton.language.constexpr,
)#

Squared ReLU that skips padding rows (permutation_map == -1).

core.inference.moe.activations.padded_squared_relu(
x: torch.Tensor,
permutation_map: torch.Tensor,
) torch.Tensor#

Squared ReLU activation that skips padding rows.

core.inference.moe.activations._squared_relu_quantize_kernel(
input_ptr,
out_fp8_ptr,
out_scale_ptr,
src_idx_ptr,
K,
n_col_blocks,
skip_padding: triton.language.constexpr,
REAL_GROUPS: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
BLOCK_GROUPS: triton.language.constexpr,
)#

Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.

Grid: (M,) — one program per row. Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8, writes FP8 data + swizzled scales in place.

core.inference.moe.activations.squared_relu_and_quantize_mxfp8(
x: torch.Tensor,
permutation_map: torch.Tensor,
skip_padding: bool = True,
)#

Fused squared ReLU + MXFP8 quantize + swizzle.

Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize.

Parameters:
  • x – [M, K] BF16 FC1 output.

  • permutation_map – [M] int32, original token index or -1 for padding.

  • skip_padding – if True, skip rows where permutation_map == -1.

Returns:

MXFP8Tensor with .data [M, K] float8_e4m3fn and .scale (swizzled e8m0).