core.inference.quantization.mxfp8_tensor#

Module Contents#

Classes#

MXFP8Tensor

MXFP8 tensor wrapper storing quantized fp8_e4m3 data and swizzled e8m0 scales.

Functions#

API#

core.inference.quantization.mxfp8_tensor._ceil_div(a, b)#
class core.inference.quantization.mxfp8_tensor.MXFP8Tensor#

MXFP8 tensor wrapper storing quantized fp8_e4m3 data and swizzled e8m0 scales.

data: torch.Tensor#

None

scale: torch.Tensor#

None

backend: Optional[str]#

None

size(idx: Optional[int] = None)#

Wrapper for calling self.data.size()

scale_2d(K: Optional[int] = None) torch.Tensor#

Reshape 1D swizzled scale to 2D for scaled_grouped_mm / scaled_mm.

Swizzle pads rows to multiples of 128 and cols to multiples of 4. Returns (padded_M, padded_cols) where padded_cols = ceil(K//32, 4) * 4.

classmethod from_bf16(
x: torch.Tensor,
group_size: int = 32,
backend: str = 'flashinfer',
)#

Quantize BF16 tensor to MXFP8.

Parameters:
  • x – [M, K] BF16 tensor on CUDA.

  • group_size – MXFP8 group size (default 32).

  • backend – ‘triton’ (fused quantize + swizzle Triton kernel) or ‘flashinfer’ (single fused FlashInfer CUDA kernel).