core.inference.quantization.mxfp8_quantize#

Standalone MXFP8 quantization kernel with fused scale swizzle.

One block per token. Quantizes BF16 → FP8 e4m3 and writes scales directly in cuBLAS 2D blocked (swizzled) layout. No FP4, no triton_kernels dependency.

Usage: from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize data, swizzled_scales, total_scale_bytes = mxfp8_quantize(x_bf16) # data: [M, K] float8_e4m3fn # swizzled_scales: 1D uint8 in cuBLAS blocked layout

Module Contents#

Functions#

_ceil_div

_mxfp8_quant_swizzle_kernel

Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout.

mxfp8_quantize

Quantize a 2D tensor to MXFP8 with fused scale swizzle.

API#

core.inference.quantization.mxfp8_quantize._ceil_div(a, b)#
core.inference.quantization.mxfp8_quantize._mxfp8_quant_swizzle_kernel(
out_ptr,
scale_ptr,
src_ptr,
K,
n_col_blocks,
REAL_GROUPS: triton.language.constexpr,
BLOCK_K: triton.language.constexpr,
BLOCK_GROUPS: triton.language.constexpr,
)#

Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout.

We use round up in scale calculation. see: Mishra et al., Recipes for Pre-training LLMs with MXFP8 (https://arxiv.org/pdf/2506.08027)

The implementation borrows code from the triton upstream MXFP downcast kernel: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py

Note on swizzled scale layout (torch.nn.functional.SwizzleType.SWIZZLE_32_4_4):

Background: In MXFP8, every group of 32 elements shares one 1-byte scale
(an e8m0 exponent). For an [M, K] matrix, this gives an [M, K//32] scale
matrix. cuBLAS doesn't read these scales in simple row-major order — it
expects a "swizzled" layout optimized for its internal access patterns.

Step 1 — Divide into macro-tiles:
    The scale matrix is partitioned into 128-row x 4-col macro-tiles.
    Each tile is stored as a contiguous 512-byte (128 x 4) block.

Step 2 — Interleave within each tile:
    Within a macro-tile, the 128 rows are NOT stored sequentially.
    Instead, they are split into 4 groups of 32 rows:
        group 0: rows   0- 31
        group 1: rows  32- 63
        group 2: rows  64- 95
        group 3: rows  96-127

    Rows with the same position within their group (same "sub_row")
    are placed next to each other. So the memory layout is:

    Concretely, for sub_row=0:
        byte 0:  row  0, col 0
        byte 1:  row  0, col 1
        byte 2:  row  0, col 2
        byte 3:  row  0, col 3
        byte 4:  row 32, col 0
        byte 5:  row 32, col 1
        byte 6:  row 32, col 2
        byte 7:  row 32, col 3
        byte 8:  row 64, col 0
        ...
        byte 15: row 96, col 3

The formula to map logical (row, col) → byte offset:
    tile_idx = (row // 128) * n_col_blocks + (col // 4)
    sub_row  = row % 32
    group    = (row % 128) // 32
    local_col = col % 4
    offset   = tile_idx * 512 + sub_row * 16 + group * 4 + local_col
core.inference.quantization.mxfp8_quantize.mxfp8_quantize(x: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#

Quantize a 2D tensor to MXFP8 with fused scale swizzle.

Parameters:

x – [M, K] tensor in bf16/fp16/fp32. K must be divisible by 32.

Returns:

data: [M, K] float8_e4m3fn swizzled_scales: 1D tensor in cuBLAS blocked layout (uint8/e8m0)

Return type:

(data, swizzled_scales)