core.inference.quantization.utils#
Module Contents#
Functions#
Convert TE MXFP8 weights to mcore MXFP8Tensor format. |
|
Return True if a parameter should be quantized to FlashInfer MXFP8. |
|
Convert a parameter value to BF16 for quantization. |
|
Record shape/dtype/device for each parameter that will be quantized. |
|
Quantize model parameters to MXFP8Tensor format. |
|
MXFP8 matmul via FlashInfer. |
|
MXFP8 matmul via torch.nn.functional.scaled_mm. |
|
Compute a matmul in MXFP8. |
API#
- core.inference.quantization.utils._verify_te_to_mcore_mxfp8_conversion(
- te_dequantized,
- fi_quantized: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- core.inference.quantization.utils.quantize_model_to_mxfp8(
- model: torch.nn.Module,
- backend: str = 'flashinfer',
Convert TE MXFP8 weights to mcore MXFP8Tensor format.
Recursively walks the model and replaces each TEMXFP8Tensor parameter with an MXFP8Tensor re-quantized via the specified backend.
- Parameters:
model – The model whose TE MXFP8 parameters should be converted.
backend – ‘flashinfer’ or ‘triton’ quantization backend.
- core.inference.quantization.utils._should_quantize_param(val: torch.Tensor) bool#
Return True if a parameter should be quantized to FlashInfer MXFP8.
- core.inference.quantization.utils._to_bf16(val: torch.Tensor) torch.Tensor#
Convert a parameter value to BF16 for quantization.
- core.inference.quantization.utils.collect_mxfp8_param_metadata(
- model: torch.nn.Module,
Record shape/dtype/device for each parameter that will be quantized.
Called once before the first quantization to record the original parameter metadata (shape, dtype, device) before any format conversion.
- core.inference.quantization.utils.quantize_params_to_mxfp8(
- model: torch.nn.Module,
- persistent_buffers: Optional[Dict[str, megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor]] = None,
- _prefix: str = '',
- backend: str = 'flashinfer',
Quantize model parameters to MXFP8Tensor format.
Handles both TEMXFP8Tensor (fp8_param=True) and BF16/FP16 nn.Parameter inputs. When persistent_buffers is provided, new quantized values are
copy_()’d into the existing MXFP8Tensor objects so that CUDA-graph device-pointer captures remain valid.- Parameters:
model – The model whose parameters should be quantized.
persistent_buffers – If not
None, a dict mapping fully-qualified parameter names to previously-createdMXFP8Tensorobjects. Updated in-place and returned._prefix – Internal recursion prefix – callers should not set this.
backend – ‘flashinfer’ or ‘triton’ quantization backend.
- Returns:
The
persistent_buffersdict (created on first call ifNone).
- core.inference.quantization.utils._mm_mxfp8_flashinfer(
- x_mxfp8: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- out=None,
MXFP8 matmul via FlashInfer.
- core.inference.quantization.utils._mm_mxfp8_torch(
- x_mxfp8: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- out=None,
MXFP8 matmul via torch.nn.functional.scaled_mm.
- core.inference.quantization.utils.mm_mxfp8(
- x: torch.Tensor,
- weight: megatron.core.inference.quantization.mxfp8_tensor.MXFP8Tensor,
- out: torch.Tensor = None,
Compute a matmul in MXFP8.
Quantizes the bf16 input activation tensor on the fly. Weight must be pre-quantized. Dispatches to FlashInfer or torch based on weight.backend.