core.inference.quantization.mxfp8_quantize#
Standalone MXFP8 quantization kernel with fused scale swizzle.
One block per token. Quantizes BF16 → FP8 e4m3 and writes scales directly in cuBLAS 2D blocked (swizzled) layout. No FP4, no triton_kernels dependency.
Usage: from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize data, swizzled_scales, total_scale_bytes = mxfp8_quantize(x_bf16) # data: [M, K] float8_e4m3fn # swizzled_scales: 1D uint8 in cuBLAS blocked layout
Module Contents#
Functions#
Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout. |
|
Quantize a 2D tensor to MXFP8 with fused scale swizzle. |
API#
- core.inference.quantization.mxfp8_quantize._ceil_div(a, b)#
- core.inference.quantization.mxfp8_quantize._mxfp8_quant_swizzle_kernel(
- out_ptr,
- scale_ptr,
- src_ptr,
- K,
- n_col_blocks,
- REAL_GROUPS: triton.language.constexpr,
- BLOCK_K: triton.language.constexpr,
- BLOCK_GROUPS: triton.language.constexpr,
Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout.
We use round up in scale calculation. see: Mishra et al., Recipes for Pre-training LLMs with MXFP8 (https://arxiv.org/pdf/2506.08027)
The implementation borrows code from the triton upstream MXFP downcast kernel: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py
Note on swizzled scale layout (torch.nn.functional.SwizzleType.SWIZZLE_32_4_4):
Background: In MXFP8, every group of 32 elements shares one 1-byte scale (an e8m0 exponent). For an [M, K] matrix, this gives an [M, K//32] scale matrix. cuBLAS doesn't read these scales in simple row-major order — it expects a "swizzled" layout optimized for its internal access patterns. Step 1 — Divide into macro-tiles: The scale matrix is partitioned into 128-row x 4-col macro-tiles. Each tile is stored as a contiguous 512-byte (128 x 4) block. Step 2 — Interleave within each tile: Within a macro-tile, the 128 rows are NOT stored sequentially. Instead, they are split into 4 groups of 32 rows: group 0: rows 0- 31 group 1: rows 32- 63 group 2: rows 64- 95 group 3: rows 96-127 Rows with the same position within their group (same "sub_row") are placed next to each other. So the memory layout is: Concretely, for sub_row=0: byte 0: row 0, col 0 byte 1: row 0, col 1 byte 2: row 0, col 2 byte 3: row 0, col 3 byte 4: row 32, col 0 byte 5: row 32, col 1 byte 6: row 32, col 2 byte 7: row 32, col 3 byte 8: row 64, col 0 ... byte 15: row 96, col 3 The formula to map logical (row, col) → byte offset: tile_idx = (row // 128) * n_col_blocks + (col // 4) sub_row = row % 32 group = (row % 128) // 32 local_col = col % 4 offset = tile_idx * 512 + sub_row * 16 + group * 4 + local_col
- core.inference.quantization.mxfp8_quantize.mxfp8_quantize(x: torch.Tensor) tuple[torch.Tensor, torch.Tensor]#
Quantize a 2D tensor to MXFP8 with fused scale swizzle.
- Parameters:
x – [M, K] tensor in bf16/fp16/fp32. K must be divisible by 32.
- Returns:
data: [M, K] float8_e4m3fn swizzled_scales: 1D tensor in cuBLAS blocked layout (uint8/e8m0)
- Return type:
(data, swizzled_scales)