core.inference.moe.activations#
Padding-aware activation kernels for fused MoE.
These kernels skip padding rows (where permutation_map == -1) to avoid wasted computation on aligned-but-empty expert slots.
Module Contents#
Functions#
Squared ReLU that skips padding rows (permutation_map == -1). |
|
Squared ReLU activation that skips padding rows. |
|
Fused squared ReLU + MXFP8 quantize + swizzle in one kernel. |
|
Fused squared ReLU + MXFP8 quantize + swizzle. |
API#
- core.inference.moe.activations._ceil_div(a, b)#
- core.inference.moe.activations._squared_relu_kernel(
- input_ptr,
- output_ptr,
- src_idx_ptr,
- M,
- N,
- BLOCK_N: triton.language.constexpr,
Squared ReLU that skips padding rows (permutation_map == -1).
- core.inference.moe.activations.padded_squared_relu(
- x: torch.Tensor,
- permutation_map: torch.Tensor,
Squared ReLU activation that skips padding rows.
- core.inference.moe.activations._squared_relu_quantize_kernel(
- input_ptr,
- out_fp8_ptr,
- out_scale_ptr,
- src_idx_ptr,
- K,
- n_col_blocks,
- skip_padding: triton.language.constexpr,
- REAL_GROUPS: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- BLOCK_GROUPS: triton.language.constexpr,
Fused squared ReLU + MXFP8 quantize + swizzle in one kernel.
Grid: (M,) — one program per row. Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8, writes FP8 data + swizzled scales in place.
- core.inference.moe.activations.squared_relu_and_quantize_mxfp8(
- x: torch.Tensor,
- permutation_map: torch.Tensor,
- skip_padding: bool = True,
Fused squared ReLU + MXFP8 quantize + swizzle.
Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize.
- Parameters:
x – [M, K] BF16 FC1 output.
permutation_map – [M] int32, original token index or -1 for padding.
skip_padding – if True, skip rows where permutation_map == -1.
- Returns:
MXFP8Tensor with .data [M, K] float8_e4m3fn and .scale (swizzled e8m0).