core.inference.quantization.mxfp8_tensor#
Module Contents#
Classes#
MXFP8 tensor wrapper storing quantized fp8_e4m3 data and swizzled e8m0 scales. |
Functions#
API#
- core.inference.quantization.mxfp8_tensor._ceil_div(a, b)#
- class core.inference.quantization.mxfp8_tensor.MXFP8Tensor#
MXFP8 tensor wrapper storing quantized fp8_e4m3 data and swizzled e8m0 scales.
- data: torch.Tensor#
None
- scale: torch.Tensor#
None
- backend: Optional[str]#
None
- size(idx: Optional[int] = None)#
Wrapper for calling self.data.size()
- scale_2d(K: Optional[int] = None) torch.Tensor#
Reshape 1D swizzled scale to 2D for scaled_grouped_mm / scaled_mm.
Swizzle pads rows to multiples of 128 and cols to multiples of 4. Returns (padded_M, padded_cols) where padded_cols = ceil(K//32, 4) * 4.
- classmethod from_bf16(
- x: torch.Tensor,
- group_size: int = 32,
- backend: str = 'flashinfer',
Quantize BF16 tensor to MXFP8.
- Parameters:
x – [M, K] BF16 tensor on CUDA.
group_size – MXFP8 group size (default 32).
backend – ‘triton’ (fused quantize + swizzle Triton kernel) or ‘flashinfer’ (single fused FlashInfer CUDA kernel).